🎨 实现 chatGPT 流式接口
This commit is contained in:
parent
02c809f2d2
commit
e7f9ddc35d
14
pom.xml
14
pom.xml
@ -129,7 +129,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>net.minidev</groupId>
|
<groupId>net.minidev</groupId>
|
||||||
<artifactId>json-smart</artifactId>
|
<artifactId>json-smart</artifactId>
|
||||||
<version>2.4.8</version>
|
<version>2.4.9</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
@ -328,7 +328,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||||
<artifactId>client</artifactId>
|
<artifactId>client</artifactId>
|
||||||
<version>0.11.0</version>
|
<version>0.12.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.squareup.retrofit2</groupId>
|
<groupId>com.squareup.retrofit2</groupId>
|
||||||
@ -355,6 +355,12 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-actuator</artifactId>
|
<artifactId>spring-boot-starter-actuator</artifactId>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>io.micrometer</groupId>
|
<groupId>io.micrometer</groupId>
|
||||||
@ -429,7 +435,7 @@
|
|||||||
<profileActive>dev</profileActive>
|
<profileActive>dev</profileActive>
|
||||||
</properties>
|
</properties>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>true</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
@ -438,7 +444,7 @@
|
|||||||
<profileActive>prod</profileActive>
|
<profileActive>prod</profileActive>
|
||||||
</properties>
|
</properties>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>true</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
@ -45,6 +45,7 @@ public class ShiroConfig {
|
|||||||
filterChainDefinitionMap.put("/api/v1/auth/login/**", "anon");
|
filterChainDefinitionMap.put("/api/v1/auth/login/**", "anon");
|
||||||
filterChainDefinitionMap.put("/api/v1/auth/logout/**", "anon");
|
filterChainDefinitionMap.put("/api/v1/auth/logout/**", "anon");
|
||||||
filterChainDefinitionMap.put("/api/v1/auth/refresh-token/**", "anon");
|
filterChainDefinitionMap.put("/api/v1/auth/refresh-token/**", "anon");
|
||||||
|
filterChainDefinitionMap.put("/api/v1/sse/**", "anon");
|
||||||
filterChainDefinitionMap.put("/**", "jwt");
|
filterChainDefinitionMap.put("/**", "jwt");
|
||||||
shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
|
shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
|
||||||
|
|
||||||
|
@ -3,10 +3,15 @@ package com.rymcu.forest.openai;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.rymcu.forest.core.result.GlobalResult;
|
import com.rymcu.forest.core.result.GlobalResult;
|
||||||
import com.rymcu.forest.core.result.GlobalResultGenerator;
|
import com.rymcu.forest.core.result.GlobalResultGenerator;
|
||||||
|
import com.rymcu.forest.entity.User;
|
||||||
import com.rymcu.forest.openai.service.OpenAiService;
|
import com.rymcu.forest.openai.service.OpenAiService;
|
||||||
|
import com.rymcu.forest.openai.service.SseService;
|
||||||
|
import com.rymcu.forest.util.UserUtils;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
|
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
|
||||||
|
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
|
import io.reactivex.Flowable;
|
||||||
import org.apache.commons.lang.StringUtils;
|
import org.apache.commons.lang.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
@ -14,6 +19,7 @@ import org.springframework.web.bind.annotation.RequestBody;
|
|||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
import javax.annotation.Resource;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -28,26 +34,38 @@ import java.util.List;
|
|||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/api/v1/openai")
|
@RequestMapping("/api/v1/openai")
|
||||||
public class OpenAiController {
|
public class OpenAiController {
|
||||||
|
@Resource
|
||||||
|
private SseService sseService;
|
||||||
|
|
||||||
@Value("${openai.token}")
|
@Value("${openai.token}")
|
||||||
private String token;
|
private String token;
|
||||||
|
|
||||||
@PostMapping("/chat")
|
@PostMapping("/chat")
|
||||||
public GlobalResult<List<ChatCompletionChoice>> chat(@RequestBody JSONObject jsonObject) {
|
public GlobalResult chat(@RequestBody JSONObject jsonObject) {
|
||||||
String message = jsonObject.getString("message");
|
String message = jsonObject.getString("message");
|
||||||
if (StringUtils.isBlank(message)) {
|
if (StringUtils.isBlank(message)) {
|
||||||
throw new IllegalArgumentException("参数异常!");
|
throw new IllegalArgumentException("参数异常!");
|
||||||
}
|
}
|
||||||
|
User user = UserUtils.getCurrentUserByToken();
|
||||||
ChatMessage chatMessage = new ChatMessage("user", message);
|
ChatMessage chatMessage = new ChatMessage("user", message);
|
||||||
List<ChatMessage> list = new ArrayList();
|
List<ChatMessage> list = new ArrayList<>(4);
|
||||||
list.add(chatMessage);
|
list.add(chatMessage);
|
||||||
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
|
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
|
||||||
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
|
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
|
||||||
.model("gpt-3.5-turbo")
|
.model("gpt-3.5-turbo")
|
||||||
|
.stream(true)
|
||||||
.messages(list)
|
.messages(list)
|
||||||
.build();
|
.build();
|
||||||
List<ChatCompletionChoice> choices = service.createChatCompletion(completionRequest).getChoices();
|
service.streamChatCompletion(completionRequest).doOnError(Throwable::printStackTrace)
|
||||||
return GlobalResultGenerator.genSuccessResult(choices);
|
.blockingForEach(chunk -> {
|
||||||
|
String text = chunk.getChoices().get(0).getMessage().getContent();
|
||||||
|
if (text == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
System.out.print(text);
|
||||||
|
sseService.send(user.getIdUser(), text);
|
||||||
|
});
|
||||||
|
service.shutdownExecutor();
|
||||||
|
return GlobalResultGenerator.genSuccessResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
39
src/main/java/com/rymcu/forest/openai/SeeController.java
Normal file
39
src/main/java/com/rymcu/forest/openai/SeeController.java
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package com.rymcu.forest.openai;
|
||||||
|
|
||||||
|
import com.rymcu.forest.entity.User;
|
||||||
|
import com.rymcu.forest.openai.service.SseService;
|
||||||
|
import com.rymcu.forest.util.UserUtils;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import javax.annotation.Resource;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Created on 2023/5/26 11:26.
|
||||||
|
*
|
||||||
|
* @author ronger
|
||||||
|
* @email ronger-x@outlook.com
|
||||||
|
* @desc : com.rymcu.forest.openai
|
||||||
|
*/
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/api/v1/sse")
|
||||||
|
public class SeeController {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private SseService sseService;
|
||||||
|
|
||||||
|
@GetMapping(value = "/subscribe/{idUser}", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
|
||||||
|
public SseEmitter subscribe(@PathVariable Long idUser) {
|
||||||
|
return sseService.connect(idUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping(value = "/close/{idUser}")
|
||||||
|
public void close(@PathVariable Long idUser) {
|
||||||
|
sseService.close(idUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -9,8 +9,10 @@ import com.theokanning.openai.DeleteResult;
|
|||||||
import com.theokanning.openai.OpenAiApi;
|
import com.theokanning.openai.OpenAiApi;
|
||||||
import com.theokanning.openai.OpenAiError;
|
import com.theokanning.openai.OpenAiError;
|
||||||
import com.theokanning.openai.OpenAiHttpException;
|
import com.theokanning.openai.OpenAiHttpException;
|
||||||
|
import com.theokanning.openai.completion.CompletionChunk;
|
||||||
import com.theokanning.openai.completion.CompletionRequest;
|
import com.theokanning.openai.completion.CompletionRequest;
|
||||||
import com.theokanning.openai.completion.CompletionResult;
|
import com.theokanning.openai.completion.CompletionResult;
|
||||||
|
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||||
import com.theokanning.openai.edit.EditRequest;
|
import com.theokanning.openai.edit.EditRequest;
|
||||||
@ -28,11 +30,15 @@ import com.theokanning.openai.image.ImageResult;
|
|||||||
import com.theokanning.openai.model.Model;
|
import com.theokanning.openai.model.Model;
|
||||||
import com.theokanning.openai.moderation.ModerationRequest;
|
import com.theokanning.openai.moderation.ModerationRequest;
|
||||||
import com.theokanning.openai.moderation.ModerationResult;
|
import com.theokanning.openai.moderation.ModerationResult;
|
||||||
|
|
||||||
|
import io.reactivex.BackpressureStrategy;
|
||||||
|
import io.reactivex.Flowable;
|
||||||
import io.reactivex.Single;
|
import io.reactivex.Single;
|
||||||
import okhttp3.*;
|
import okhttp3.*;
|
||||||
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.Environment;
|
||||||
import retrofit2.HttpException;
|
import retrofit2.HttpException;
|
||||||
import retrofit2.Retrofit;
|
import retrofit2.Retrofit;
|
||||||
|
import retrofit2.Call;
|
||||||
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
|
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
|
||||||
import retrofit2.converter.jackson.JacksonConverterFactory;
|
import retrofit2.converter.jackson.JacksonConverterFactory;
|
||||||
|
|
||||||
@ -40,18 +46,19 @@ import java.io.IOException;
|
|||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
public class OpenAiService {
|
public class OpenAiService {
|
||||||
|
|
||||||
private static final String BASE_URL = "https://api.openai.com/";
|
private static final String BASE_URL = "https://api.openai.com/";
|
||||||
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
|
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
|
||||||
private static final ObjectMapper errorMapper = defaultObjectMapper();
|
private static final ObjectMapper mapper = defaultObjectMapper();
|
||||||
|
|
||||||
private final OpenAiApi api;
|
private final OpenAiApi api;
|
||||||
|
private final ExecutorService executorService;
|
||||||
private static final Environment env = SpringContextHolder.getBean(Environment.class);
|
private static final Environment env = SpringContextHolder.getBean(Environment.class);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new OpenAiService that wraps OpenAiApi
|
* Creates a new OpenAiService that wraps OpenAiApi
|
||||||
*
|
*
|
||||||
@ -68,17 +75,39 @@ public class OpenAiService {
|
|||||||
* @param timeout http read timeout, Duration.ZERO means no timeout
|
* @param timeout http read timeout, Duration.ZERO means no timeout
|
||||||
*/
|
*/
|
||||||
public OpenAiService(final String token, final Duration timeout) {
|
public OpenAiService(final String token, final Duration timeout) {
|
||||||
this(buildApi(token, timeout));
|
ObjectMapper mapper = defaultObjectMapper();
|
||||||
|
OkHttpClient client = defaultClient(token, timeout);
|
||||||
|
Retrofit retrofit = defaultRetrofit(client, mapper);
|
||||||
|
|
||||||
|
this.api = retrofit.create(OpenAiApi.class);
|
||||||
|
this.executorService = client.dispatcher().executorService();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new OpenAiService that wraps OpenAiApi.
|
* Creates a new OpenAiService that wraps OpenAiApi.
|
||||||
* Use this if you need more customization.
|
* Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and
|
||||||
|
* want to shut down instantly
|
||||||
*
|
*
|
||||||
* @param api OpenAiApi instance to use for all methods
|
* @param api OpenAiApi instance to use for all methods
|
||||||
*/
|
*/
|
||||||
public OpenAiService(final OpenAiApi api) {
|
public OpenAiService(final OpenAiApi api) {
|
||||||
this.api = api;
|
this.api = api;
|
||||||
|
this.executorService = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new OpenAiService that wraps OpenAiApi.
|
||||||
|
* The ExecutorService must be the one you get from the client you created the api with
|
||||||
|
* otherwise shutdownExecutor() won't work.
|
||||||
|
* <p>
|
||||||
|
* Use this if you need more customization.
|
||||||
|
*
|
||||||
|
* @param api OpenAiApi instance to use for all methods
|
||||||
|
* @param executorService the ExecutorService from client.dispatcher().executorService()
|
||||||
|
*/
|
||||||
|
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
|
||||||
|
this.api = api;
|
||||||
|
this.executorService = executorService;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Model> listModels() {
|
public List<Model> listModels() {
|
||||||
@ -93,10 +122,22 @@ public class OpenAiService {
|
|||||||
return execute(api.createCompletion(request));
|
return execute(api.createCompletion(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
|
||||||
|
request.setStream(true);
|
||||||
|
|
||||||
|
return stream(api.createCompletionStream(request), CompletionChunk.class);
|
||||||
|
}
|
||||||
|
|
||||||
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
|
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
|
||||||
return execute(api.createChatCompletion(request));
|
return execute(api.createChatCompletion(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
|
||||||
|
request.setStream(true);
|
||||||
|
|
||||||
|
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
|
||||||
|
}
|
||||||
|
|
||||||
public EditResult createEdit(EditRequest request) {
|
public EditResult createEdit(EditRequest request) {
|
||||||
return execute(api.createEdit(request));
|
return execute(api.createEdit(request));
|
||||||
}
|
}
|
||||||
@ -227,7 +268,7 @@ public class OpenAiService {
|
|||||||
}
|
}
|
||||||
String errorBody = e.response().errorBody().string();
|
String errorBody = e.response().errorBody().string();
|
||||||
|
|
||||||
OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
|
OpenAiError error = mapper.readValue(errorBody, OpenAiError.class);
|
||||||
throw new OpenAiHttpException(error, e, e.code());
|
throw new OpenAiHttpException(error, e, e.code());
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
// couldn't parse OpenAI error
|
// couldn't parse OpenAI error
|
||||||
@ -236,8 +277,49 @@ public class OpenAiService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the Open AI api and returns a Flowable of SSE for streaming
|
||||||
|
* omitting the last message.
|
||||||
|
*
|
||||||
|
* @param apiCall The api call
|
||||||
|
*/
|
||||||
|
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
|
||||||
|
return stream(apiCall, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the Open AI api and returns a Flowable of SSE for streaming.
|
||||||
|
*
|
||||||
|
* @param apiCall The api call
|
||||||
|
* @param emitDone If true the last message ([DONE]) is emitted
|
||||||
|
*/
|
||||||
|
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
|
||||||
|
return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the Open AI api and returns a Flowable of type T for streaming
|
||||||
|
* omitting the last message.
|
||||||
|
*
|
||||||
|
* @param apiCall The api call
|
||||||
|
* @param cl Class of type T to return
|
||||||
|
*/
|
||||||
|
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
|
||||||
|
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shuts down the OkHttp ExecutorService.
|
||||||
|
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
|
||||||
|
* is to shut down after an idle timeout of 60s.
|
||||||
|
* Call this method to shut down the ExecutorService immediately.
|
||||||
|
*/
|
||||||
|
public void shutdownExecutor() {
|
||||||
|
Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down");
|
||||||
|
this.executorService.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
public static OpenAiApi buildApi(String token, Duration timeout) {
|
public static OpenAiApi buildApi(String token, Duration timeout) {
|
||||||
Objects.requireNonNull(token, "OpenAI token required");
|
|
||||||
ObjectMapper mapper = defaultObjectMapper();
|
ObjectMapper mapper = defaultObjectMapper();
|
||||||
OkHttpClient client = defaultClient(token, timeout);
|
OkHttpClient client = defaultClient(token, timeout);
|
||||||
Retrofit retrofit = defaultRetrofit(client, mapper);
|
Retrofit retrofit = defaultRetrofit(client, mapper);
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
package com.rymcu.forest.openai.service;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.theokanning.openai.OpenAiError;
|
||||||
|
import com.theokanning.openai.OpenAiHttpException;
|
||||||
|
|
||||||
|
import io.reactivex.FlowableEmitter;
|
||||||
|
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import retrofit2.Call;
|
||||||
|
import retrofit2.Callback;
|
||||||
|
import retrofit2.HttpException;
|
||||||
|
import retrofit2.Response;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback to parse Server Sent Events (SSE) from raw InputStream and
|
||||||
|
* emit the events with io.reactivex.FlowableEmitter to allow streaming of
|
||||||
|
* SSE.
|
||||||
|
*/
|
||||||
|
public class ResponseBodyCallback implements Callback<ResponseBody> {
|
||||||
|
private static final ObjectMapper mapper = OpenAiService.defaultObjectMapper();
|
||||||
|
|
||||||
|
private FlowableEmitter<SSE> emitter;
|
||||||
|
private boolean emitDone;
|
||||||
|
|
||||||
|
public ResponseBodyCallback(FlowableEmitter<SSE> emitter, boolean emitDone) {
|
||||||
|
this.emitter = emitter;
|
||||||
|
this.emitDone = emitDone;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onResponse(Call<ResponseBody> call, Response<ResponseBody> response) {
|
||||||
|
BufferedReader reader = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!response.isSuccessful()) {
|
||||||
|
HttpException e = new HttpException(response);
|
||||||
|
ResponseBody errorBody = response.errorBody();
|
||||||
|
|
||||||
|
if (errorBody == null) {
|
||||||
|
throw e;
|
||||||
|
} else {
|
||||||
|
OpenAiError error = mapper.readValue(
|
||||||
|
errorBody.string(),
|
||||||
|
OpenAiError.class
|
||||||
|
);
|
||||||
|
throw new OpenAiHttpException(error, e, e.code());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InputStream in = response.body().byteStream();
|
||||||
|
reader = new BufferedReader(new InputStreamReader(in));
|
||||||
|
String line;
|
||||||
|
SSE sse = null;
|
||||||
|
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (line.startsWith("data:")) {
|
||||||
|
String data = line.substring(5).trim();
|
||||||
|
sse = new SSE(data);
|
||||||
|
} else if (line.equals("") && sse != null) {
|
||||||
|
if (sse.isDone()) {
|
||||||
|
if (emitDone) {
|
||||||
|
emitter.onNext(sse);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
emitter.onNext(sse);
|
||||||
|
sse = null;
|
||||||
|
} else {
|
||||||
|
throw new SSEFormatException("Invalid sse format! " + line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emitter.onComplete();
|
||||||
|
|
||||||
|
} catch (Throwable t) {
|
||||||
|
onFailure(call, t);
|
||||||
|
} finally {
|
||||||
|
if (reader != null) {
|
||||||
|
try {
|
||||||
|
reader.close();
|
||||||
|
} catch (IOException e) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Call<ResponseBody> call, Throwable t) {
|
||||||
|
emitter.onError(t);
|
||||||
|
}
|
||||||
|
}
|
26
src/main/java/com/rymcu/forest/openai/service/SSE.java
Normal file
26
src/main/java/com/rymcu/forest/openai/service/SSE.java
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package com.rymcu.forest.openai.service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple Server Sent Event representation
|
||||||
|
*/
|
||||||
|
public class SSE {
|
||||||
|
private static final String DONE_DATA = "[DONE]";
|
||||||
|
|
||||||
|
private final String data;
|
||||||
|
|
||||||
|
public SSE(String data){
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getData(){
|
||||||
|
return this.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte[] toBytes(){
|
||||||
|
return String.format("data: %s\n\n", this.data).getBytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isDone(){
|
||||||
|
return DONE_DATA.equalsIgnoreCase(this.data);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,10 @@
|
|||||||
|
package com.rymcu.forest.openai.service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Exception indicating a SSE format error
|
||||||
|
*/
|
||||||
|
public class SSEFormatException extends Throwable{
|
||||||
|
public SSEFormatException(String msg){
|
||||||
|
super(msg);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
package com.rymcu.forest.openai.service;
|
||||||
|
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Created on 2023/5/26 11:24.
|
||||||
|
*
|
||||||
|
* @author ronger
|
||||||
|
* @email ronger-x@outlook.com
|
||||||
|
* @desc : com.rymcu.forest.openai.service
|
||||||
|
*/
|
||||||
|
|
||||||
|
public interface SseService {
|
||||||
|
SseEmitter connect(Long idUser);
|
||||||
|
|
||||||
|
boolean send(Long idUser, String content);
|
||||||
|
|
||||||
|
void close(Long idUser);
|
||||||
|
}
|
@ -0,0 +1,92 @@
|
|||||||
|
package com.rymcu.forest.openai.service.impl;
|
||||||
|
|
||||||
|
import com.rymcu.forest.openai.service.SseService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Created on 2023/5/26 11:38.
|
||||||
|
*
|
||||||
|
* @author ronger
|
||||||
|
* @email ronger-x@outlook.com
|
||||||
|
* @desc : com.rymcu.forest.openai.service.impl
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class SseServiceImpl implements SseService {
|
||||||
|
|
||||||
|
private static final Map<Long, SseEmitter> sessionMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SseEmitter connect(Long idUser) {
|
||||||
|
if (existsUser(idUser)) {
|
||||||
|
removeUser(idUser);
|
||||||
|
}
|
||||||
|
SseEmitter sseEmitter = new SseEmitter(0L);
|
||||||
|
sseEmitter.onError((err) -> {
|
||||||
|
log.error("type: SseSession Error, msg: {} session Id : {}", err.getMessage(), idUser);
|
||||||
|
onError(idUser, err);
|
||||||
|
});
|
||||||
|
|
||||||
|
sseEmitter.onTimeout(() -> {
|
||||||
|
log.info("type: SseSession Timeout, session Id : {}", idUser);
|
||||||
|
removeUser(idUser);
|
||||||
|
});
|
||||||
|
|
||||||
|
sseEmitter.onCompletion(() -> {
|
||||||
|
log.info("type: SseSession Completion, session Id : {}", idUser);
|
||||||
|
removeUser(idUser);
|
||||||
|
});
|
||||||
|
addUser(idUser, sseEmitter);
|
||||||
|
return sseEmitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean send(Long idUser, String content) {
|
||||||
|
if (existsUser(idUser)) {
|
||||||
|
try {
|
||||||
|
sendMessage(idUser, content);
|
||||||
|
return true;
|
||||||
|
} catch (IOException exception) {
|
||||||
|
log.error("type: SseSession send Error:IOException, msg: {} session Id : {}", exception.getMessage(), idUser);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("User Id " + idUser + " not Found");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(Long idUser) {
|
||||||
|
log.info("type: SseSession Close, session Id : {}", idUser);
|
||||||
|
removeUser(idUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addUser(Long idUser, SseEmitter sseEmitter) {
|
||||||
|
sessionMap.put(idUser, sseEmitter);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void onError(Long sessionKey, Throwable throwable) {
|
||||||
|
SseEmitter sseEmitter = sessionMap.get(sessionKey);
|
||||||
|
if (sseEmitter != null) {
|
||||||
|
sseEmitter.completeWithError(throwable);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void removeUser(Long idUser) {
|
||||||
|
sessionMap.remove(idUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean existsUser(Long idUser) {
|
||||||
|
return sessionMap.containsKey(idUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendMessage(Long idUser, String content) throws IOException {
|
||||||
|
sessionMap.get(idUser).send(content);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user