Explorar el Código

对上下文数量及大小进行限制

zhenghao hace 5 meses
padre
commit
9febac5192

+ 1 - 1
src/main/java/cn/jlsxwkj/common/handler/GlobalRestExceptionHandler.java

@@ -59,7 +59,7 @@ public class GlobalRestExceptionHandler {
         Log.error(e.getClass(), "Exception  ====> {}", responseError.getException());
         Log.error(e.getClass(), "Message    ====> {}", responseError.getMessage());
 
-        e.printStackTrace();
+        Log.info(this.getClass(), Arrays.toString(e.getStackTrace()));
 
         return responseError;
     }

+ 1 - 1
src/main/java/cn/jlsxwkj/moudles/chat/ChatController.java

@@ -42,7 +42,7 @@ public class ChatController {
 
     @Operation(summary = "问答文档总结流")
     @PostMapping(value = "/chatSummaryStream", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
-    public Flux<String> chatSummaryStream(@RequestParam(defaultValue = "帮我总结一下") String message,
+    public Flux<String> chatSummaryStream(@RequestParam(defaultValue = "帮我总结一下上面内容") String message,
                                       @RequestBody MultipartFile file) throws CustomException {
         return chatService.chatSummaryStream(message, file);
     }

+ 30 - 12
src/main/java/cn/jlsxwkj/moudles/chat/ChatService.java

@@ -48,7 +48,9 @@ public class ChatService {
 
     private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
     private static final String PATH = System.getProperty("user.dir") + "/ai_doc/";
-    private static final List<Message> LIST_MESSAGE = new ArrayList<>();
+    private static final int MAX_LIST_MESSAGE_LENGTH = 3;
+    private static final int MAX_MESSAGE_LENGTH = 512;
+    private static List<Message> listMessage = new ArrayList<>();
 
     @Resource
     private VectorStore vectorStore;
@@ -125,8 +127,10 @@ public class ChatService {
         String fileName = supportFile.get(1);
         String fileTypeCn = supportFile.get(2);
         DocumentReader reader = supportFile.get(3);
-        String context = fileTypeCn + " : " + fileName + " ====> " + mergeDocuments(reader.get()).replace("\n", " ");
-        return stream(message, new UserMessage(context + " : " + message));
+        String context = fileTypeCn + " :: " + fileName + " ====> " + mergeDocuments(reader.get()) + "\n" + "↑ " + message;
+        UserMessage userMessage = new UserMessage(context);
+        Log.info(this.getClass(), context);
+        return stream(message, userMessage);
     }
 
     /**
@@ -137,8 +141,8 @@ public class ChatService {
     }
 
     private String mergeDocuments(List<Document> documents) {
-        return MergeDocuments.mergeDocuments(documents)
-                .stream().map(Document::getContent).collect(Collectors.joining("\n"));
+        return MergeDocuments.mergeDocuments(documents).stream()
+                .map(Document::getContent).collect(Collectors.joining("\n"));
     }
 
     /**
@@ -149,15 +153,28 @@ public class ChatService {
      * @return 提示词
      */
     private Prompt getChatPrompt2String(String message, Message context) {
-        ChatService.LIST_MESSAGE.add(new SystemMessage(userConfig.getSysMessage()));
+        if (listMessage.size() > MAX_LIST_MESSAGE_LENGTH) {
+            listMessage = listMessage.subList(listMessage.size() - MAX_LIST_MESSAGE_LENGTH, listMessage.size())
+                    .stream()
+                    .map(msg -> {
+                        if (msg.getMessageType().equals(MessageType.ASSISTANT)) {
+                            String content = msg.getContent();
+                            if (content.length() > MAX_MESSAGE_LENGTH) {
+                                return new AssistantMessage(content.substring(0, MAX_MESSAGE_LENGTH));
+                            }
+                        }
+                        return msg;
+                    }).collect(Collectors.toList());
+        }
+        listMessage.add(new SystemMessage(userConfig.getSysMessage()));
         if (context.getMessageType().equals(MessageType.SYSTEM)) {
-            ChatService.LIST_MESSAGE.add(new UserMessage(message));
-            ChatService.LIST_MESSAGE.add(context);
+            listMessage.add(new UserMessage(message));
+            listMessage.add(context);
         }
         if (context.getMessageType().equals(MessageType.USER)) {
-            ChatService.LIST_MESSAGE.add(context);
+            return new Prompt(context);
         }
-        return new Prompt(ChatService.LIST_MESSAGE);
+        return new Prompt(listMessage);
     }
 
     /**
@@ -175,12 +192,13 @@ public class ChatService {
                 response -> {
                     String str = response.getResult().getOutput().getContent();
                     chatMessage.append(str);
-                    return str.replace(" ", "  ");
+                    return str.replace(" ", "  ")
+                            .replace("\t", "    ");
                 }).concatWithValues("<{完成}>")
                 .onErrorComplete()
                 .doOnComplete(() -> {
                     String chatMsg = chatMessage.toString();
-                    ChatService.LIST_MESSAGE.add(new AssistantMessage(chatMsg));
+                    listMessage.add(new AssistantMessage(chatMsg));
                     try {
                         chatHistoryService.insert(userId, message, chatMsg);
                     } catch (InsertFailException e) {

+ 3 - 1
src/main/resources/application-dev.yml

@@ -12,8 +12,10 @@ spring:
       base-url: http://localhost:11434
       chat:
         model: qwen2:1.5b
+        options:
+          num_ctx: 10240
       embedding:
         model: mofanke/dmeta-embedding-zh
 customer:
   cnocr-url: "http://www.jlsxwkj.cn:8501/ocr"
-  sysMessage: "如果问你是谁, 说你是玄武科技公司是 ai 或者 大模型"
+  sysMessage: "如果问你是谁或者你叫什么, 说你是吉林省玄武科技公司是 ai 或者 大模型"