diff --git a/src/main/java/com/rymcu/forest/openai/OpenAiController.java b/src/main/java/com/rymcu/forest/openai/OpenAiController.java index b07ae55..6820373 100644 --- a/src/main/java/com/rymcu/forest/openai/OpenAiController.java +++ b/src/main/java/com/rymcu/forest/openai/OpenAiController.java @@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject; import com.rymcu.forest.core.result.GlobalResult; import com.rymcu.forest.core.result.GlobalResultGenerator; import com.rymcu.forest.entity.User; +import com.rymcu.forest.openai.entity.ChatMessageModel; import com.rymcu.forest.openai.service.OpenAiService; import com.rymcu.forest.openai.service.SseService; import com.rymcu.forest.util.UserUtils; @@ -13,6 +14,7 @@ import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatMessage; import io.reactivex.Flowable; import org.apache.commons.lang.StringUtils; +import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -22,6 +24,7 @@ import org.springframework.web.bind.annotation.RestController; import javax.annotation.Resource; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -50,6 +53,33 @@ public class OpenAiController { ChatMessage chatMessage = new ChatMessage("user", message); List list = new ArrayList<>(4); list.add(chatMessage); + return sendMessage(user, list); + } + + @PostMapping("/new-chat") + public GlobalResult newChat(@RequestBody List messages) { + if (messages.isEmpty()) { + throw new IllegalArgumentException("参数异常!"); + } + User user = UserUtils.getCurrentUserByToken(); + Collections.reverse(messages); + List list = new ArrayList<>(messages.size()); + if (messages.size() > 4) { + messages = messages.subList(messages.size() - 4, messages.size()); + } + if (messages.size() >= 4 && messages.size() % 4 == 0) { + ChatMessage message = new ChatMessage("system", "简单总结一下你和用户的对话, 用作后续的上下文提示 prompt, 控制在 200 字内"); + list.add(message); + } + messages.forEach(chatMessageModel -> { + ChatMessage message = new ChatMessage(chatMessageModel.getRole(), chatMessageModel.getContent()); + list.add(message); + }); + return sendMessage(user, list); + } + + @NotNull + private GlobalResult sendMessage(User user, List list) { OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() .model("gpt-3.5-turbo") diff --git a/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java new file mode 100644 index 0000000..908f7ac --- /dev/null +++ b/src/main/java/com/rymcu/forest/openai/entity/ChatMessageModel.java @@ -0,0 +1,26 @@ +package com.rymcu.forest.openai.entity; + +import lombok.Data; + +/** + * Created on 2023/7/16 14:52. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.openai.entity + */ +@Data +public class ChatMessageModel { + + Long dataId; + + String to; + + String from; + + Integer dataType; + + String content; + + String role; +}