|
@@ -8,21 +8,15 @@ import cn.jlsxwkj.common.exception.CustomException;
|
|
import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
|
|
import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
|
|
import cn.jlsxwkj.common.exception.InsertFailException;
|
|
import cn.jlsxwkj.common.exception.InsertFailException;
|
|
import cn.jlsxwkj.common.exception.UnknownException;
|
|
import cn.jlsxwkj.common.exception.UnknownException;
|
|
-import cn.jlsxwkj.common.reader.ParagraphDocReader;
|
|
|
|
-import cn.jlsxwkj.common.reader.ParagraphOcrReader;
|
|
|
|
-import cn.jlsxwkj.common.reader.ParagraphTextReader;
|
|
|
|
|
|
+import cn.jlsxwkj.common.utils.FileType;
|
|
import cn.jlsxwkj.common.utils.Log;
|
|
import cn.jlsxwkj.common.utils.Log;
|
|
import cn.jlsxwkj.common.utils.MergeDocuments;
|
|
import cn.jlsxwkj.common.utils.MergeDocuments;
|
|
|
|
+import cn.jlsxwkj.moudles.chat.message.*;
|
|
import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
|
|
import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
|
|
import jakarta.annotation.Resource;
|
|
import jakarta.annotation.Resource;
|
|
-import org.reactivestreams.Subscription;
|
|
|
|
-import org.springframework.ai.chat.messages.*;
|
|
|
|
-import org.springframework.ai.chat.model.ChatResponse;
|
|
|
|
-import org.springframework.ai.chat.prompt.Prompt;
|
|
|
|
import org.springframework.ai.document.Document;
|
|
import org.springframework.ai.document.Document;
|
|
import org.springframework.ai.document.DocumentReader;
|
|
import org.springframework.ai.document.DocumentReader;
|
|
import org.springframework.ai.embedding.EmbeddingModel;
|
|
import org.springframework.ai.embedding.EmbeddingModel;
|
|
-import org.springframework.ai.ollama.OllamaChatModel;
|
|
|
|
import org.springframework.ai.vectorstore.SearchRequest;
|
|
import org.springframework.ai.vectorstore.SearchRequest;
|
|
import org.springframework.ai.vectorstore.VectorStore;
|
|
import org.springframework.ai.vectorstore.VectorStore;
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.stereotype.Service;
|
|
@@ -47,20 +41,19 @@ public class ChatService {
|
|
private final MD5 md5 = cn.hutool.crypto.digest.MD5.create();
|
|
private final MD5 md5 = cn.hutool.crypto.digest.MD5.create();
|
|
private final String PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
private final String PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
private List<Message> listMessage;
|
|
private List<Message> listMessage;
|
|
|
|
+ private StringBuffer chatMessage;
|
|
|
|
|
|
@Resource
|
|
@Resource
|
|
private VectorStore vectorStore;
|
|
private VectorStore vectorStore;
|
|
@Resource
|
|
@Resource
|
|
private EmbeddingModel embeddingModel;
|
|
private EmbeddingModel embeddingModel;
|
|
@Resource
|
|
@Resource
|
|
- private OllamaChatModel ollamaChatModel;
|
|
|
|
|
|
+ private ChatUtil chatUtil;
|
|
@Resource
|
|
@Resource
|
|
private ChatHistoryService chatHistoryService;
|
|
private ChatHistoryService chatHistoryService;
|
|
@Resource
|
|
@Resource
|
|
private UserConfig userConfig;
|
|
private UserConfig userConfig;
|
|
|
|
|
|
- private StringBuffer chatMessage;
|
|
|
|
- private Subscription subscription;
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
* 使用spring ai解析txt文档并保存至 pg
|
|
* 使用spring ai解析txt文档并保存至 pg
|
|
@@ -135,25 +128,30 @@ public class ChatService {
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * 取消订阅
|
|
|
|
|
|
+ * 合并文档
|
|
|
|
+ *
|
|
|
|
+ * @param documents 文档列表
|
|
|
|
+ * @return 合并后的文档
|
|
*/
|
|
*/
|
|
- public void stopChat() {
|
|
|
|
- subscription.cancel();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
private String mergeDocuments(List<Document> documents) {
|
|
private String mergeDocuments(List<Document> documents) {
|
|
return MergeDocuments.mergeDocuments(documents).stream()
|
|
return MergeDocuments.mergeDocuments(documents).stream()
|
|
.map(Document::getContent).collect(Collectors.joining("\n"));
|
|
.map(Document::getContent).collect(Collectors.joining("\n"));
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
|
|
+ * 停止 chat
|
|
|
|
+ */
|
|
|
|
+ public void stop() {
|
|
|
|
+ chatUtil.stop();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
* 获取prompt
|
|
* 获取prompt
|
|
*
|
|
*
|
|
* @param message 提问内容
|
|
* @param message 提问内容
|
|
* @param context 提示词
|
|
* @param context 提示词
|
|
- * @return 提示词
|
|
|
|
*/
|
|
*/
|
|
- private Prompt getChatPrompt(String message, Message context) {
|
|
|
|
|
|
+ private void setChatMessage(String message, Message context) {
|
|
if (listMessage == null) {
|
|
if (listMessage == null) {
|
|
listMessage = new ArrayList<>();
|
|
listMessage = new ArrayList<>();
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
@@ -163,7 +161,7 @@ public class ChatService {
|
|
.stream()
|
|
.stream()
|
|
.filter(msg -> !msg.getContent().equals(userConfig.getSysMessage()))
|
|
.filter(msg -> !msg.getContent().equals(userConfig.getSysMessage()))
|
|
.map(msg -> {
|
|
.map(msg -> {
|
|
- if (msg.getMessageType().equals(MessageType.ASSISTANT)) {
|
|
|
|
|
|
+ if (msg.getRole().equals(UserRole.ASSISTANT)) {
|
|
String content = msg.getContent();
|
|
String content = msg.getContent();
|
|
if (content.length() > userConfig.getMaxMessageLength()) {
|
|
if (content.length() > userConfig.getMaxMessageLength()) {
|
|
return new AssistantMessage(content.substring(0, userConfig.getMaxMessageLength()));
|
|
return new AssistantMessage(content.substring(0, userConfig.getMaxMessageLength()));
|
|
@@ -174,16 +172,15 @@ public class ChatService {
|
|
.collect(Collectors.toList());
|
|
.collect(Collectors.toList());
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
}
|
|
}
|
|
- if (context.getMessageType().equals(MessageType.SYSTEM)) {
|
|
|
|
|
|
+ if (context.getRole().equals(UserRole.SYSTEM)) {
|
|
listMessage.add(new UserMessage(message));
|
|
listMessage.add(new UserMessage(message));
|
|
if (!context.getContent().isBlank()) {
|
|
if (!context.getContent().isBlank()) {
|
|
listMessage.add(context);
|
|
listMessage.add(context);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- if (context.getMessageType().equals(MessageType.USER)) {
|
|
|
|
- return new Prompt(context);
|
|
|
|
|
|
+ if (context.getRole().equals(UserRole.USER)) {
|
|
|
|
+ listMessage.add(context);
|
|
}
|
|
}
|
|
- return new Prompt(listMessage);
|
|
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
@@ -196,15 +193,16 @@ public class ChatService {
|
|
private Flux<String> stream(String message, Message context) {
|
|
private Flux<String> stream(String message, Message context) {
|
|
String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
chatMessage = new StringBuffer();
|
|
chatMessage = new StringBuffer();
|
|
- Flux<ChatResponse> stream = ollamaChatModel.stream(getChatPrompt(message, context));
|
|
|
|
- return stream.doOnSubscribe(subscription -> this.subscription = subscription).map(
|
|
|
|
- response -> {
|
|
|
|
- String str = response.getResult().getOutput().getContent();
|
|
|
|
- chatMessage.append(str);
|
|
|
|
- return str.replace(" ", " ")
|
|
|
|
- .replace("\t", " ");
|
|
|
|
- }).concatWithValues("<{完成}>")
|
|
|
|
- .onErrorComplete()
|
|
|
|
|
|
+ setChatMessage(message, context);
|
|
|
|
+ return chatUtil.chat(listMessage, true).map(response -> {
|
|
|
|
+ String str = response.getMessage().getContent();
|
|
|
|
+ chatMessage.append(str);
|
|
|
|
+ System.out.print(str);
|
|
|
|
+ return str.replace(" ", " ")
|
|
|
|
+ .replace("\t", " ");
|
|
|
|
+ }).concatWithValues("<{完成}>")
|
|
|
|
+ .cancelOn(chatUtil.getScheduler())
|
|
|
|
+ .doOnCancel(this::stop)
|
|
.doOnComplete(() -> {
|
|
.doOnComplete(() -> {
|
|
String chatMsg = chatMessage.toString();
|
|
String chatMsg = chatMessage.toString();
|
|
listMessage.add(new AssistantMessage(chatMsg));
|
|
listMessage.add(new AssistantMessage(chatMsg));
|
|
@@ -236,28 +234,18 @@ public class ChatService {
|
|
File savePath = new File(PATH);
|
|
File savePath = new File(PATH);
|
|
File saveFile = new File(PATH + md5.digestHex(bytes) + "." + fileType);
|
|
File saveFile = new File(PATH + md5.digestHex(bytes) + "." + fileType);
|
|
String fileName = saveFile.getName();
|
|
String fileName = saveFile.getName();
|
|
- String fileTypeCn;
|
|
|
|
if (!savePath.exists()) {
|
|
if (!savePath.exists()) {
|
|
if (!savePath.mkdirs()) {
|
|
if (!savePath.mkdirs()) {
|
|
Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- DocumentReader reader;
|
|
|
|
- switch (fileType) {
|
|
|
|
- case "txt" -> {
|
|
|
|
- reader = new ParagraphTextReader(bytes, fileName, 5);
|
|
|
|
- fileTypeCn = "文本";
|
|
|
|
- }
|
|
|
|
- case "doc", "docx" -> {
|
|
|
|
- reader = new ParagraphDocReader(bytes, fileName, 5);
|
|
|
|
- fileTypeCn = "文档";
|
|
|
|
- }
|
|
|
|
- case "jpg", "png" -> {
|
|
|
|
- reader = new ParagraphOcrReader(bytes, fileName, userConfig.getCnocrUrl(), 5);
|
|
|
|
- fileTypeCn = "图片";
|
|
|
|
- }
|
|
|
|
- default -> throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
|
|
|
|
+ Tuple fileTypeAndResource = new FileType(bytes, fileName, 5, userConfig.getCnocrUrl())
|
|
|
|
+ .getFileTypeAndResource(fileType);
|
|
|
|
+ if (fileTypeAndResource == null) {
|
|
|
|
+ throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
}
|
|
}
|
|
|
|
+ String fileTypeCn = fileTypeAndResource.get(0);
|
|
|
|
+ DocumentReader reader = fileTypeAndResource.get(1);
|
|
return new Tuple(saveFile, fileName, fileTypeCn, reader);
|
|
return new Tuple(saveFile, fileName, fileTypeCn, reader);
|
|
}
|
|
}
|
|
}
|
|
}
|