diff --git a/pom.xml b/pom.xml index e863555..3f4d58f 100644 --- a/pom.xml +++ b/pom.xml @@ -129,7 +129,7 @@ net.minidev json-smart - 2.4.8 + 2.4.9 @@ -328,7 +328,7 @@ com.theokanning.openai-gpt3-java client - 0.11.0 + 0.12.0 com.squareup.retrofit2 @@ -355,6 +355,12 @@ org.springframework.boot spring-boot-starter-actuator + + + com.fasterxml.jackson.core + jackson-databind + + io.micrometer @@ -429,7 +435,7 @@ dev - false + true @@ -438,7 +444,7 @@ prod - true + false diff --git a/src/main/java/com/rymcu/forest/config/ShiroConfig.java b/src/main/java/com/rymcu/forest/config/ShiroConfig.java index 4b2a47e..bd6563a 100644 --- a/src/main/java/com/rymcu/forest/config/ShiroConfig.java +++ b/src/main/java/com/rymcu/forest/config/ShiroConfig.java @@ -45,6 +45,7 @@ public class ShiroConfig { filterChainDefinitionMap.put("/api/v1/auth/login/**", "anon"); filterChainDefinitionMap.put("/api/v1/auth/logout/**", "anon"); filterChainDefinitionMap.put("/api/v1/auth/refresh-token/**", "anon"); + filterChainDefinitionMap.put("/api/v1/sse/**", "anon"); filterChainDefinitionMap.put("/**", "jwt"); shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap); diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java index 9b96ef1..b07ae55 100644 --- a/src/main/java/com/rymcu/forest/openai/OpenAiController.java +++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java @@ -3,10 +3,15 @@ package com.rymcu.forest.openai; import com.alibaba.fastjson.JSONObject; import com.rymcu.forest.core.result.GlobalResult; import com.rymcu.forest.core.result.GlobalResultGenerator; +import com.rymcu.forest.entity.User; import com.rymcu.forest.openai.service.OpenAiService; +import com.rymcu.forest.openai.service.SseService; +import com.rymcu.forest.util.UserUtils; import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatMessage; +import io.reactivex.Flowable; import org.apache.commons.lang.StringUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.PostMapping; @@ -14,6 +19,7 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import javax.annotation.Resource; import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -28,26 +34,38 @@ import java.util.List; @RestController @RequestMapping("/api/v1/openai") public class OpenAiController { + @Resource + private SseService sseService; @Value("${openai.token}") private String token; @PostMapping("/chat") - public GlobalResult> chat(@RequestBody JSONObject jsonObject) { + public GlobalResult chat(@RequestBody JSONObject jsonObject) { String message = jsonObject.getString("message"); if (StringUtils.isBlank(message)) { throw new IllegalArgumentException("参数异常!"); } + User user = UserUtils.getCurrentUserByToken(); ChatMessage chatMessage = new ChatMessage("user", message); - List list = new ArrayList(); + List list = new ArrayList<>(4); list.add(chatMessage); OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() .model("gpt-3.5-turbo") + .stream(true) .messages(list) .build(); - List choices = service.createChatCompletion(completionRequest).getChoices(); - return GlobalResultGenerator.genSuccessResult(choices); + service.streamChatCompletion(completionRequest).doOnError(Throwable::printStackTrace) + .blockingForEach(chunk -> { + String text = chunk.getChoices().get(0).getMessage().getContent(); + if (text == null) { + return; + } + System.out.print(text); + sseService.send(user.getIdUser(), text); + }); + service.shutdownExecutor(); + return GlobalResultGenerator.genSuccessResult(); } - } diff --git a/src/main/java/com/rymcu/forest/openai/SeeController.java b/src/main/java/com/rymcu/forest/openai/SeeController.java new file mode 100644 index 0000000..948ed20 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/SeeController.java @@ -0,0 +1,39 @@ +package com.rymcu.forest.openai; + +import com.rymcu.forest.entity.User; +import com.rymcu.forest.openai.service.SseService; +import com.rymcu.forest.util.UserUtils; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import javax.annotation.Resource; + +/** + * Created on 2023/5/26 11:26. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai + */ +@RestController +@RequestMapping("/api/v1/sse") +public class SeeController { + + @Resource + private SseService sseService; + + @GetMapping(value = "/subscribe/{idUser}", produces = {MediaType.TEXT_EVENT_STREAM_VALUE}) + public SseEmitter subscribe(@PathVariable Long idUser) { + return sseService.connect(idUser); + } + + @GetMapping(value = "/close/{idUser}") + public void close(@PathVariable Long idUser) { + sseService.close(idUser); + } + +} diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java index 2781b74..ad468b8 100644 --- a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java +++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java @@ -9,8 +9,10 @@ import com.theokanning.openai.DeleteResult; import com.theokanning.openai.OpenAiApi; import com.theokanning.openai.OpenAiError; import com.theokanning.openai.OpenAiHttpException; +import com.theokanning.openai.completion.CompletionChunk; import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionResult; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.edit.EditRequest; @@ -28,11 +30,15 @@ import com.theokanning.openai.image.ImageResult; import com.theokanning.openai.model.Model; import com.theokanning.openai.moderation.ModerationRequest; import com.theokanning.openai.moderation.ModerationResult; + +import io.reactivex.BackpressureStrategy; +import io.reactivex.Flowable; import io.reactivex.Single; import okhttp3.*; import org.springframework.core.env.Environment; import retrofit2.HttpException; import retrofit2.Retrofit; +import retrofit2.Call; import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; import retrofit2.converter.jackson.JacksonConverterFactory; @@ -40,18 +46,19 @@ import java.io.IOException; import java.time.Duration; import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; public class OpenAiService { private static final String BASE_URL = "https://api.openai.com/"; private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10); - private static final ObjectMapper errorMapper = defaultObjectMapper(); + private static final ObjectMapper mapper = defaultObjectMapper(); private final OpenAiApi api; + private final ExecutorService executorService; private static final Environment env = SpringContextHolder.getBean(Environment.class); - /** * Creates a new OpenAiService that wraps OpenAiApi * @@ -68,17 +75,39 @@ public class OpenAiService { * @param timeout http read timeout, Duration.ZERO means no timeout */ public OpenAiService(final String token, final Duration timeout) { - this(buildApi(token, timeout)); + ObjectMapper mapper = defaultObjectMapper(); + OkHttpClient client = defaultClient(token, timeout); + Retrofit retrofit = defaultRetrofit(client, mapper); + + this.api = retrofit.create(OpenAiApi.class); + this.executorService = client.dispatcher().executorService(); } /** * Creates a new OpenAiService that wraps OpenAiApi. - * Use this if you need more customization. + * Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and + * want to shut down instantly * * @param api OpenAiApi instance to use for all methods */ public OpenAiService(final OpenAiApi api) { this.api = api; + this.executorService = null; + } + + /** + * Creates a new OpenAiService that wraps OpenAiApi. + * The ExecutorService must be the one you get from the client you created the api with + * otherwise shutdownExecutor() won't work. + *

+ * Use this if you need more customization. + * + * @param api OpenAiApi instance to use for all methods + * @param executorService the ExecutorService from client.dispatcher().executorService() + */ + public OpenAiService(final OpenAiApi api, final ExecutorService executorService) { + this.api = api; + this.executorService = executorService; } public List listModels() { @@ -93,10 +122,22 @@ public class OpenAiService { return execute(api.createCompletion(request)); } + public Flowable streamCompletion(CompletionRequest request) { + request.setStream(true); + + return stream(api.createCompletionStream(request), CompletionChunk.class); + } + public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { return execute(api.createChatCompletion(request)); } + public Flowable streamChatCompletion(ChatCompletionRequest request) { + request.setStream(true); + + return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class); + } + public EditResult createEdit(EditRequest request) { return execute(api.createEdit(request)); } @@ -227,7 +268,7 @@ public class OpenAiService { } String errorBody = e.response().errorBody().string(); - OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class); + OpenAiError error = mapper.readValue(errorBody, OpenAiError.class); throw new OpenAiHttpException(error, e, e.code()); } catch (IOException ex) { // couldn't parse OpenAI error @@ -236,8 +277,49 @@ public class OpenAiService { } } + /** + * Calls the Open AI api and returns a Flowable of SSE for streaming + * omitting the last message. + * + * @param apiCall The api call + */ + public static Flowable stream(Call apiCall) { + return stream(apiCall, false); + } + + /** + * Calls the Open AI api and returns a Flowable of SSE for streaming. + * + * @param apiCall The api call + * @param emitDone If true the last message ([DONE]) is emitted + */ + public static Flowable stream(Call apiCall, boolean emitDone) { + return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER); + } + + /** + * Calls the Open AI api and returns a Flowable of type T for streaming + * omitting the last message. + * + * @param apiCall The api call + * @param cl Class of type T to return + */ + public static Flowable stream(Call apiCall, Class cl) { + return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl)); + } + + /** + * Shuts down the OkHttp ExecutorService. + * The default behaviour of OkHttp's ExecutorService (ConnectionPool) + * is to shut down after an idle timeout of 60s. + * Call this method to shut down the ExecutorService immediately. + */ + public void shutdownExecutor() { + Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down"); + this.executorService.shutdown(); + } + public static OpenAiApi buildApi(String token, Duration timeout) { - Objects.requireNonNull(token, "OpenAI token required"); ObjectMapper mapper = defaultObjectMapper(); OkHttpClient client = defaultClient(token, timeout); Retrofit retrofit = defaultRetrofit(client, mapper); diff --git a/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java b/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java new file mode 100644 index 0000000..b1cbbe9 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java @@ -0,0 +1,99 @@ +package com.rymcu.forest.openai.service; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.theokanning.openai.OpenAiError; +import com.theokanning.openai.OpenAiHttpException; + +import io.reactivex.FlowableEmitter; + +import okhttp3.ResponseBody; +import retrofit2.Call; +import retrofit2.Callback; +import retrofit2.HttpException; +import retrofit2.Response; + +/** + * Callback to parse Server Sent Events (SSE) from raw InputStream and + * emit the events with io.reactivex.FlowableEmitter to allow streaming of + * SSE. + */ +public class ResponseBodyCallback implements Callback { + private static final ObjectMapper mapper = OpenAiService.defaultObjectMapper(); + + private FlowableEmitter emitter; + private boolean emitDone; + + public ResponseBodyCallback(FlowableEmitter emitter, boolean emitDone) { + this.emitter = emitter; + this.emitDone = emitDone; + } + + @Override + public void onResponse(Call call, Response response) { + BufferedReader reader = null; + + try { + if (!response.isSuccessful()) { + HttpException e = new HttpException(response); + ResponseBody errorBody = response.errorBody(); + + if (errorBody == null) { + throw e; + } else { + OpenAiError error = mapper.readValue( + errorBody.string(), + OpenAiError.class + ); + throw new OpenAiHttpException(error, e, e.code()); + } + } + + InputStream in = response.body().byteStream(); + reader = new BufferedReader(new InputStreamReader(in)); + String line; + SSE sse = null; + + while ((line = reader.readLine()) != null) { + if (line.startsWith("data:")) { + String data = line.substring(5).trim(); + sse = new SSE(data); + } else if (line.equals("") && sse != null) { + if (sse.isDone()) { + if (emitDone) { + emitter.onNext(sse); + } + break; + } + + emitter.onNext(sse); + sse = null; + } else { + throw new SSEFormatException("Invalid sse format! " + line); + } + } + + emitter.onComplete(); + + } catch (Throwable t) { + onFailure(call, t); + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + // do nothing + } + } + } + } + + @Override + public void onFailure(Call call, Throwable t) { + emitter.onError(t); + } +} diff --git a/src/main/java/com/rymcu/forest/openai/service/SSE.java b/src/main/java/com/rymcu/forest/openai/service/SSE.java new file mode 100644 index 0000000..21da02a --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/SSE.java @@ -0,0 +1,26 @@ +package com.rymcu.forest.openai.service; + +/** + * Simple Server Sent Event representation + */ +public class SSE { + private static final String DONE_DATA = "[DONE]"; + + private final String data; + + public SSE(String data){ + this.data = data; + } + + public String getData(){ + return this.data; + } + + public byte[] toBytes(){ + return String.format("data: %s\n\n", this.data).getBytes(); + } + + public boolean isDone(){ + return DONE_DATA.equalsIgnoreCase(this.data); + } +} diff --git a/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java b/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java new file mode 100644 index 0000000..55a468e --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java @@ -0,0 +1,10 @@ +package com.rymcu.forest.openai.service; + +/** + * Exception indicating a SSE format error + */ +public class SSEFormatException extends Throwable{ + public SSEFormatException(String msg){ + super(msg); + } +} diff --git a/src/main/java/com/rymcu/forest/openai/service/SseService.java b/src/main/java/com/rymcu/forest/openai/service/SseService.java new file mode 100644 index 0000000..d8b4372 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/SseService.java @@ -0,0 +1,19 @@ +package com.rymcu.forest.openai.service; + +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +/** + * Created on 2023/5/26 11:24. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai.service + */ + +public interface SseService { + SseEmitter connect(Long idUser); + + boolean send(Long idUser, String content); + + void close(Long idUser); +} diff --git a/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java b/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java new file mode 100644 index 0000000..a97cd66 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java @@ -0,0 +1,92 @@ +package com.rymcu.forest.openai.service.impl; + +import com.rymcu.forest.openai.service.SseService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Created on 2023/5/26 11:38. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai.service.impl + */ +@Slf4j +@Service +public class SseServiceImpl implements SseService { + + private static final Map sessionMap = new ConcurrentHashMap<>(); + + @Override + public SseEmitter connect(Long idUser) { + if (existsUser(idUser)) { + removeUser(idUser); + } + SseEmitter sseEmitter = new SseEmitter(0L); + sseEmitter.onError((err) -> { + log.error("type: SseSession Error, msg: {} session Id : {}", err.getMessage(), idUser); + onError(idUser, err); + }); + + sseEmitter.onTimeout(() -> { + log.info("type: SseSession Timeout, session Id : {}", idUser); + removeUser(idUser); + }); + + sseEmitter.onCompletion(() -> { + log.info("type: SseSession Completion, session Id : {}", idUser); + removeUser(idUser); + }); + addUser(idUser, sseEmitter); + return sseEmitter; + } + + @Override + public boolean send(Long idUser, String content) { + if (existsUser(idUser)) { + try { + sendMessage(idUser, content); + return true; + } catch (IOException exception) { + log.error("type: SseSession send Error:IOException, msg: {} session Id : {}", exception.getMessage(), idUser); + } + } else { + throw new IllegalArgumentException("User Id " + idUser + " not Found"); + } + return false; + } + + @Override + public void close(Long idUser) { + log.info("type: SseSession Close, session Id : {}", idUser); + removeUser(idUser); + } + + private void addUser(Long idUser, SseEmitter sseEmitter) { + sessionMap.put(idUser, sseEmitter); + } + + private void onError(Long sessionKey, Throwable throwable) { + SseEmitter sseEmitter = sessionMap.get(sessionKey); + if (sseEmitter != null) { + sseEmitter.completeWithError(throwable); + } + } + + private void removeUser(Long idUser) { + sessionMap.remove(idUser); + } + + private boolean existsUser(Long idUser) { + return sessionMap.containsKey(idUser); + } + + private void sendMessage(Long idUser, String content) throws IOException { + sessionMap.get(idUser).send(content); + } +}