choices = service.createChatCompletion(completionRequest).getChoices();
- return GlobalResultGenerator.genSuccessResult(choices);
+ service.streamChatCompletion(completionRequest).doOnError(Throwable::printStackTrace)
+ .blockingForEach(chunk -> {
+ String text = chunk.getChoices().get(0).getMessage().getContent();
+ if (text == null) {
+ return;
+ }
+ System.out.print(text);
+ sseService.send(user.getIdUser(), text);
+ });
+ service.shutdownExecutor();
+ return GlobalResultGenerator.genSuccessResult();
}
-
}
diff --git a/src/main/java/com/rymcu/forest/openai/SeeController.java b/src/main/java/com/rymcu/forest/openai/SeeController.java
new file mode 100644
index 0000000..948ed20
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/SeeController.java
@@ -0,0 +1,39 @@
+package com.rymcu.forest.openai;
+
+import com.rymcu.forest.entity.User;
+import com.rymcu.forest.openai.service.SseService;
+import com.rymcu.forest.util.UserUtils;
+import org.springframework.http.MediaType;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PathVariable;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import javax.annotation.Resource;
+
+/**
+ * Created on 2023/5/26 11:26.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.openai
+ */
+@RestController
+@RequestMapping("/api/v1/sse")
+public class SeeController {
+
+ @Resource
+ private SseService sseService;
+
+ @GetMapping(value = "/subscribe/{idUser}", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
+ public SseEmitter subscribe(@PathVariable Long idUser) {
+ return sseService.connect(idUser);
+ }
+
+ @GetMapping(value = "/close/{idUser}")
+ public void close(@PathVariable Long idUser) {
+ sseService.close(idUser);
+ }
+
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
index 2781b74..ad468b8 100644
--- a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
+++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
@@ -9,8 +9,10 @@ import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
+import com.theokanning.openai.completion.CompletionChunk;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
+import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.edit.EditRequest;
@@ -28,11 +30,15 @@ import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
+
+import io.reactivex.BackpressureStrategy;
+import io.reactivex.Flowable;
import io.reactivex.Single;
import okhttp3.*;
import org.springframework.core.env.Environment;
import retrofit2.HttpException;
import retrofit2.Retrofit;
+import retrofit2.Call;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
@@ -40,18 +46,19 @@ import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
public class OpenAiService {
private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
- private static final ObjectMapper errorMapper = defaultObjectMapper();
+ private static final ObjectMapper mapper = defaultObjectMapper();
private final OpenAiApi api;
+ private final ExecutorService executorService;
private static final Environment env = SpringContextHolder.getBean(Environment.class);
-
/**
* Creates a new OpenAiService that wraps OpenAiApi
*
@@ -68,17 +75,39 @@ public class OpenAiService {
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
- this(buildApi(token, timeout));
+ ObjectMapper mapper = defaultObjectMapper();
+ OkHttpClient client = defaultClient(token, timeout);
+ Retrofit retrofit = defaultRetrofit(client, mapper);
+
+ this.api = retrofit.create(OpenAiApi.class);
+ this.executorService = client.dispatcher().executorService();
}
/**
* Creates a new OpenAiService that wraps OpenAiApi.
- * Use this if you need more customization.
+ * Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and
+ * want to shut down instantly
*
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(final OpenAiApi api) {
this.api = api;
+ this.executorService = null;
+ }
+
+ /**
+ * Creates a new OpenAiService that wraps OpenAiApi.
+ * The ExecutorService must be the one you get from the client you created the api with
+ * otherwise shutdownExecutor() won't work.
+ *
+ * Use this if you need more customization.
+ *
+ * @param api OpenAiApi instance to use for all methods
+ * @param executorService the ExecutorService from client.dispatcher().executorService()
+ */
+ public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
+ this.api = api;
+ this.executorService = executorService;
}
public List listModels() {
@@ -93,10 +122,22 @@ public class OpenAiService {
return execute(api.createCompletion(request));
}
+ public Flowable streamCompletion(CompletionRequest request) {
+ request.setStream(true);
+
+ return stream(api.createCompletionStream(request), CompletionChunk.class);
+ }
+
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}
+ public Flowable streamChatCompletion(ChatCompletionRequest request) {
+ request.setStream(true);
+
+ return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
+ }
+
public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
}
@@ -227,7 +268,7 @@ public class OpenAiService {
}
String errorBody = e.response().errorBody().string();
- OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
+ OpenAiError error = mapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) {
// couldn't parse OpenAI error
@@ -236,8 +277,49 @@ public class OpenAiService {
}
}
+ /**
+ * Calls the Open AI api and returns a Flowable of SSE for streaming
+ * omitting the last message.
+ *
+ * @param apiCall The api call
+ */
+ public static Flowable stream(Call apiCall) {
+ return stream(apiCall, false);
+ }
+
+ /**
+ * Calls the Open AI api and returns a Flowable of SSE for streaming.
+ *
+ * @param apiCall The api call
+ * @param emitDone If true the last message ([DONE]) is emitted
+ */
+ public static Flowable stream(Call apiCall, boolean emitDone) {
+ return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER);
+ }
+
+ /**
+ * Calls the Open AI api and returns a Flowable of type T for streaming
+ * omitting the last message.
+ *
+ * @param apiCall The api call
+ * @param cl Class of type T to return
+ */
+ public static Flowable stream(Call apiCall, Class cl) {
+ return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
+ }
+
+ /**
+ * Shuts down the OkHttp ExecutorService.
+ * The default behaviour of OkHttp's ExecutorService (ConnectionPool)
+ * is to shut down after an idle timeout of 60s.
+ * Call this method to shut down the ExecutorService immediately.
+ */
+ public void shutdownExecutor() {
+ Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down");
+ this.executorService.shutdown();
+ }
+
public static OpenAiApi buildApi(String token, Duration timeout) {
- Objects.requireNonNull(token, "OpenAI token required");
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);
diff --git a/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java b/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java
new file mode 100644
index 0000000..b1cbbe9
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/ResponseBodyCallback.java
@@ -0,0 +1,99 @@
+package com.rymcu.forest.openai.service;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.theokanning.openai.OpenAiError;
+import com.theokanning.openai.OpenAiHttpException;
+
+import io.reactivex.FlowableEmitter;
+
+import okhttp3.ResponseBody;
+import retrofit2.Call;
+import retrofit2.Callback;
+import retrofit2.HttpException;
+import retrofit2.Response;
+
+/**
+ * Callback to parse Server Sent Events (SSE) from raw InputStream and
+ * emit the events with io.reactivex.FlowableEmitter to allow streaming of
+ * SSE.
+ */
+public class ResponseBodyCallback implements Callback {
+ private static final ObjectMapper mapper = OpenAiService.defaultObjectMapper();
+
+ private FlowableEmitter emitter;
+ private boolean emitDone;
+
+ public ResponseBodyCallback(FlowableEmitter emitter, boolean emitDone) {
+ this.emitter = emitter;
+ this.emitDone = emitDone;
+ }
+
+ @Override
+ public void onResponse(Call call, Response response) {
+ BufferedReader reader = null;
+
+ try {
+ if (!response.isSuccessful()) {
+ HttpException e = new HttpException(response);
+ ResponseBody errorBody = response.errorBody();
+
+ if (errorBody == null) {
+ throw e;
+ } else {
+ OpenAiError error = mapper.readValue(
+ errorBody.string(),
+ OpenAiError.class
+ );
+ throw new OpenAiHttpException(error, e, e.code());
+ }
+ }
+
+ InputStream in = response.body().byteStream();
+ reader = new BufferedReader(new InputStreamReader(in));
+ String line;
+ SSE sse = null;
+
+ while ((line = reader.readLine()) != null) {
+ if (line.startsWith("data:")) {
+ String data = line.substring(5).trim();
+ sse = new SSE(data);
+ } else if (line.equals("") && sse != null) {
+ if (sse.isDone()) {
+ if (emitDone) {
+ emitter.onNext(sse);
+ }
+ break;
+ }
+
+ emitter.onNext(sse);
+ sse = null;
+ } else {
+ throw new SSEFormatException("Invalid sse format! " + line);
+ }
+ }
+
+ emitter.onComplete();
+
+ } catch (Throwable t) {
+ onFailure(call, t);
+ } finally {
+ if (reader != null) {
+ try {
+ reader.close();
+ } catch (IOException e) {
+ // do nothing
+ }
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(Call call, Throwable t) {
+ emitter.onError(t);
+ }
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/SSE.java b/src/main/java/com/rymcu/forest/openai/service/SSE.java
new file mode 100644
index 0000000..21da02a
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/SSE.java
@@ -0,0 +1,26 @@
+package com.rymcu.forest.openai.service;
+
+/**
+ * Simple Server Sent Event representation
+ */
+public class SSE {
+ private static final String DONE_DATA = "[DONE]";
+
+ private final String data;
+
+ public SSE(String data){
+ this.data = data;
+ }
+
+ public String getData(){
+ return this.data;
+ }
+
+ public byte[] toBytes(){
+ return String.format("data: %s\n\n", this.data).getBytes();
+ }
+
+ public boolean isDone(){
+ return DONE_DATA.equalsIgnoreCase(this.data);
+ }
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java b/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java
new file mode 100644
index 0000000..55a468e
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/SSEFormatException.java
@@ -0,0 +1,10 @@
+package com.rymcu.forest.openai.service;
+
+/**
+ * Exception indicating a SSE format error
+ */
+public class SSEFormatException extends Throwable{
+ public SSEFormatException(String msg){
+ super(msg);
+ }
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/SseService.java b/src/main/java/com/rymcu/forest/openai/service/SseService.java
new file mode 100644
index 0000000..d8b4372
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/SseService.java
@@ -0,0 +1,19 @@
+package com.rymcu.forest.openai.service;
+
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+/**
+ * Created on 2023/5/26 11:24.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.openai.service
+ */
+
+public interface SseService {
+ SseEmitter connect(Long idUser);
+
+ boolean send(Long idUser, String content);
+
+ void close(Long idUser);
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java b/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java
new file mode 100644
index 0000000..a97cd66
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/impl/SseServiceImpl.java
@@ -0,0 +1,92 @@
+package com.rymcu.forest.openai.service.impl;
+
+import com.rymcu.forest.openai.service.SseService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Service;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Created on 2023/5/26 11:38.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.openai.service.impl
+ */
+@Slf4j
+@Service
+public class SseServiceImpl implements SseService {
+
+ private static final Map sessionMap = new ConcurrentHashMap<>();
+
+ @Override
+ public SseEmitter connect(Long idUser) {
+ if (existsUser(idUser)) {
+ removeUser(idUser);
+ }
+ SseEmitter sseEmitter = new SseEmitter(0L);
+ sseEmitter.onError((err) -> {
+ log.error("type: SseSession Error, msg: {} session Id : {}", err.getMessage(), idUser);
+ onError(idUser, err);
+ });
+
+ sseEmitter.onTimeout(() -> {
+ log.info("type: SseSession Timeout, session Id : {}", idUser);
+ removeUser(idUser);
+ });
+
+ sseEmitter.onCompletion(() -> {
+ log.info("type: SseSession Completion, session Id : {}", idUser);
+ removeUser(idUser);
+ });
+ addUser(idUser, sseEmitter);
+ return sseEmitter;
+ }
+
+ @Override
+ public boolean send(Long idUser, String content) {
+ if (existsUser(idUser)) {
+ try {
+ sendMessage(idUser, content);
+ return true;
+ } catch (IOException exception) {
+ log.error("type: SseSession send Error:IOException, msg: {} session Id : {}", exception.getMessage(), idUser);
+ }
+ } else {
+ throw new IllegalArgumentException("User Id " + idUser + " not Found");
+ }
+ return false;
+ }
+
+ @Override
+ public void close(Long idUser) {
+ log.info("type: SseSession Close, session Id : {}", idUser);
+ removeUser(idUser);
+ }
+
+ private void addUser(Long idUser, SseEmitter sseEmitter) {
+ sessionMap.put(idUser, sseEmitter);
+ }
+
+ private void onError(Long sessionKey, Throwable throwable) {
+ SseEmitter sseEmitter = sessionMap.get(sessionKey);
+ if (sseEmitter != null) {
+ sseEmitter.completeWithError(throwable);
+ }
+ }
+
+ private void removeUser(Long idUser) {
+ sessionMap.remove(idUser);
+ }
+
+ private boolean existsUser(Long idUser) {
+ return sessionMap.containsKey(idUser);
+ }
+
+ private void sendMessage(Long idUser, String content) throws IOException {
+ sessionMap.get(idUser).send(content);
+ }
+}