🎨 实现 chatGPT 流式接口

This commit is contained in:
ronger 2023-05-26 15:37:32 +08:00
parent 02c809f2d2
commit e7f9ddc35d
10 changed files with 407 additions and 15 deletions

14
pom.xml
View File

@ -129,7 +129,7 @@
<dependency> <dependency>
<groupId>net.minidev</groupId> <groupId>net.minidev</groupId>
<artifactId>json-smart</artifactId> <artifactId>json-smart</artifactId>
<version>2.4.8</version> <version>2.4.9</version>
</dependency> </dependency>
<dependency> <dependency>
@ -328,7 +328,7 @@
<dependency> <dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId> <groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>client</artifactId> <artifactId>client</artifactId>
<version>0.11.0</version> <version>0.12.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.squareup.retrofit2</groupId> <groupId>com.squareup.retrofit2</groupId>
@ -355,6 +355,12 @@
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId> <artifactId>spring-boot-starter-actuator</artifactId>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>io.micrometer</groupId> <groupId>io.micrometer</groupId>
@ -429,7 +435,7 @@
<profileActive>dev</profileActive> <profileActive>dev</profileActive>
</properties> </properties>
<activation> <activation>
<activeByDefault>false</activeByDefault> <activeByDefault>true</activeByDefault>
</activation> </activation>
</profile> </profile>
<profile> <profile>
@ -438,7 +444,7 @@
<profileActive>prod</profileActive> <profileActive>prod</profileActive>
</properties> </properties>
<activation> <activation>
<activeByDefault>true</activeByDefault> <activeByDefault>false</activeByDefault>
</activation> </activation>
</profile> </profile>
</profiles> </profiles>

View File

@ -45,6 +45,7 @@ public class ShiroConfig {
filterChainDefinitionMap.put("/api/v1/auth/login/**", "anon"); filterChainDefinitionMap.put("/api/v1/auth/login/**", "anon");
filterChainDefinitionMap.put("/api/v1/auth/logout/**", "anon"); filterChainDefinitionMap.put("/api/v1/auth/logout/**", "anon");
filterChainDefinitionMap.put("/api/v1/auth/refresh-token/**", "anon"); filterChainDefinitionMap.put("/api/v1/auth/refresh-token/**", "anon");
filterChainDefinitionMap.put("/api/v1/sse/**", "anon");
filterChainDefinitionMap.put("/**", "jwt"); filterChainDefinitionMap.put("/**", "jwt");
shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap); shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);

View File

@ -3,10 +3,15 @@ package com.rymcu.forest.openai;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.rymcu.forest.core.result.GlobalResult; import com.rymcu.forest.core.result.GlobalResult;
import com.rymcu.forest.core.result.GlobalResultGenerator; import com.rymcu.forest.core.result.GlobalResultGenerator;
import com.rymcu.forest.entity.User;
import com.rymcu.forest.openai.service.OpenAiService; import com.rymcu.forest.openai.service.OpenAiService;
import com.rymcu.forest.openai.service.SseService;
import com.rymcu.forest.util.UserUtils;
import com.theokanning.openai.completion.chat.ChatCompletionChoice; import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessage;
import io.reactivex.Flowable;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
@ -14,6 +19,7 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -28,26 +34,38 @@ import java.util.List;
@RestController @RestController
@RequestMapping("/api/v1/openai") @RequestMapping("/api/v1/openai")
public class OpenAiController { public class OpenAiController {
@Resource
private SseService sseService;
@Value("${openai.token}") @Value("${openai.token}")
private String token; private String token;
@PostMapping("/chat") @PostMapping("/chat")
public GlobalResult<List<ChatCompletionChoice>> chat(@RequestBody JSONObject jsonObject) { public GlobalResult chat(@RequestBody JSONObject jsonObject) {
String message = jsonObject.getString("message"); String message = jsonObject.getString("message");
if (StringUtils.isBlank(message)) { if (StringUtils.isBlank(message)) {
throw new IllegalArgumentException("参数异常!"); throw new IllegalArgumentException("参数异常!");
} }
User user = UserUtils.getCurrentUserByToken();
ChatMessage chatMessage = new ChatMessage("user", message); ChatMessage chatMessage = new ChatMessage("user", message);
List<ChatMessage> list = new ArrayList(); List<ChatMessage> list = new ArrayList<>(4);
list.add(chatMessage); list.add(chatMessage);
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
.model("gpt-3.5-turbo") .model("gpt-3.5-turbo")
.stream(true)
.messages(list) .messages(list)
.build(); .build();
List<ChatCompletionChoice> choices = service.createChatCompletion(completionRequest).getChoices(); service.streamChatCompletion(completionRequest).doOnError(Throwable::printStackTrace)
return GlobalResultGenerator.genSuccessResult(choices); .blockingForEach(chunk -> {
String text = chunk.getChoices().get(0).getMessage().getContent();
if (text == null) {
return;
}
System.out.print(text);
sseService.send(user.getIdUser(), text);
});
service.shutdownExecutor();
return GlobalResultGenerator.genSuccessResult();
} }
} }

View File

