diff --git a/pom.xml b/pom.xml
index 3f4d58f..1719943 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,7 +152,7 @@
com.alibaba
fastjson
- 2.0.20
+ 2.0.25
@@ -305,12 +305,12 @@
cn.hutool
hutool-core
- 5.8.11
+ 5.8.19
cn.hutool
hutool-http
- 5.8.11
+ 5.8.19
diff --git a/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java b/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java
index 1a07fac..bedebd2 100644
--- a/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java
+++ b/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java
@@ -1,6 +1,6 @@
package com.rymcu.forest.config;
-import com.alibaba.fastjson.support.spring.FastJsonJsonView;
+import com.alibaba.fastjson.support.spring.annotation.FastJsonView;
import com.rymcu.forest.core.exception.BusinessException;
import com.rymcu.forest.core.exception.ServiceException;
import com.rymcu.forest.core.exception.TransactionException;
@@ -47,7 +47,7 @@ public class BaseExceptionHandler {
result = new GlobalResult<>(ResultCode.UNAUTHORIZED);
logger.info("用户无权限");
} else if (ex instanceof UnknownAccountException) {
- // 账号或密码错误
+ // 未知账号
result = new GlobalResult<>(ResultCode.UNKNOWN_ACCOUNT);
logger.info(ex.getMessage());
} else if (ex instanceof AccountException) {
@@ -91,7 +91,7 @@ public class BaseExceptionHandler {
return result;
} else {
ModelAndView mv = new ModelAndView();
- FastJsonJsonView view = new FastJsonJsonView();
+ FastJsonView view = new FastJsonView();
Map attributes = new HashMap(2);
if (ex instanceof UnauthenticatedException) {
attributes.put("code", ResultCode.UNAUTHENTICATED.getCode());
@@ -99,6 +99,16 @@ public class BaseExceptionHandler {
} else if (ex instanceof UnauthorizedException) {
attributes.put("code", ResultCode.UNAUTHORIZED.getCode());
attributes.put("message", ResultCode.UNAUTHORIZED.getMessage());
+ } else if (ex instanceof UnknownAccountException) {
+ // 未知账号
+ attributes.put("code", ResultCode.UNKNOWN_ACCOUNT.getCode());
+ attributes.put("message", ex.getMessage());
+ logger.info(ex.getMessage());
+ } else if (ex instanceof AccountException) {
+ // 账号或密码错误
+ attributes.put("code", ResultCode.INCORRECT_ACCOUNT_OR_PASSWORD.getCode());
+ attributes.put("message", ex.getMessage());
+ logger.info(ex.getMessage());
} else if (ex instanceof ServiceException) {
//业务失败的异常,如“账号或密码错误”
attributes.put("code", ((ServiceException) ex).getCode());
diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java
index b07ae55..f70c949 100644
--- a/src/main/java/com/rymcu/forest/openai/OpenAiController.java
+++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java
@@ -4,8 +4,10 @@ import com.alibaba.fastjson.JSONObject;
import com.rymcu.forest.core.result.GlobalResult;
import com.rymcu.forest.core.result.GlobalResultGenerator;
import com.rymcu.forest.entity.User;
+import com.rymcu.forest.openai.entity.ChatMessageModel;
import com.rymcu.forest.openai.service.OpenAiService;
import com.rymcu.forest.openai.service.SseService;
+import com.rymcu.forest.util.Html2TextUtil;
import com.rymcu.forest.util.UserUtils;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
@@ -13,6 +15,7 @@ import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import io.reactivex.Flowable;
import org.apache.commons.lang.StringUtils;
+import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@@ -22,6 +25,7 @@ import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.time.Duration;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
/**
@@ -50,9 +54,36 @@ public class OpenAiController {
ChatMessage chatMessage = new ChatMessage("user", message);
List list = new ArrayList<>(4);
list.add(chatMessage);
+ return sendMessage(user, list);
+ }
+
+ @PostMapping("/new-chat")
+ public GlobalResult newChat(@RequestBody List messages) {
+ if (messages.isEmpty()) {
+ throw new IllegalArgumentException("参数异常!");
+ }
+ User user = UserUtils.getCurrentUserByToken();
+ Collections.reverse(messages);
+ List list = new ArrayList<>(messages.size());
+ if (messages.size() > 4) {
+ messages = messages.subList(messages.size() - 4, messages.size());
+ }
+ if (messages.size() >= 4 && messages.size() % 4 == 0) {
+ ChatMessage message = new ChatMessage("system", "简单总结一下你和用户的对话, 用作后续的上下文提示 prompt, 控制在 200 字内");
+ list.add(message);
+ }
+ messages.forEach(chatMessageModel -> {
+ ChatMessage message = new ChatMessage(chatMessageModel.getRole(), Html2TextUtil.getContent(chatMessageModel.getContent()));
+ list.add(message);
+ });
+ return sendMessage(user, list);
+ }
+
+ @NotNull
+ private GlobalResult sendMessage(User user, List list) {
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
- .model("gpt-3.5-turbo")
+ .model("gpt-3.5-turbo-16k-0613")
.stream(true)
.messages(list)
.build();
diff --git a/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java
new file mode 100644
index 0000000..908f7ac
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java
@@ -0,0 +1,26 @@
+package com.rymcu.forest.openai.entity;
+
+import lombok.Data;
+
+/**
+ * Created on 2023/7/16 14:52.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.openai.entity
+ */
+@Data
+public class ChatMessageModel {
+
+ Long dataId;
+
+ String to;
+
+ String from;
+
+ Integer dataType;
+
+ String content;
+
+ String role;
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java b/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java
new file mode 100644
index 0000000..ea8aab7
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java
@@ -0,0 +1,29 @@
+package com.rymcu.forest.openai.service;
+
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+
+import java.io.IOException;
+
+/**
+ * OkHttp Interceptor that adds an ip address header
+ * @author ronger
+ */
+public class IpAddressInterceptor implements Interceptor {
+
+ private final String ip;
+
+ IpAddressInterceptor(String ip) {
+ this.ip = ip;
+ }
+
+ @Override
+ public Response intercept(Chain chain) throws IOException {
+ Request request = chain.request()
+ .newBuilder()
+ .header("x-forwarded-for", ip)
+ .build();
+ return chain.proceed(request);
+ }
+}
diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
index ad468b8..b9d00fb 100644
--- a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
+++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java
@@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.rymcu.forest.util.SpringContextHolder;
+import com.rymcu.forest.util.Utils;
import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.OpenAiApi;
import com.theokanning.openai.OpenAiError;
@@ -30,18 +31,20 @@ import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
-
import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.Single;
import okhttp3.*;
import org.springframework.core.env.Environment;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import retrofit2.Call;
import retrofit2.HttpException;
import retrofit2.Retrofit;
-import retrofit2.Call;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
+import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
@@ -336,8 +339,11 @@ public class OpenAiService {
}
public static OkHttpClient defaultClient(String token, Duration timeout) {
+ HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
+ String ip = Utils.getIpAddress(request);
return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
+ .addInterceptor(new IpAddressInterceptor(ip))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
.readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
.build();