集成 open ai 接口

This commit is contained in:
ronger 2023-03-21 14:22:14 +08:00
parent 15a4dc1804
commit 1b42ce8cf8
4 changed files with 392 additions and 2 deletions

40
pom.xml
View File

@ -324,6 +324,42 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>client</artifactId>
<version>0.11.0</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>adapter-rxjava2</artifactId>
<version>2.9.0</version>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-jackson</artifactId>
<version>2.9.0</version>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-prometheus</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>
@ -393,7 +429,7 @@
<profileActive>dev</profileActive> <profileActive>dev</profileActive>
</properties> </properties>
<activation> <activation>
<activeByDefault>true</activeByDefault> <activeByDefault>false</activeByDefault>
</activation> </activation>
</profile> </profile>
<profile> <profile>
@ -402,7 +438,7 @@
<profileActive>prod</profileActive> <profileActive>prod</profileActive>
</properties> </properties>
<activation> <activation>
<activeByDefault>false</activeByDefault> <activeByDefault>true</activeByDefault>
</activation> </activation>
</profile> </profile>
</profiles> </profiles>

View File

@ -0,0 +1,53 @@
package com.rymcu.forest.openai;
import com.alibaba.fastjson.JSONObject;
import com.rymcu.forest.core.result.GlobalResult;
import com.rymcu.forest.core.result.GlobalResultGenerator;
import com.rymcu.forest.openai.service.OpenAiService;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
/**
* Created on 2023/2/15 10:04.
*
* @author ronger
* @email ronger-x@outlook.com
* @desc : com.rymcu.forest.openai
*/
@RestController
@RequestMapping("/api/v1/openai")
public class OpenAiController {
@Value("${openai.token}")
private String token;
@PostMapping("/chat")
public GlobalResult<List<ChatCompletionChoice>> chat(@RequestBody JSONObject jsonObject) {
String message = jsonObject.getString("message");
if (StringUtils.isBlank(message)) {
throw new IllegalArgumentException("参数异常!");
}
ChatMessage chatMessage = new ChatMessage("user", message);
List<ChatMessage> list = new ArrayList();
list.add(chatMessage);
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
.model("gpt-3.5-turbo")
.messages(list)
.build();
List<ChatCompletionChoice> choices = service.createChatCompletion(completionRequest).getChoices();
return GlobalResultGenerator.genSuccessResult(choices);
}
}

View File

@ -0,0 +1,29 @@
package com.rymcu.forest.openai.service;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import java.io.IOException;
/**
* OkHttp Interceptor that adds an authorization token header
* @author ronger
*/
public class AuthenticationInterceptor implements Interceptor {
private final String token;
AuthenticationInterceptor(String token) {
this.token = token;
}
@Override
public Response intercept(Chain chain) throws IOException {
Request request = chain.request()
.newBuilder()
.header("Authorization", "Bearer " + token)
.build();
return chain.proceed(request);
}
}

View File