@ -0,0 +1,39 @@
package com.rymcu.forest.openai;
import com.rymcu.forest.entity.User;
import com.rymcu.forest.openai.service.SseService;
import com.rymcu.forest.util.UserUtils;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import javax.annotation.Resource;
/**
* Created on 2023/5/26 11:26.
*
* @author ronger
* @email ronger-x@outlook.com
* @desc : com.rymcu.forest.openai
*/
@RestController
@RequestMapping("/api/v1/sse")
public class SeeController {
@Resource
private SseService sseService;
@GetMapping(value = "/subscribe/{idUser}", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
public SseEmitter subscribe(@PathVariable Long idUser) {
return sseService.connect(idUser);
}
@GetMapping(value = "/close/{idUser}")
public void close(@PathVariable Long idUser) {
sseService.close(idUser);
}
}

View File

@ -9,8 +9,10 @@ import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.OpenAiApi; import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError; import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException; import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionChunk;
import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult; import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.edit.EditRequest; import com.theokanning.openai.edit.EditRequest;
@ -28,11 +30,15 @@ import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.model.Model; import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest; import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult; import com.theokanning.openai.moderation.ModerationResult;
import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.Single; import io.reactivex.Single;
import okhttp3.*; import okhttp3.*;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import retrofit2.HttpException; import retrofit2.HttpException;
import retrofit2.Retrofit; import retrofit2.Retrofit;
import retrofit2.Call;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory; import retrofit2.converter.jackson.JacksonConverterFactory;
@ -40,18 +46,19 @@ import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class OpenAiService { public class OpenAiService {
private static final String BASE_URL = "https://api.openai.com/"; private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10); private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
private static final ObjectMapper errorMapper = defaultObjectMapper(); private static final ObjectMapper mapper = defaultObjectMapper();
private final OpenAiApi api; private final OpenAiApi api;
private final ExecutorService executorService;
private static final Environment env = SpringContextHolder.getBean(Environment.class); private static final Environment env = SpringContextHolder.getBean(Environment.class);
/** /**
* Creates a new OpenAiService that wraps OpenAiApi * Creates a new OpenAiService that wraps OpenAiApi
* *
@ -68,17 +75,39 @@ public class OpenAiService {
* @param timeout http read timeout, Duration.ZERO means no timeout * @param timeout http read timeout, Duration.ZERO means no timeout
*/ */
public OpenAiService(final String token, final Duration timeout) { public OpenAiService(final String token, final Duration timeout) {
this(buildApi(token, timeout)); ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);
this.api = retrofit.create(OpenAiApi.class);
this.executorService = client.dispatcher().executorService();
} }
/** /**
* Creates a new OpenAiService that wraps OpenAiApi. * Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization. * Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and
* want to shut down instantly
* *
* @param api OpenAiApi instance to use for all methods * @param api OpenAiApi instance to use for all methods
*/ */
public OpenAiService(final OpenAiApi api) { public OpenAiService(final OpenAiApi api) {
this.api = api; this.api = api;
this.executorService = null;
}
/**
* Creates a new OpenAiService that wraps OpenAiApi.
* The ExecutorService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work.
* <p>
* Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
* @param executorService the ExecutorService from client.dispatcher().executorService()
*/
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
this.api = api;
this.executorService = executorService;
} }
public List<Model> listModels() { public List<Model> listModels() {
@ -93,10 +122,22 @@ public class OpenAiService {
return execute(api.createCompletion(request)); return execute(api.createCompletion(request));
} }
public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
request.setStream(true);
return stream(api.createCompletionStream(request), CompletionChunk.class);
}
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request)); return execute(api.createChatCompletion(request));
} }
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}
public EditResult createEdit(EditRequest request) { public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request)); return execute(api.createEdit(request));
} }
@ -227,7 +268,7 @@ public class OpenAiService {
} }
String errorBody = e.response().errorBody().string(); String errorBody = e.response().errorBody().string();
OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class); OpenAiError error = mapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code()); throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) { } catch (IOException ex) {
// couldn't parse OpenAI error // couldn't parse OpenAI error
@ -236,8 +277,49 @@ public class OpenAiService {
} }
} }
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming
* omitting the last message.
*
* @param apiCall The api call
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
return stream(apiCall, false);
}
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming.
*
* @param apiCall The api call
* @param emitDone If true the last message ([DONE]) is emitted
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER);
}
/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
*
* @param apiCall The api call
* @param cl Class of type T to return
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
}
/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shut down after an idle timeout of 60s.
* Call this method to shut down the ExecutorService immediately.
*/
public void shutdownExecutor() {
Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down");
this.executorService.shutdown();
}
public static OpenAiApi buildApi(String token, Duration timeout) { public static OpenAiApi buildApi(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");
ObjectMapper mapper = defaultObjectMapper(); ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout); OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper); Retrofit retrofit = defaultRetrofit(client, mapper);

View File

