|
@@ -1,9 +1,15 @@
|
|
|
package cn.jlsxwkj.moudles.chat;
|
|
|
|
|
|
import cn.dev33.satoken.SaManager;
|
|
|
+import cn.hutool.core.lang.Tuple;
|
|
|
import cn.hutool.crypto.digest.MD5;
|
|
|
-import cn.jlsxwkj.common.exception.*;
|
|
|
+import cn.jlsxwkj.common.config.UserConfig;
|
|
|
+import cn.jlsxwkj.common.exception.CustomException;
|
|
|
+import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
|
|
|
+import cn.jlsxwkj.common.exception.InsertFailException;
|
|
|
+import cn.jlsxwkj.common.exception.UnknownException;
|
|
|
import cn.jlsxwkj.common.reader.ParagraphDocReader;
|
|
|
+import cn.jlsxwkj.common.reader.ParagraphOcrReader;
|
|
|
import cn.jlsxwkj.common.reader.ParagraphTextReader;
|
|
|
import cn.jlsxwkj.common.utils.Log;
|
|
|
import cn.jlsxwkj.common.utils.MergeDocuments;
|
|
@@ -11,10 +17,7 @@ import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.Data;
|
|
|
import org.reactivestreams.Subscription;
|
|
|
-import org.springframework.ai.chat.messages.AssistantMessage;
|
|
|
-import org.springframework.ai.chat.messages.Message;
|
|
|
-import org.springframework.ai.chat.messages.SystemMessage;
|
|
|
-import org.springframework.ai.chat.messages.UserMessage;
|
|
|
+import org.springframework.ai.chat.messages.*;
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.document.Document;
|
|
@@ -42,149 +45,190 @@ import java.util.stream.Collectors;
|
|
|
@Data
|
|
|
public class ChatService {
|
|
|
|
|
|
- /**
|
|
|
- * md5 校验
|
|
|
- */
|
|
|
- private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
|
|
|
- /**
|
|
|
- * 当前项目路径 + /ai_doc
|
|
|
- */
|
|
|
- private static final String PATH = Objects.requireNonNull(ChatService.class.getClassLoader().getResource("")).getPath() + "\\ai_doc\\";
|
|
|
- /**
|
|
|
- * 用户对话上下文
|
|
|
- */
|
|
|
- private static final List<Message> LIST_MESSAGE = new ArrayList<>();
|
|
|
- /**
|
|
|
- * 向量
|
|
|
- */
|
|
|
- @Resource
|
|
|
- private VectorStore vectorStore;
|
|
|
- /**
|
|
|
- * 大模型
|
|
|
- */
|
|
|
- @Resource
|
|
|
- private OllamaChatModel ollamaChatModel;
|
|
|
- /**
|
|
|
- * 历史记录
|
|
|
- */
|
|
|
- @Resource
|
|
|
- private ChatHistoryService chatHistoryService;
|
|
|
- /**
|
|
|
- * 单条消息
|
|
|
- */
|
|
|
- private static StringBuffer chatMessage;
|
|
|
- /**
|
|
|
- * 上游订阅
|
|
|
- */
|
|
|
- private Subscription subscription;
|
|
|
-
|
|
|
- /**
|
|
|
- * 使用spring ai解析txt文档并保存至 pg
|
|
|
- * @param file 保存的文件
|
|
|
- * @return 保存状态
|
|
|
- * @throws IOException 保存失败
|
|
|
- */
|
|
|
- public String uploadDocument(MultipartFile file) throws CustomException, IOException {
|
|
|
- //分割文件名为 文件名, 后缀
|
|
|
- String[] split = Objects.requireNonNull(file.getOriginalFilename()).split("\\.");
|
|
|
- //获取文件类型
|
|
|
- String fileType = split[split.length - 1].toLowerCase(Locale.ROOT);
|
|
|
- //文件夹地址
|
|
|
- File savePath = new File(PATH);
|
|
|
- //保存文件对象
|
|
|
- File saveFile = new File(PATH + ChatService.MD5.digestHex(file.getInputStream()) + "." + fileType);
|
|
|
- //文件保存地址
|
|
|
- String fileUrl = saveFile.toURI().toURL().toString();
|
|
|
- //判断文件夹是否存在
|
|
|
- if (!savePath.exists()) {
|
|
|
- if (!savePath.mkdirs()) {
|
|
|
- Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
|
- }
|
|
|
- }
|
|
|
- //判断文件类型
|
|
|
- DocumentReader reader;
|
|
|
- switch (fileType) {
|
|
|
- case "txt" -> reader = new ParagraphTextReader(fileUrl, 5);
|
|
|
- case "doc", "docx" -> reader = new ParagraphDocReader(fileUrl, 5);
|
|
|
- default -> throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
|
- }
|
|
|
- // 判断文件是否存在
|
|
|
- if (!saveFile.exists()) {
|
|
|
- try {
|
|
|
- file.transferTo(saveFile);
|
|
|
- vectorStore.add(reader.get());
|
|
|
- return "上传完成";
|
|
|
- } catch (Exception e){
|
|
|
- boolean delete = saveFile.delete();
|
|
|
- if (!delete) {
|
|
|
- throw new FileDeleteFailException("删除文件失败: " + fileUrl);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- throw new UnknownException("未知错误");
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 根据关键词搜索向量库
|
|
|
- *
|
|
|
- * @param keyword 关键词
|
|
|
- * @return 文本内容
|
|
|
- */
|
|
|
- public String search(String keyword) {
|
|
|
- return MergeDocuments.mergeDocuments(
|
|
|
- vectorStore.similaritySearch(SearchRequest.query(keyword).withSimilarityThreshold(0.5))
|
|
|
- ).stream().map(Document::getContent).collect(Collectors.joining("\n"));
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 问答流,根据输入内容回答
|
|
|
- *
|
|
|
- * @param message 输入内容
|
|
|
- * @return 回答内容
|
|
|
- */
|
|
|
- public Flux<String> chatStream(String message) {
|
|
|
- String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
|
- chatMessage = new StringBuffer();
|
|
|
- //查询获取文档信息
|
|
|
- String content = search(message);
|
|
|
- //封装prompt并调用大模型
|
|
|
- Flux<ChatResponse> stream = this.ollamaChatModel.stream(getChatPrompt2String(message, content));
|
|
|
- return stream.doOnSubscribe(subscription -> this.subscription = subscription)
|
|
|
- .map(response -> {
|
|
|
- String str = response.getResult().getOutput().getContent();
|
|
|
- chatMessage.append(str);
|
|
|
- return str.replace(" ", " ");
|
|
|
- }).concatWithValues("<{完成}>")
|
|
|
- .onErrorComplete()
|
|
|
- .doOnComplete(() -> {
|
|
|
- LIST_MESSAGE.add(new AssistantMessage(chatMessage.toString()));
|
|
|
- try {
|
|
|
- chatHistoryService.insert(userId, message, chatMessage.toString());
|
|
|
- } catch (InsertFailException e) {
|
|
|
- e.printStackTrace();
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 取消订阅
|
|
|
- */
|
|
|
- public void stopChat() {
|
|
|
- subscription.cancel();
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 获取prompt
|
|
|
- *
|
|
|
- * @param message 提问内容
|
|
|
- * @param context 上下文
|
|
|
- * @return 提示词
|
|
|
- */
|
|
|
- private Prompt getChatPrompt2String(String message, String context) {
|
|
|
- Log.info(this.getClass(),"{} message : {}", "user", message);
|
|
|
- LIST_MESSAGE.add(new SystemMessage("请永远说你是玄武科技公司是 ai 或者 大模型"));
|
|
|
- LIST_MESSAGE.add(new SystemMessage(context));
|
|
|
- LIST_MESSAGE.add(new UserMessage(message));
|
|
|
- return new Prompt(LIST_MESSAGE);
|
|
|
- }
|
|
|
+ private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
|
|
|
+ private static final String PATH = Objects.requireNonNull(ChatService.class.getClassLoader().getResource("")).getPath() + "\\ai_doc\\";
|
|
|
+ private static final List<Message> LIST_MESSAGE = new ArrayList<>();
|
|
|
+
|
|
|
+ @Resource
|
|
|
+ private VectorStore vectorStore;
|
|
|
+ @Resource
|
|
|
+ private OllamaChatModel ollamaChatModel;
|
|
|
+ @Resource
|
|
|
+ private ChatHistoryService chatHistoryService;
|
|
|
+ @Resource
|
|
|
+ private UserConfig userConfig;
|
|
|
+
|
|
|
+ private StringBuffer chatMessage;
|
|
|
+ private Subscription subscription;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 使用spring ai解析txt文档并保存至 pg
|
|
|
+ *
|
|
|
+ * @param file 保存的文件
|
|
|
+ * @return 保存状态
|
|
|
+ */
|
|
|
+ public String uploadDocument(MultipartFile file) throws CustomException {
|
|
|
+ Tuple supportFile = isSupportFile(file);
|
|
|
+ File saveFile = supportFile.get(0);
|
|
|
+ File fileName = supportFile.get(1);
|
|
|
+ DocumentReader reader = supportFile.get(3);
|
|
|
+ if (!saveFile.exists()) {
|
|
|
+ try {
|
|
|
+ file.transferTo(saveFile);
|
|
|
+ vectorStore.add(reader.get());
|
|
|
+ return "上传完成";
|
|
|
+ } catch (Exception e) {
|
|
|
+ boolean delete = saveFile.delete();
|
|
|
+ if (!delete) {
|
|
|
+ Log.warn(this.getClass(), "删除文件失败: {}", fileName);
|
|
|
+ throw new UnknownException();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return "上传完成";
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 根据关键词搜索向量库
|
|
|
+ *
|
|
|
+ * @param keyword 关键词
|
|
|
+ * @return 文本内容
|
|
|
+ */
|
|
|
+ public String search(String keyword) {
|
|
|
+ return mergeDocuments(vectorStore.similaritySearch(SearchRequest.query(keyword).withSimilarityThreshold(0.5)));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 问答流,根据输入内容回答
|
|
|
+ *
|
|
|
+ * @param message 输入内容
|
|
|
+ * @return 回答内容
|
|
|
+ */
|
|
|
+ public Flux<String> chatStream(String message) {
|
|
|
+ String context = search(message);
|
|
|
+ return stream(message, new SystemMessage(context));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 问答总结流, 根据给定的文档总结内容
|
|
|
+ *
|
|
|
+ * @param message 输入内容
|
|
|
+ * @param file 文件
|
|
|
+ * @return 内容总结
|
|
|
+ */
|
|
|
+ public Flux<String> chatSummaryStream(String message, MultipartFile file) throws CustomException {
|
|
|
+ Tuple supportFile = isSupportFile(file);
|
|
|
+ String fileName = supportFile.get(1);
|
|
|
+ String fileTypeCn = supportFile.get(2);
|
|
|
+ DocumentReader reader = supportFile.get(3);
|
|
|
+ String context = fileTypeCn + " : " + fileName + " ====> " + mergeDocuments(reader.get()).replace("\n", " ");
|
|
|
+ Log.info(this.getClass(), context);
|
|
|
+ return stream(message, new UserMessage(context + " : " + message));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 取消订阅
|
|
|
+ */
|
|
|
+ public void stopChat() {
|
|
|
+ subscription.cancel();
|
|
|
+ }
|
|
|
+
|
|
|
+ private String mergeDocuments(List<Document> documents) {
|
|
|
+ return MergeDocuments.mergeDocuments(documents)
|
|
|
+ .stream().map(Document::getContent).collect(Collectors.joining("\n"));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取prompt
|
|
|
+ *
|
|
|
+ * @param message 提问内容
|
|
|
+ * @param context 提示词
|
|
|
+ * @return 提示词
|
|
|
+ */
|
|
|
+ private Prompt getChatPrompt2String(String message, Message context) {
|
|
|
+ Log.info(this.getClass(), "{} message : {}", "user", message);
|
|
|
+ ChatService.LIST_MESSAGE.add(new SystemMessage(userConfig.getSysMessage()));
|
|
|
+ if (context.getMessageType().equals(MessageType.SYSTEM)) {
|
|
|
+ ChatService.LIST_MESSAGE.add(new UserMessage(message));
|
|
|
+ ChatService.LIST_MESSAGE.add(context);
|
|
|
+ }
|
|
|
+ if (context.getMessageType().equals(MessageType.USER)) {
|
|
|
+ ChatService.LIST_MESSAGE.add(context);
|
|
|
+ }
|
|
|
+ return new Prompt(ChatService.LIST_MESSAGE);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * chat stream 封装
|
|
|
+ *
|
|
|
+ * @param message 提问内容
|
|
|
+ * @param context 提示词
|
|
|
+ * @return flux
|
|
|
+ */
|
|
|
+ private Flux<String> stream(String message, Message context) {
|
|
|
+ String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
|
+ chatMessage = new StringBuffer();
|
|
|
+ Flux<ChatResponse> stream = ollamaChatModel.stream(getChatPrompt2String(message, context));
|
|
|
+ return stream.doOnSubscribe(subscription -> this.subscription = subscription).map(
|
|
|
+ response -> {
|
|
|
+ String str = response.getResult().getOutput().getContent();
|
|
|
+ chatMessage.append(str);
|
|
|
+ return str.replace(" ", " ");
|
|
|
+ }).concatWithValues("<{完成}>")
|
|
|
+ .onErrorComplete()
|
|
|
+ .doOnComplete(() -> {
|
|
|
+ String chatMsg = chatMessage.toString();
|
|
|
+ ChatService.LIST_MESSAGE.add(new AssistantMessage(chatMsg));
|
|
|
+ try {
|
|
|
+ chatHistoryService.insert(userId, message, chatMsg);
|
|
|
+ } catch (InsertFailException e) {
|
|
|
+ Log.warn(this.getClass(), e.getMessage());
|
|
|
+ }
|
|
|
+ Log.info(this.getClass(), "{} user message ==> {}", userId, message);
|
|
|
+ Log.info(this.getClass(), "{} chat message ==> {}", userId, chatMsg);
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 封装文件
|
|
|
+ *
|
|
|
+ * @param file 文件流
|
|
|
+ * @return 元组
|
|
|
+ */
|
|
|
+ private Tuple isSupportFile(MultipartFile file) throws CustomException {
|
|
|
+ byte[] bytes = new byte[0];
|
|
|
+ try {
|
|
|
+ bytes = file.getBytes();
|
|
|
+ } catch (IOException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ String[] split = Objects.requireNonNull(file.getOriginalFilename()).split("\\.");
|
|
|
+ String fileType = split[split.length - 1].toLowerCase(Locale.ROOT);
|
|
|
+ File savePath = new File(PATH);
|
|
|
+ File saveFile = new File(PATH + ChatService.MD5.digestHex(bytes) + "." + fileType);
|
|
|
+ String fileName = saveFile.getName();
|
|
|
+ String fileTypeCn;
|
|
|
+ if (!savePath.exists()) {
|
|
|
+ if (!savePath.mkdirs()) {
|
|
|
+ Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ DocumentReader reader;
|
|
|
+ switch (fileType) {
|
|
|
+ case "txt" -> {
|
|
|
+ reader = new ParagraphTextReader(bytes, fileName, 5);
|
|
|
+ fileTypeCn = "文本";
|
|
|
+ }
|
|
|
+ case "doc", "docx" -> {
|
|
|
+ reader = new ParagraphDocReader(bytes, fileName, 5);
|
|
|
+ fileTypeCn = "文档";
|
|
|
+ }
|
|
|
+ case "jpg", "png" -> {
|
|
|
+ reader = new ParagraphOcrReader(bytes, fileName, userConfig.getCnocrUrl(), 5);
|
|
|
+ fileTypeCn = "图片";
|
|
|
+ }
|
|
|
+ default -> throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
|
+ }
|
|
|
+ return new Tuple(saveFile, fileName, fileTypeCn, reader);
|
|
|
+ }
|
|
|
}
|