Răsfoiți Sursa

保存历史对话

zhenghao 5 luni în urmă
părinte
comite
bea40e33a4

+ 11 - 14
src/main/java/cn/jlsxwkj/moudles/chat/ChatController.java

@@ -23,7 +23,7 @@ import java.util.List;
 public class ChatController {
 
 	@Resource
-	private ChatService documentService;
+	private ChatService chatService;
 	@Resource
 	private ChatHistoryService chatHistoryService;
 
@@ -31,23 +31,14 @@ public class ChatController {
 	@PostMapping("/upload")
 	public String uploadDoc(
 			@Parameter(name = "file", description = "文件") @RequestBody MultipartFile file) throws IOException {
-		return documentService.uploadDocument(file);
+		return chatService.uploadDocument(file);
 	}
 
 	@Operation(summary = "搜索文档")
 	@PostMapping("/search")
 	public String searchDoc(
 			@Parameter(name = "keyword", description = "关键词") @RequestParam String keyword) {
-		return documentService.search(keyword);
-	}
-
-	@Operation(summary = "插入历史对话")
-	@PostMapping("/setHistory")
-	public String insertHistory(
-			@Parameter(name = "userId", description = "用户id") @RequestParam String userId,
-			@Parameter(name = "userQ", description = "用户提问") @RequestParam String userQ,
-			@Parameter(name = "chatA", description = "机器人回答") @RequestParam String chatA) {
-		return chatHistoryService.insert(userId, userQ, chatA);
+		return chatService.search(keyword);
 	}
 
 	@Operation(summary = "获取历史对话")
@@ -57,11 +48,17 @@ public class ChatController {
 		return chatHistoryService.selectHistoryByUserId(userId);
 	}
 
+	@Operation(summary = "保存历史对话")
+	@PostMapping("/getHistory")
+	public String saveHistory() {
+		return chatService.saveDialog();
+	}
+
 	@Operation(summary = "问答文档流")
 	@PostMapping(value = "/chatStream", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
 	public Flux<String> chatStream(
 			@Parameter(name = "message", description = "消息") @RequestParam String message) {
-		return documentService.chatStream(message);
+		return chatService.chatStream(message);
 	}
 
 	@Operation(summary = "问答文档")
@@ -69,6 +66,6 @@ public class ChatController {
 	@Deprecated
 	public String chat(
 			@Parameter(name = "message", description = "消息") @RequestParam String message) {
-		return documentService.chat(message);
+		return chatService.chat(message);
 	}
 }

+ 40 - 21
src/main/java/cn/jlsxwkj/moudles/chat/ChatService.java

@@ -5,12 +5,10 @@ import cn.jlsxwkj.common.reader.ParagraphDocReader;
 import cn.jlsxwkj.common.reader.ParagraphTextReader;
 import cn.jlsxwkj.common.utils.Log;
 import cn.jlsxwkj.common.utils.MergeDocuments;
+import cn.jlsxwkj.moudles.chathistory.ChatHistory;
+import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
 import jakarta.annotation.Resource;
-import org.springframework.ai.chat.messages.AssistantMessage;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.messages.*;
 import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.document.Document;
 import org.springframework.ai.document.DocumentReader;
@@ -47,6 +45,11 @@ public class ChatService {
 	@Resource
 	private OllamaChatModel ollamaChatModel;
 	/**
+	 * 历史记录
+	 */
+	@Resource
+	private ChatHistoryService chatHistoryService;
+	/**
 	 * md5 校验
 	 */
 	private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
@@ -58,10 +61,33 @@ public class ChatService {
 	 * 用户对话上下文
 	 */
 	private static final List<Message> LIST_MESSAGE = new ArrayList<>();
-
+	/**
+	 * 单条消息
+	 */
 	private static StringBuffer chatMessage;
 
 	/**
+	 * 保存会话
+	 * @return 保存条数
+	 */
+	@SuppressWarnings("unused")
+	public String saveDialog() {
+		List<ChatHistory> chatHistoryList = new ArrayList<>();
+		List<Message> userMessageList = LIST_MESSAGE.stream().filter(x -> x.getMessageType().equals(MessageType.USER)).toList();
+		List<Message> chatMessageList = LIST_MESSAGE.stream().filter(x -> x.getMessageType().equals(MessageType.ASSISTANT)).toList();
+		for (int i = 0; i < userMessageList.size(); i++) {
+			String userMessage = userMessageList.get(i).getContent();
+			String chatMessage = chatMessageList.get(i).getContent();
+			ChatHistory chatHistory = new ChatHistory();
+			chatHistory.setUserId("test");
+			chatHistory.setUserQ(userMessage);
+			chatHistory.setChatA(chatMessage);
+			chatHistoryList.add(chatHistory);
+		}
+		return chatHistoryService.insertAll(chatHistoryList);
+	}
+
+	/**
 	 * 使用spring ai解析txt文档并保存至 pg
 	 * @param file 保存的文件
 	 * @return 保存状态
@@ -114,14 +140,9 @@ public class ChatService {
 	 * @return 			文本内容
 	 */
 	public String search(String keyword) {
-		//相似度不少于 50%
-		SearchRequest searchRequest = SearchRequest.query(keyword).withSimilarityThreshold(0.5);
-		//提取文本内容
-		List<Document> documents = vectorStore.similaritySearch(searchRequest);
-		//合并文档
-		List<Document> mergeDocuments = MergeDocuments.mergeDocuments(documents);
-		//获取文档中的文本并以换行分割
-		return mergeDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
+		return MergeDocuments.mergeDocuments(
+				vectorStore.similaritySearch(SearchRequest.query(keyword).withSimilarityThreshold(0.5))
+		).stream().map(Document::getContent).collect(Collectors.joining("\n"));
 	}
 
 	/**
@@ -136,14 +157,13 @@ public class ChatService {
 		String content = search(message);
 		//封装prompt并调用大模型
 		Prompt chatPrompt2String = getChatPrompt2String(message, content);
-		return ollamaChatModel.stream(chatPrompt2String)
-				.map(response -> {
+		return ollamaChatModel.stream(chatPrompt2String).map(
+				response -> {
 					String str = response.getResult().getOutput().getContent();
 					chatMessage.append(str);
 					return str.replace(" ", "  ");
 				})
-				.concatWithValues("<{完成}>")
-				.onErrorStop()
+				.concatWithValues("<{完成}>").onErrorStop()
 				.doOnComplete(() -> {
 					LIST_MESSAGE.add(new AssistantMessage(chatMessage.toString()));
 					LIST_MESSAGE.forEach(x -> Log.info(this.getClass(),
@@ -164,10 +184,9 @@ public class ChatService {
 		chatMessage = new StringBuffer();
 		//查询获取文档信息
 		String content = search(message);
-		//封装prompt并调用大模型
-		ChatResponse call = ollamaChatModel.call(getChatPrompt2String(message, content));
 		//获取消息
-		String result = call.getResult().getOutput().getContent();
+		String result = ollamaChatModel.call(getChatPrompt2String(message, content))
+				.getResult().getOutput().getContent();
 		chatMessage.append(result);
 		LIST_MESSAGE.add(new AssistantMessage(chatMessage.toString()));
 		LIST_MESSAGE.forEach(x -> Log.info(this.getClass(),

+ 9 - 0
src/main/java/cn/jlsxwkj/moudles/chathistory/ChatHistoryMapper.java

@@ -28,6 +28,15 @@ public interface ChatHistoryMapper {
                    @Param("userQ") String userQ,
                    @Param("chatA") String chatA);
 
+    @Insert("""
+            insert into chat_history(
+                user_id,
+                user_Q,
+                chat_A
+            ) values #{items}
+            """)
+    Integer insertAll(@Param("items") String items);
+
     @Select("""
             select 
                 id, user_id, 

+ 19 - 2
src/main/java/cn/jlsxwkj/moudles/chathistory/ChatHistoryService.java

@@ -14,10 +14,27 @@ public class ChatHistoryService {
     @Resource
     private ChatHistoryMapper chatHistoryMapper;
 
-    public String insert(String userId,
+    @SuppressWarnings("unused")
+    public Integer insert(String userId,
                          String userQ,
                          String chatA) {
-        return chatHistoryMapper.insert(userId, userQ, chatA) != 1 ? "插入失败" : "插入成功";
+        return chatHistoryMapper.insert(userId, userQ, chatA);
+    }
+
+    public String insertAll(List<ChatHistory> items) {
+        StringBuilder values = new StringBuilder();
+        items.forEach(item -> {
+            values.append(String.format("""
+                            ('{}','{}','{}'),
+                            """,
+                    item.getUserId(),
+                    item.getUserQ(),
+                    item.getChatA()
+            ));
+        });
+        values.deleteCharAt(values.lastIndexOf(","));
+        Integer rows = chatHistoryMapper.insertAll(values.toString());
+        return "插入" + rows + "条";
     }
 
     public List<ChatHistory> selectHistoryByUserId(String userId) {