diff --git a/pom.xml b/pom.xml
index 2a75010..e863555 100644
--- a/pom.xml
+++ b/pom.xml
@@ -324,6 +324,42 @@
+
+
+ com.theokanning.openai-gpt3-java
+ client
+ 0.11.0
+
+
+ com.squareup.retrofit2
+ retrofit
+ 2.9.0
+
+
+ com.squareup.retrofit2
+ adapter-rxjava2
+ 2.9.0
+
+
+ com.squareup.retrofit2
+ converter-jackson
+ 2.9.0
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-actuator
+
+
+ io.micrometer
+ micrometer-registry-prometheus
+
@@ -393,7 +429,7 @@
dev
- true
+ false
@@ -402,7 +438,7 @@
prod
- false
+ true
diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java
new file mode 100644
index 0000000..9b96ef1
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java
@@ -0,0 +1,53 @@
+package com.rymcu.forest.openai;
+
+import com.alibaba.fastjson.JSONObject;
+import com.rymcu.forest.core.result.GlobalResult;
+import com.rymcu.forest.core.result.GlobalResultGenerator;
+import com.rymcu.forest.openai.service.OpenAiService;
+import com.theokanning.openai.completion.chat.ChatCompletionChoice;
+import com.theokanning.openai.completion.chat.ChatCompletionRequest;
+import com.theokanning.openai.completion.chat.ChatMessage;
+import org.apache.commons.lang.StringUtils;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Created on 2023/2/15 10:04.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.openai
+ */
+@RestController
+@RequestMapping("/api/v1/openai")
+public class OpenAiController {
+
+ @Value("${openai.token}")
+ private String token;
+
+ @PostMapping("/chat")
+ public GlobalResult> chat(@RequestBody JSONObject jsonObject) {
+ String message = jsonObject.getString("message");
+ if (StringUtils.isBlank(message)) {
+ throw new IllegalArgumentException("参数异常!");
+ }
+ ChatMessage chatMessage = new ChatMessage("user", message);
+ List list = new ArrayList();
+ list.add(chatMessage);
+ OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
+ ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
+ .model("gpt-3.5-turbo")
+ .messages(list)
+ .build();
+ List choices = service.createChatCompletion(completionRequest).getChoices();
+ return GlobalResultGenerator.genSuccessResult(choices);
+ }
+
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java b/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java
new file mode 100644
index 0000000..30afd43
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java
@@ -0,0 +1,29 @@
+package com.rymcu.forest.openai.service;
+
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+
+import java.io.IOException;
+
+/**
+ * OkHttp Interceptor that adds an authorization token header
+ * @author ronger
+ */
+public class AuthenticationInterceptor implements Interceptor {
+
+ private final String token;
+
+ AuthenticationInterceptor(String token) {
+ this.token = token;
+ }
+
+ @Override
+ public Response intercept(Chain chain) throws IOException {
+ Request request = chain.request()
+ .newBuilder()
+ .header("Authorization", "Bearer " + token)
+ .build();
+ return chain.proceed(request);
+ }
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
new file mode 100644
index 0000000..2781b74
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
@@ -0,0 +1,272 @@
+package com.rymcu.forest.openai.service;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.PropertyNamingStrategy;
+import com.rymcu.forest.util.SpringContextHolder;
+import com.theokanning.openai.DeleteResult;
+import com.theokanning.openai.OpenAiApi;
+import com.theokanning.openai.OpenAiError;
+import com.theokanning.openai.OpenAiHttpException;
+import com.theokanning.openai.completion.CompletionRequest;
+import com.theokanning.openai.completion.CompletionResult;
+import com.theokanning.openai.completion.chat.ChatCompletionRequest;
+import com.theokanning.openai.completion.chat.ChatCompletionResult;
+import com.theokanning.openai.edit.EditRequest;
+import com.theokanning.openai.edit.EditResult;
+import com.theokanning.openai.embedding.EmbeddingRequest;
+import com.theokanning.openai.embedding.EmbeddingResult;
+import com.theokanning.openai.file.File;
+import com.theokanning.openai.finetune.FineTuneEvent;
+import com.theokanning.openai.finetune.FineTuneRequest;
+import com.theokanning.openai.finetune.FineTuneResult;
+import com.theokanning.openai.image.CreateImageEditRequest;
+import com.theokanning.openai.image.CreateImageRequest;
+import com.theokanning.openai.image.CreateImageVariationRequest;
+import com.theokanning.openai.image.ImageResult;
+import com.theokanning.openai.model.Model;
+import com.theokanning.openai.moderation.ModerationRequest;
+import com.theokanning.openai.moderation.ModerationResult;
+import io.reactivex.Single;
+import okhttp3.*;
+import org.springframework.core.env.Environment;
+import retrofit2.HttpException;
+import retrofit2.Retrofit;
+import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
+import retrofit2.converter.jackson.JacksonConverterFactory;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+
+public class OpenAiService {
+
+ private static final String BASE_URL = "https://api.openai.com/";
+ private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
+ private static final ObjectMapper errorMapper = defaultObjectMapper();
+
+ private final OpenAiApi api;
+ private static final Environment env = SpringContextHolder.getBean(Environment.class);
+
+
+ /**
+ * Creates a new OpenAiService that wraps OpenAiApi
+ *
+ * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+ */
+ public OpenAiService(final String token) {
+ this(token, DEFAULT_TIMEOUT);
+ }
+
+ /**
+ * Creates a new OpenAiService that wraps OpenAiApi
+ *
+ * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+ * @param timeout http read timeout, Duration.ZERO means no timeout
+ */
+ public OpenAiService(final String token, final Duration timeout) {
+ this(buildApi(token, timeout));
+ }
+
+ /**
+ * Creates a new OpenAiService that wraps OpenAiApi.
+ * Use this if you need more customization.
+ *
+ * @param api OpenAiApi instance to use for all methods
+ */
+ public OpenAiService(final OpenAiApi api) {
+ this.api = api;
+ }
+
+ public List listModels() {
+ return execute(api.listModels()).data;
+ }
+
+ public Model getModel(String modelId) {
+ return execute(api.getModel(modelId));
+ }
+
+ public CompletionResult createCompletion(CompletionRequest request) {
+ return execute(api.createCompletion(request));
+ }
+
+ public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
+ return execute(api.createChatCompletion(request));
+ }
+
+ public EditResult createEdit(EditRequest request) {
+ return execute(api.createEdit(request));
+ }
+
+ public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
+ return execute(api.createEmbeddings(request));
+ }
+
+ public List listFiles() {
+ return execute(api.listFiles()).data;
+ }
+
+ public File uploadFile(String purpose, String filepath) {
+ java.io.File file = new java.io.File(filepath);
+ RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose);
+ RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file);
+ MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody);
+
+ return execute(api.uploadFile(purposeBody, body));
+ }
+
+ public DeleteResult deleteFile(String fileId) {
+ return execute(api.deleteFile(fileId));
+ }
+
+ public File retrieveFile(String fileId) {
+ return execute(api.retrieveFile(fileId));
+ }
+
+ public FineTuneResult createFineTune(FineTuneRequest request) {
+ return execute(api.createFineTune(request));
+ }
+
+ public CompletionResult createFineTuneCompletion(CompletionRequest request) {
+ return execute(api.createFineTuneCompletion(request));
+ }
+
+ public List listFineTunes() {
+ return execute(api.listFineTunes()).data;
+ }
+
+ public FineTuneResult retrieveFineTune(String fineTuneId) {
+ return execute(api.retrieveFineTune(fineTuneId));
+ }
+
+ public FineTuneResult cancelFineTune(String fineTuneId) {
+ return execute(api.cancelFineTune(fineTuneId));
+ }
+
+ public List listFineTuneEvents(String fineTuneId) {
+ return execute(api.listFineTuneEvents(fineTuneId)).data;
+ }
+
+ public DeleteResult deleteFineTune(String fineTuneId) {
+ return execute(api.deleteFineTune(fineTuneId));
+ }
+
+ public ImageResult createImage(CreateImageRequest request) {
+ return execute(api.createImage(request));
+ }
+
+ public ImageResult createImageEdit(CreateImageEditRequest request, String imagePath, String maskPath) {
+ java.io.File image = new java.io.File(imagePath);
+ java.io.File mask = null;
+ if (maskPath != null) {
+ mask = new java.io.File(maskPath);
+ }
+ return createImageEdit(request, image, mask);
+ }
+
+ public ImageResult createImageEdit(CreateImageEditRequest request, java.io.File image, java.io.File mask) {
+ RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image);
+
+ MultipartBody.Builder builder = new MultipartBody.Builder()
+ .setType(MediaType.get("multipart/form-data"))
+ .addFormDataPart("prompt", request.getPrompt())
+ .addFormDataPart("size", request.getSize())
+ .addFormDataPart("response_format", request.getResponseFormat())
+ .addFormDataPart("image", "image", imageBody);
+
+ if (request.getN() != null) {
+ builder.addFormDataPart("n", request.getN().toString());
+ }
+
+ if (mask != null) {
+ RequestBody maskBody = RequestBody.create(MediaType.parse("image"), mask);
+ builder.addFormDataPart("mask", "mask", maskBody);
+ }
+
+ return execute(api.createImageEdit(builder.build()));
+ }
+
+ public ImageResult createImageVariation(CreateImageVariationRequest request, String imagePath) {
+ java.io.File image = new java.io.File(imagePath);
+ return createImageVariation(request, image);
+ }
+
+ public ImageResult createImageVariation(CreateImageVariationRequest request, java.io.File image) {
+ RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image);
+
+ MultipartBody.Builder builder = new MultipartBody.Builder()
+ .setType(MediaType.get("multipart/form-data"))
+ .addFormDataPart("size", request.getSize())
+ .addFormDataPart("response_format", request.getResponseFormat())
+ .addFormDataPart("image", "image", imageBody);
+
+ if (request.getN() != null) {
+ builder.addFormDataPart("n", request.getN().toString());
+ }
+
+ return execute(api.createImageVariation(builder.build()));
+ }
+
+ public ModerationResult createModeration(ModerationRequest request) {
+ return execute(api.createModeration(request));
+ }
+
+ /**
+ * Calls the Open AI api, returns the response, and parses error messages if the request fails
+ */
+ public static T execute(Single apiCall) {
+ try {
+ return apiCall.blockingGet();
+ } catch (HttpException e) {
+ try {
+ if (e.response() == null || e.response().errorBody() == null) {
+ throw e;
+ }
+ String errorBody = e.response().errorBody().string();
+
+ OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
+ throw new OpenAiHttpException(error, e, e.code());
+ } catch (IOException ex) {
+ // couldn't parse OpenAI error
+ throw e;
+ }
+ }
+ }
+
+ public static OpenAiApi buildApi(String token, Duration timeout) {
+ Objects.requireNonNull(token, "OpenAI token required");
+ ObjectMapper mapper = defaultObjectMapper();
+ OkHttpClient client = defaultClient(token, timeout);
+ Retrofit retrofit = defaultRetrofit(client, mapper);
+
+ return retrofit.create(OpenAiApi.class);
+ }
+
+ public static ObjectMapper defaultObjectMapper() {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
+ mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
+ return mapper;
+ }
+
+ public static OkHttpClient defaultClient(String token, Duration timeout) {
+ return new OkHttpClient.Builder()
+ .addInterceptor(new AuthenticationInterceptor(token))
+ .connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
+ .readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
+ .build();
+ }
+
+ public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) {
+ return new Retrofit.Builder()
+ .baseUrl(env.getProperty("openai.url", BASE_URL))
+ .client(client)
+ .addConverterFactory(JacksonConverterFactory.create(mapper))
+ .addCallAdapterFactory(RxJava2CallAdapterFactory.create())
+ .build();
+ }
+}