From 1b42ce8cf8fa7a7ce9589fa547201b420b706688 Mon Sep 17 00:00:00 2001 From: ronger Date: Tue, 21 Mar 2023 14:22:14 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E9=9B=86=E6=88=90=20open=20ai=20?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 40 ++- .../rymcu/forest/openai/OpenAiController.java | 53 ++++ .../service/AuthenticationInterceptor.java | 29 ++ .../forest/openai/service/OpenAiService.java | 272 ++++++++++++++++++ 4 files changed, 392 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/rymcu/forest/openai/OpenAiController.java create mode 100644 src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java create mode 100644 src/main/java/com/rymcu/forest/openai/service/OpenAiService.java diff --git a/pom.xml b/pom.xml index 2a75010..e863555 100644 --- a/pom.xml +++ b/pom.xml @@ -324,6 +324,42 @@ + + + com.theokanning.openai-gpt3-java + client + 0.11.0 + + + com.squareup.retrofit2 + retrofit + 2.9.0 + + + com.squareup.retrofit2 + adapter-rxjava2 + 2.9.0 + + + com.squareup.retrofit2 + converter-jackson + 2.9.0 + + + com.fasterxml.jackson.core + jackson-databind + + + + + + org.springframework.boot + spring-boot-starter-actuator + + + io.micrometer + micrometer-registry-prometheus + @@ -393,7 +429,7 @@ dev - true + false @@ -402,7 +438,7 @@ prod - false + true diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java new file mode 100644 index 0000000..9b96ef1 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java @@ -0,0 +1,53 @@ +package com.rymcu.forest.openai; + +import com.alibaba.fastjson.JSONObject; +import com.rymcu.forest.core.result.GlobalResult; +import com.rymcu.forest.core.result.GlobalResultGenerator; +import com.rymcu.forest.openai.service.OpenAiService; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatMessage; +import org.apache.commons.lang.StringUtils; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +/** + * Created on 2023/2/15 10:04. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai + */ +@RestController +@RequestMapping("/api/v1/openai") +public class OpenAiController { + + @Value("${openai.token}") + private String token; + + @PostMapping("/chat") + public GlobalResult> chat(@RequestBody JSONObject jsonObject) { + String message = jsonObject.getString("message"); + if (StringUtils.isBlank(message)) { + throw new IllegalArgumentException("参数异常!"); + } + ChatMessage chatMessage = new ChatMessage("user", message); + List list = new ArrayList(); + list.add(chatMessage); + OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); + ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() + .model("gpt-3.5-turbo") + .messages(list) + .build(); + List choices = service.createChatCompletion(completionRequest).getChoices(); + return GlobalResultGenerator.genSuccessResult(choices); + } + +} diff --git a/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java b/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java new file mode 100644 index 0000000..30afd43 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/AuthenticationInterceptor.java @@ -0,0 +1,29 @@ +package com.rymcu.forest.openai.service; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; + +/** + * OkHttp Interceptor that adds an authorization token header + * @author ronger + */ +public class AuthenticationInterceptor implements Interceptor { + + private final String token; + + AuthenticationInterceptor(String token) { + this.token = token; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request() + .newBuilder() + .header("Authorization", "Bearer " + token) + .build(); + return chain.proceed(request); + } +} diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java new file mode 100644 index 0000000..2781b74 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java @@ -0,0 +1,272 @@ +package com.rymcu.forest.openai.service; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.rymcu.forest.util.SpringContextHolder; +import com.theokanning.openai.DeleteResult; +import com.theokanning.openai.OpenAiApi; +import com.theokanning.openai.OpenAiError; +import com.theokanning.openai.OpenAiHttpException; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.completion.CompletionResult; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatCompletionResult; +import com.theokanning.openai.edit.EditRequest; +import com.theokanning.openai.edit.EditResult; +import com.theokanning.openai.embedding.EmbeddingRequest; +import com.theokanning.openai.embedding.EmbeddingResult; +import com.theokanning.openai.file.File; +import com.theokanning.openai.finetune.FineTuneEvent; +import com.theokanning.openai.finetune.FineTuneRequest; +import com.theokanning.openai.finetune.FineTuneResult; +import com.theokanning.openai.image.CreateImageEditRequest; +import com.theokanning.openai.image.CreateImageRequest; +import com.theokanning.openai.image.CreateImageVariationRequest; +import com.theokanning.openai.image.ImageResult; +import com.theokanning.openai.model.Model; +import com.theokanning.openai.moderation.ModerationRequest; +import com.theokanning.openai.moderation.ModerationResult; +import io.reactivex.Single; +import okhttp3.*; +import org.springframework.core.env.Environment; +import retrofit2.HttpException; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +public class OpenAiService { + + private static final String BASE_URL = "https://api.openai.com/"; + private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10); + private static final ObjectMapper errorMapper = defaultObjectMapper(); + + private final OpenAiApi api; + private static final Environment env = SpringContextHolder.getBean(Environment.class); + + + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + */ + public OpenAiService(final String token) { + this(token, DEFAULT_TIMEOUT); + } + + /** + * Creates a new OpenAiService that wraps OpenAiApi + * + * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + * @param timeout http read timeout, Duration.ZERO means no timeout + */ + public OpenAiService(final String token, final Duration timeout) { + this(buildApi(token, timeout)); + } + + /** + * Creates a new OpenAiService that wraps OpenAiApi. + * Use this if you need more customization. + * + * @param api OpenAiApi instance to use for all methods + */ + public OpenAiService(final OpenAiApi api) { + this.api = api; + } + + public List listModels() { + return execute(api.listModels()).data; + } + + public Model getModel(String modelId) { + return execute(api.getModel(modelId)); + } + + public CompletionResult createCompletion(CompletionRequest request) { + return execute(api.createCompletion(request)); + } + + public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { + return execute(api.createChatCompletion(request)); + } + + public EditResult createEdit(EditRequest request) { + return execute(api.createEdit(request)); + } + + public EmbeddingResult createEmbeddings(EmbeddingRequest request) { + return execute(api.createEmbeddings(request)); + } + + public List listFiles() { + return execute(api.listFiles()).data; + } + + public File uploadFile(String purpose, String filepath) { + java.io.File file = new java.io.File(filepath); + RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose); + RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file); + MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody); + + return execute(api.uploadFile(purposeBody, body)); + } + + public DeleteResult deleteFile(String fileId) { + return execute(api.deleteFile(fileId)); + } + + public File retrieveFile(String fileId) { + return execute(api.retrieveFile(fileId)); + } + + public FineTuneResult createFineTune(FineTuneRequest request) { + return execute(api.createFineTune(request)); + } + + public CompletionResult createFineTuneCompletion(CompletionRequest request) { + return execute(api.createFineTuneCompletion(request)); + } + + public List listFineTunes() { + return execute(api.listFineTunes()).data; + } + + public FineTuneResult retrieveFineTune(String fineTuneId) { + return execute(api.retrieveFineTune(fineTuneId)); + } + + public FineTuneResult cancelFineTune(String fineTuneId) { + return execute(api.cancelFineTune(fineTuneId)); + } + + public List listFineTuneEvents(String fineTuneId) { + return execute(api.listFineTuneEvents(fineTuneId)).data; + } + + public DeleteResult deleteFineTune(String fineTuneId) { + return execute(api.deleteFineTune(fineTuneId)); + } + + public ImageResult createImage(CreateImageRequest request) { + return execute(api.createImage(request)); + } + + public ImageResult createImageEdit(CreateImageEditRequest request, String imagePath, String maskPath) { + java.io.File image = new java.io.File(imagePath); + java.io.File mask = null; + if (maskPath != null) { + mask = new java.io.File(maskPath); + } + return createImageEdit(request, image, mask); + } + + public ImageResult createImageEdit(CreateImageEditRequest request, java.io.File image, java.io.File mask) { + RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image); + + MultipartBody.Builder builder = new MultipartBody.Builder() + .setType(MediaType.get("multipart/form-data")) + .addFormDataPart("prompt", request.getPrompt()) + .addFormDataPart("size", request.getSize()) + .addFormDataPart("response_format", request.getResponseFormat()) + .addFormDataPart("image", "image", imageBody); + + if (request.getN() != null) { + builder.addFormDataPart("n", request.getN().toString()); + } + + if (mask != null) { + RequestBody maskBody = RequestBody.create(MediaType.parse("image"), mask); + builder.addFormDataPart("mask", "mask", maskBody); + } + + return execute(api.createImageEdit(builder.build())); + } + + public ImageResult createImageVariation(CreateImageVariationRequest request, String imagePath) { + java.io.File image = new java.io.File(imagePath); + return createImageVariation(request, image); + } + + public ImageResult createImageVariation(CreateImageVariationRequest request, java.io.File image) { + RequestBody imageBody = RequestBody.create(MediaType.parse("image"), image); + + MultipartBody.Builder builder = new MultipartBody.Builder() + .setType(MediaType.get("multipart/form-data")) + .addFormDataPart("size", request.getSize()) + .addFormDataPart("response_format", request.getResponseFormat()) + .addFormDataPart("image", "image", imageBody); + + if (request.getN() != null) { + builder.addFormDataPart("n", request.getN().toString()); + } + + return execute(api.createImageVariation(builder.build())); + } + + public ModerationResult createModeration(ModerationRequest request) { + return execute(api.createModeration(request)); + } + + /** + * Calls the Open AI api, returns the response, and parses error messages if the request fails + */ + public static T execute(Single apiCall) { + try { + return apiCall.blockingGet(); + } catch (HttpException e) { + try { + if (e.response() == null || e.response().errorBody() == null) { + throw e; + } + String errorBody = e.response().errorBody().string(); + + OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class); + throw new OpenAiHttpException(error, e, e.code()); + } catch (IOException ex) { + // couldn't parse OpenAI error + throw e; + } + } + } + + public static OpenAiApi buildApi(String token, Duration timeout) { + Objects.requireNonNull(token, "OpenAI token required"); + ObjectMapper mapper = defaultObjectMapper(); + OkHttpClient client = defaultClient(token, timeout); + Retrofit retrofit = defaultRetrofit(client, mapper); + + return retrofit.create(OpenAiApi.class); + } + + public static ObjectMapper defaultObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + return mapper; + } + + public static OkHttpClient defaultClient(String token, Duration timeout) { + return new OkHttpClient.Builder() + .addInterceptor(new AuthenticationInterceptor(token)) + .connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS)) + .readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) + .build(); + } + + public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) { + return new Retrofit.Builder() + .baseUrl(env.getProperty("openai.url", BASE_URL)) + .client(client) + .addConverterFactory(JacksonConverterFactory.create(mapper)) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .build(); + } +}