🎨 增加模型选择功能

This commit is contained in:
ronger 2024-08-06 10:50:42 +08:00
parent fc2bf19b44
commit ba374888eb
3 changed files with 32 additions and 28 deletions

View File

@ -1,10 +1,10 @@
package com.rymcu.forest.openai; package com.rymcu.forest.openai;
import com.alibaba.fastjson.JSONObject;
import com.rymcu.forest.core.result.GlobalResult; import com.rymcu.forest.core.result.GlobalResult;
import com.rymcu.forest.core.result.GlobalResultGenerator; import com.rymcu.forest.core.result.GlobalResultGenerator;
import com.rymcu.forest.entity.User; import com.rymcu.forest.entity.User;
import com.rymcu.forest.openai.entity.ChatMessageModel; import com.rymcu.forest.openai.entity.ChatMessageModel;
import com.rymcu.forest.openai.entity.ChatModel;
import com.rymcu.forest.openai.service.OpenAiService; import com.rymcu.forest.openai.service.OpenAiService;
import com.rymcu.forest.openai.service.SseService; import com.rymcu.forest.openai.service.SseService;
import com.rymcu.forest.util.Html2TextUtil; import com.rymcu.forest.util.Html2TextUtil;
@ -22,7 +22,6 @@ import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
/** /**
@ -42,20 +41,9 @@ public class OpenAiController {
private String token; private String token;
@PostMapping("/chat") @PostMapping("/chat")
public GlobalResult chat(@RequestBody JSONObject jsonObject) { public GlobalResult newChat(@RequestBody ChatModel chatModel) {
String message = jsonObject.getString("message"); List<ChatMessageModel> messages = chatModel.getMessages();
if (StringUtils.isBlank(message)) { String model = chatModel.getModel();
throw new IllegalArgumentException("参数异常!");
}
User user = UserUtils.getCurrentUserByToken();
ChatMessage chatMessage = new ChatMessage("user", message);
List<ChatMessage> list = new ArrayList<>(4);
list.add(chatMessage);
return sendMessage(user, list);
}
@PostMapping("/new-chat")
public GlobalResult newChat(@RequestBody List<ChatMessageModel> messages) {
if (messages.isEmpty()) { if (messages.isEmpty()) {
throw new IllegalArgumentException("参数异常!"); throw new IllegalArgumentException("参数异常!");
} }
@ -72,15 +60,17 @@ public class OpenAiController {
ChatMessage message = new ChatMessage(chatMessageModel.getRole(), Html2TextUtil.getContent(chatMessageModel.getContent())); ChatMessage message = new ChatMessage(chatMessageModel.getRole(), Html2TextUtil.getContent(chatMessageModel.getContent()));
list.add(message); list.add(message);
}); });
return sendMessage(user, list); return sendMessage(user, list, model);
} }
@NotNull @NotNull
private GlobalResult sendMessage(User user, List<ChatMessage> list) { private GlobalResult sendMessage(User user, List<ChatMessage> list, String model) {
boolean isAdmin = UserUtils.isAdmin(user.getEmail()); if (StringUtils.isBlank(model)) {
model = "gpt-3.5-turbo-16k-0613";
}
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180)); OpenAiService service = new OpenAiService(token, Duration.ofSeconds(180));
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder() ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
.model("gpt-3.5-turbo-16k-0613") .model(model)
.stream(true) .stream(true)
.messages(list) .messages(list)
.build(); .build();

View File

@ -0,0 +1,21 @@
package com.rymcu.forest.openai.entity;
import lombok.Data;
import java.util.List;
/**
* Created on 2024/8/6 10:24.
*
* @author ronger
* @email ronger-x@outlook.com
* @desc : com.rymcu.forest.openai.entity
*/
@Data
public class ChatModel {
String model;
List<ChatMessageModel> messages;
}

View File

@ -8,16 +8,11 @@ import com.rymcu.forest.dto.LinkToImageUrlDTO;
import com.rymcu.forest.dto.TokenUser; import com.rymcu.forest.dto.TokenUser;
import com.rymcu.forest.enumerate.FilePath; import com.rymcu.forest.enumerate.FilePath;
import com.rymcu.forest.service.ForestFileService; import com.rymcu.forest.service.ForestFileService;
import com.rymcu.forest.util.FileUtils; import com.rymcu.forest.util.*;
import com.rymcu.forest.util.SpringContextHolder;
import com.rymcu.forest.util.UserUtils;
import com.rymcu.forest.util.Utils;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.apache.shiro.authz.UnauthorizedException; import org.apache.shiro.authz.UnauthorizedException;
import org.apache.shiro.authz.annotation.Logical;
import org.apache.shiro.authz.annotation.RequiresPermissions; import org.apache.shiro.authz.annotation.RequiresPermissions;
import org.apache.shiro.authz.annotation.RequiresRoles;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@ -34,8 +29,6 @@ import java.net.HttpURLConnection;
import java.net.URL; import java.net.URL;
import java.util.*; import java.util.*;
import com.rymcu.forest.util.SSRFUtil;
/** /**
* 文件上传控制器 * 文件上传控制器
* *