@ -0,0 +1,99 @@
package com.rymcu.forest.openai.service;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
import io.reactivex.FlowableEmitter;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.HttpException;
import retrofit2.Response;
/**
* Callback to parse Server Sent Events (SSE) from raw InputStream and
* emit the events with io.reactivex.FlowableEmitter to allow streaming of
* SSE.
*/
public class ResponseBodyCallback implements Callback<ResponseBody> {
private static final ObjectMapper mapper = OpenAiService.defaultObjectMapper();
private FlowableEmitter<SSE> emitter;
private boolean emitDone;
public ResponseBodyCallback(FlowableEmitter<SSE> emitter, boolean emitDone) {
this.emitter = emitter;
this.emitDone = emitDone;
}
@Override
public void onResponse(Call<ResponseBody> call, Response<ResponseBody> response) {
BufferedReader reader = null;
try {
if (!response.isSuccessful()) {
HttpException e = new HttpException(response);
ResponseBody errorBody = response.errorBody();
if (errorBody == null) {
throw e;
} else {
OpenAiError error = mapper.readValue(
errorBody.string(),
OpenAiError.class
);
throw new OpenAiHttpException(error, e, e.code());
}
}
InputStream in = response.body().byteStream();
reader = new BufferedReader(new InputStreamReader(in));
String line;
SSE sse = null;
while ((line = reader.readLine()) != null) {
if (line.startsWith("data:")) {
String data = line.substring(5).trim();
sse = new SSE(data);
} else if (line.equals("") && sse != null) {
if (sse.isDone()) {
if (emitDone) {
emitter.onNext(sse);
}
break;
}
emitter.onNext(sse);
sse = null;
} else {
throw new SSEFormatException("Invalid sse format! " + line);
}
}
emitter.onComplete();
} catch (Throwable t) {
onFailure(call, t);
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
// do nothing
}
}
}
}
@Override
public void onFailure(Call<ResponseBody> call, Throwable t) {
emitter.onError(t);
}
}

View File

@ -0,0 +1,26 @@
package com.rymcu.forest.openai.service;
/**
* Simple Server Sent Event representation
*/
public class SSE {
private static final String DONE_DATA = "[DONE]";
private final String data;
public SSE(String data){
this.data = data;
}
public String getData(){
return this.data;
}
public byte[] toBytes(){
return String.format("data: %s\n\n", this.data).getBytes();
}
public boolean isDone(){
return DONE_DATA.equalsIgnoreCase(this.data);
}
}

View File

@ -0,0 +1,10 @@
package com.rymcu.forest.openai.service;
/**
* Exception indicating a SSE format error
*/
public class SSEFormatException extends Throwable{
public SSEFormatException(String msg){
super(msg);
}
}

View File

@ -0,0 +1,19 @@
package com.rymcu.forest.openai.service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* Created on 2023/5/26 11:24.
*
* @author ronger
* @email ronger-x@outlook.com
* @desc : com.rymcu.forest.openai.service
*/
public interface SseService {
SseEmitter connect(Long idUser);
boolean send(Long idUser, String content);
void close(Long idUser);
}

View File

@ -0,0 +1,92 @@
package com.rymcu.forest.openai.service.impl;
import com.rymcu.forest.openai.service.SseService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Created on 2023/5/26 11:38.
*
* @author ronger
* @email ronger-x@outlook.com
* @desc : com.rymcu.forest.openai.service.impl
*/
@Slf4j
@Service
public class SseServiceImpl implements SseService {
private static final Map<Long, SseEmitter> sessionMap = new ConcurrentHashMap<>();
@Override
public SseEmitter connect(Long idUser) {
if (existsUser(idUser)) {
removeUser(idUser);
}
SseEmitter sseEmitter = new SseEmitter(0L);
sseEmitter.onError((err) -> {
log.error("type: SseSession Error, msg: {} session Id : {}", err.getMessage(), idUser);
onError(idUser, err);
});
sseEmitter.onTimeout(() -> {
log.info("type: SseSession Timeout, session Id : {}", idUser);
removeUser(idUser);
});
sseEmitter.onCompletion(() -> {
log.info("type: SseSession Completion, session Id : {}", idUser);
removeUser(idUser);
});
addUser(idUser, sseEmitter);
return sseEmitter;
}
@Override
public boolean send(Long idUser, String content) {
if (existsUser(idUser)) {
try {
sendMessage(idUser, content);
return true;
} catch (IOException exception) {
log.error("type: SseSession send Error:IOException, msg: {} session Id : {}", exception.getMessage(), idUser);
}
} else {
throw new IllegalArgumentException("User Id " + idUser + " not Found");
}
return false;
}
@Override
public void close(Long idUser) {
log.info("type: SseSession Close, session Id : {}", idUser);
removeUser(idUser);
}
private void addUser(Long idUser, SseEmitter sseEmitter) {
sessionMap.put(idUser, sseEmitter);
}
private void onError(Long sessionKey, Throwable throwable) {
SseEmitter sseEmitter = sessionMap.get(sessionKey);
if (sseEmitter != null) {
sseEmitter.completeWithError(throwable);
}
}
private void removeUser(Long idUser) {
sessionMap.remove(idUser);
}
private boolean existsUser(Long idUser) {
return sessionMap.containsKey(idUser);
}
private void sendMessage(Long idUser, String content) throws IOException {
sessionMap.get(idUser).send(content);
}
}