diff --git a/pom.xml b/pom.xml index 3f4d58f..1719943 100644 --- a/pom.xml +++ b/pom.xml @@ -152,7 +152,7 @@ com.alibaba fastjson - 2.0.20 + 2.0.25 @@ -305,12 +305,12 @@ cn.hutool hutool-core - 5.8.11 + 5.8.19 cn.hutool hutool-http - 5.8.11 + 5.8.19 diff --git a/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java b/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java index 1a07fac..bedebd2 100644 --- a/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java +++ b/src/main/java/com/rymcu/forest/config/BaseExceptionHandler.java @@ -1,6 +1,6 @@ package com.rymcu.forest.config; -import com.alibaba.fastjson.support.spring.FastJsonJsonView; +import com.alibaba.fastjson.support.spring.annotation.FastJsonView; import com.rymcu.forest.core.exception.BusinessException; import com.rymcu.forest.core.exception.ServiceException; import com.rymcu.forest.core.exception.TransactionException; @@ -47,7 +47,7 @@ public class BaseExceptionHandler { result = new GlobalResult<>(ResultCode.UNAUTHORIZED); logger.info("用户无权限"); } else if (ex instanceof UnknownAccountException) { - // 账号或密码错误 + // 未知账号 result = new GlobalResult<>(ResultCode.UNKNOWN_ACCOUNT); logger.info(ex.getMessage()); } else if (ex instanceof AccountException) { @@ -91,7 +91,7 @@ public class BaseExceptionHandler { return result; } else { ModelAndView mv = new ModelAndView(); - FastJsonJsonView view = new FastJsonJsonView(); + FastJsonView view = new FastJsonView(); Map attributes = new HashMap(2); if (ex instanceof UnauthenticatedException) { attributes.put("code", ResultCode.UNAUTHENTICATED.getCode()); @@ -99,6 +99,16 @@ public class BaseExceptionHandler { } else if (ex instanceof UnauthorizedException) { attributes.put("code", ResultCode.UNAUTHORIZED.getCode()); attributes.put("message", ResultCode.UNAUTHORIZED.getMessage()); + } else if (ex instanceof UnknownAccountException) { + // 未知账号 + attributes.put("code", ResultCode.UNKNOWN_ACCOUNT.getCode()); + attributes.put("message", ex.getMessage()); + logger.info(ex.getMessage()); + } else if (ex instanceof AccountException) { + // 账号或密码错误 + attributes.put("code", ResultCode.INCORRECT_ACCOUNT_OR_PASSWORD.getCode()); + attributes.put("message", ex.getMessage()); + logger.info(ex.getMessage()); } else if (ex instanceof ServiceException) { //业务失败的异常,如“账号或密码错误” attributes.put("code", ((ServiceException) ex).getCode()); diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java index b07ae55..f70c949 100644 --- a/src/main/java/com/rymcu/forest/openai/OpenAiController.java +++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java @@ -4,8 +4,10 @@ import com.alibaba.fastjson.JSONObject; import com.rymcu.forest.core.result.GlobalResult; import com.rymcu.forest.core.result.GlobalResultGenerator; import com.rymcu.forest.entity.User; +import com.rymcu.forest.openai.entity.ChatMessageModel; import com.rymcu.forest.openai.service.OpenAiService; import com.rymcu.forest.openai.service.SseService; +import com.rymcu.forest.util.Html2TextUtil; import com.rymcu.forest.util.UserUtils; import com.theokanning.openai.completion.chat.ChatCompletionChoice; import com.theokanning.openai.completion.chat.ChatCompletionChunk; @@ -13,6 +15,7 @@ import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatMessage; import io.reactivex.Flowable; import org.apache.commons.lang.StringUtils; +import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -22,6 +25,7 @@ import org.springframework.web.bind.annotation.RestController; import javax.annotation.Resource; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -50,9 +54,36 @@ public class OpenAiController { ChatMessage chatMessage = new ChatMessage("user", message); List list = new ArrayList<>(4); list.add(chatMessage); + return sendMessage(user, list); + } + + @PostMapping("/new-chat") + public GlobalResult newChat(@RequestBody List messages) { + if (messages.isEmpty()) { + throw new IllegalArgumentException("参数异常!"); + } + User user = UserUtils.getCurrentUserByToken(); + Collections.reverse(messages); + List list = new ArrayList<>(messages.size()); + if (messages.size() > 4) { + messages = messages.subList(messages.size() - 4, messages.size()); + } + if (messages.size() >= 4 && messages.size() % 4 == 0) { + ChatMessage message = new ChatMessage("system", "简单总结一下你和用户的对话, 用作后续的上下文提示 prompt, 控制在 200 字内"); + list.add(message); + } + messages.forEach(chatMessageModel -> { + ChatMessage message = new ChatMessage(chatMessageModel.getRole(), Html2TextUtil.getContent(chatMessageModel.getContent())); + list.add(message); + }); + return sendMessage(user, list); + } + + @NotNull + private GlobalResult sendMessage(User user, List list) { OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() - .model("gpt-3.5-turbo") + .model("gpt-3.5-turbo-16k-0613") .stream(true) .messages(list) .build(); diff --git a/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java new file mode 100644 index 0000000..908f7ac --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java @@ -0,0 +1,26 @@ +package com.rymcu.forest.openai.entity; + +import lombok.Data; + +/** + * Created on 2023/7/16 14:52. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai.entity + */ +@Data +public class ChatMessageModel { + + Long dataId; + + String to; + + String from; + + Integer dataType; + + String content; + + String role; +} diff --git a/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java b/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java new file mode 100644 index 0000000..ea8aab7 --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/service/IpAddressInterceptor.java @@ -0,0 +1,29 @@ +package com.rymcu.forest.openai.service; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; + +/** + * OkHttp Interceptor that adds an ip address header + * @author ronger + */ +public class IpAddressInterceptor implements Interceptor { + + private final String ip; + + IpAddressInterceptor(String ip) { + this.ip = ip; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request() + .newBuilder() + .header("x-forwarded-for", ip) + .build(); + return chain.proceed(request); + } +} diff --git a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java index ad468b8..b9d00fb 100644 --- a/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java +++ b/src/main/java/com/rymcu/forest/openai/service/OpenAiService.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; import com.rymcu.forest.util.SpringContextHolder; +import com.rymcu.forest.util.Utils; import com.theokanning.openai.DeleteResult; import com.theokanning.openai.OpenAiApi; import com.theokanning.openai.OpenAiError; @@ -30,18 +31,20 @@ import com.theokanning.openai.image.ImageResult; import com.theokanning.openai.model.Model; import com.theokanning.openai.moderation.ModerationRequest; import com.theokanning.openai.moderation.ModerationResult; - import io.reactivex.BackpressureStrategy; import io.reactivex.Flowable; import io.reactivex.Single; import okhttp3.*; import org.springframework.core.env.Environment; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import retrofit2.Call; import retrofit2.HttpException; import retrofit2.Retrofit; -import retrofit2.Call; import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; import retrofit2.converter.jackson.JacksonConverterFactory; +import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.time.Duration; import java.util.List; @@ -336,8 +339,11 @@ public class OpenAiService { } public static OkHttpClient defaultClient(String token, Duration timeout) { + HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); + String ip = Utils.getIpAddress(request); return new OkHttpClient.Builder() .addInterceptor(new AuthenticationInterceptor(token)) + .addInterceptor(new IpAddressInterceptor(ip)) .connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS)) .readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) .build();