@ -0,0 +1,272 @@
package com.rymcu.forest.openai.service;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.rymcu.forest.util.SpringContextHolder;
import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.image.CreateImageEditRequest;
import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.image.CreateImageVariationRequest;
import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
import io.reactivex.Single;
import okhttp3.*;
import org.springframework.core.env.Environment;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
public class OpenAiService {
private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
private static final ObjectMapper errorMapper = defaultObjectMapper();
private final OpenAiApi api;
private static final Environment env = SpringContextHolder.getBean(Environment.class);
/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
*/
public OpenAiService(final String token) {
this(token, DEFAULT_TIMEOUT);
}
/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(buildApi(token, timeout));
}
/**
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(final OpenAiApi api) {
this.api = api;
}
public List<Model> listModels() {
return execute(api.listModels()).data;
}
public Model getModel(String modelId) {
return execute(api.getModel(modelId));
}
public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}
public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
}
public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
return execute(api.createEmbeddings(request));
}
public List<File> listFiles() {
return execute(api.listFiles()).data;
}
public File uploadFile(String purpose, String filepath) {
java.io.File file = new java.io.File(filepath);
RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose);
RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file);
MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody);
return execute(api.uploadFile(purposeBody, body));
}
public DeleteResult deleteFile(String fileId) {
return execute(api.deleteFile(fileId));
}
public File retrieveFile(String fileId) {
return execute(api.retrieveFile(fileId));
}
public FineTuneResult createFineTune(FineTuneRequest request) {
return execute(api.createFineTune(request));
}
public CompletionResult createFineTuneCompletion(CompletionRequest request) {
return execute(api.createFineTuneCompletion(request));
}
public List<FineTuneResult> listFineTunes() {
return execute(api.listFineTunes()).data;
}
public FineTuneResult retrieveFineTune(String fineTuneId) {
return execute(api.retrieveFineTune(fineTuneId));
}
public FineTuneResult cancelFineTune(String fineTuneId) {
return execute(api.cancelFineTune(fineTuneId));
}
public List<FineTuneEvent> listFineTuneEvents(String fineTuneId) {
return execute(api.listFineTuneEvents(fineTuneId)).data;
}
public DeleteResult deleteFineTune(String fineTuneId) {
return execute(api.deleteFineTune(fineTuneId));
}
public ImageResult createImage(CreateImageRequest request) {
return execute(api.createImage(request));
}
public ImageResult createImageEdit(CreateImageEditRequest request, String imagePath, String maskPath) {
java.io.File image = new java.io.File(imagePath);
java.io.File mask = null;
if (maskPath != null) {
mask = new java.io.File(maskPath);
}
return createImageEdit(request, image, mask);
}
public ImageResult createImageEdit(CreateImageEditRequest request, java.io.File image, java.io.File mask) {
RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image);
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MediaType.get("multipart/form-data"))
.addFormDataPart("prompt", request.getPrompt())
.addFormDataPart("size", request.getSize())
.addFormDataPart("response_format", request.getResponseFormat())
.addFormDataPart("image", "image", imageBody);
if (request.getN() != null) {
builder.addFormDataPart("n", request.getN().toString());
}
if (mask != null) {
RequestBody maskBody = RequestBody.create(MediaType.parse("image"), mask);
builder.addFormDataPart("mask", "mask", maskBody);
}
return execute(api.createImageEdit(builder.build()));
}
public ImageResult createImageVariation(CreateImageVariationRequest request, String imagePath) {
java.io.File image = new java.io.File(imagePath);
return createImageVariation(request, image);
}
public ImageResult createImageVariation(CreateImageVariationRequest request, java.io.File image) {
RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image);
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MediaType.get("multipart/form-data"))
.addFormDataPart("size", request.getSize())
.addFormDataPart("response_format", request.getResponseFormat())
.addFormDataPart("image", "image", imageBody);
if (request.getN() != null) {
builder.addFormDataPart("n", request.getN().toString());
}
return execute(api.createImageVariation(builder.build()));
}
public ModerationResult createModeration(ModerationRequest request) {
return execute(api.createModeration(request));
}
/**
* Calls the Open AI api, returns the response, and parses error messages if the request fails
*/
public static <T> T execute(Single<T> apiCall) {
try {
return apiCall.blockingGet();
} catch (HttpException e) {
try {
if (e.response() == null || e.response().errorBody() == null) {
throw e;
}
String errorBody = e.response().errorBody().string();
OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) {
// couldn't parse OpenAI error
throw e;
}
}
}
public static OpenAiApi buildApi(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);
return retrofit.create(OpenAiApi.class);
}
public static ObjectMapper defaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
return mapper;
}
public static OkHttpClient defaultClient(String token, Duration timeout) {
return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
.readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
.build();
}
public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) {
return new Retrofit.Builder()
.baseUrl(env.getProperty("openai.url", BASE_URL))
.client(client)
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build();
}
}