|
@@ -9,11 +9,14 @@ import cn.jlsxwkj.common.utils.FileType;
|
|
|
import cn.jlsxwkj.common.utils.Log;
|
|
|
import cn.jlsxwkj.common.utils.MergeDocuments;
|
|
|
import cn.jlsxwkj.moudles.chat.message.*;
|
|
|
-import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
|
|
|
+import cn.jlsxwkj.moudles.chat_history.ChatHistoryService;
|
|
|
import jakarta.annotation.Resource;
|
|
|
+import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.document.Document;
|
|
|
import org.springframework.ai.document.DocumentReader;
|
|
|
import org.springframework.ai.embedding.EmbeddingModel;
|
|
|
+import org.springframework.ai.ollama.OllamaChatModel;
|
|
|
+import org.springframework.ai.ollama.api.OllamaOptions;
|
|
|
import org.springframework.ai.vectorstore.SearchRequest;
|
|
|
import org.springframework.ai.vectorstore.VectorStore;
|
|
|
import org.springframework.stereotype.Service;
|
|
@@ -22,10 +25,7 @@ import reactor.core.publisher.Flux;
|
|
|
|
|
|
import java.io.File;
|
|
|
import java.io.IOException;
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Locale;
|
|
|
-import java.util.Objects;
|
|
|
+import java.util.*;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
@@ -35,39 +35,90 @@ import java.util.stream.Collectors;
|
|
|
@Service
|
|
|
public class ChatService {
|
|
|
|
|
|
+ /**
|
|
|
+ * md5.
|
|
|
+ */
|
|
|
private final MD5 md5 = cn.hutool.crypto.digest.MD5.create();
|
|
|
- private final String PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
|
+ /**
|
|
|
+ * 获取当前项目地址 + /ai_doc.
|
|
|
+ */
|
|
|
+ private static final String
|
|
|
+ PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
|
+ /**
|
|
|
+ * 函数回调集合.
|
|
|
+ */
|
|
|
+ private static final HashSet<String> TOOLS_FUNCTION = new HashSet<>();
|
|
|
+ /**
|
|
|
+ * 历史消息列表.
|
|
|
+ */
|
|
|
private List<Message> listMessage;
|
|
|
+ /**
|
|
|
+ * chat 消息.
|
|
|
+ */
|
|
|
private StringBuffer chatMessage;
|
|
|
-
|
|
|
+ /**
|
|
|
+ * pg 向量对象(将文档转换为向量并保存至数据库).
|
|
|
+ */
|
|
|
@Resource
|
|
|
private VectorStore vectorStore;
|
|
|
+ /**
|
|
|
+ * ollama 模型.
|
|
|
+ */
|
|
|
+ @Resource
|
|
|
+ private OllamaChatModel ollamaChatModel;
|
|
|
+ /**
|
|
|
+ * 向量化模型.
|
|
|
+ */
|
|
|
@Resource
|
|
|
private EmbeddingModel embeddingModel;
|
|
|
+ /**
|
|
|
+ * 自定义 chat 工具.
|
|
|
+ */
|
|
|
@Resource
|
|
|
private ChatUtil chatUtil;
|
|
|
+ /**
|
|
|
+ * 保存对话对象(将对话保存至数据库).
|
|
|
+ */
|
|
|
@Resource
|
|
|
private ChatHistoryService chatHistoryService;
|
|
|
+ /**
|
|
|
+ * 自定义配置类.
|
|
|
+ */
|
|
|
@Resource
|
|
|
private UserConfig userConfig;
|
|
|
|
|
|
+ /*
|
|
|
+ 初始化回调函数
|
|
|
+ */
|
|
|
+ static {
|
|
|
+ TOOLS_FUNCTION.add("getWeather");
|
|
|
+ }
|
|
|
|
|
|
/**
|
|
|
- * 使用spring ai解析txt文档并保存至 pg
|
|
|
+ * 使用spring ai解析txt文档并保存至 pg.
|
|
|
*
|
|
|
* @param file 保存的文件
|
|
|
* @return 保存状态
|
|
|
*/
|
|
|
- public String uploadDocument(MultipartFile file) throws CustomException {
|
|
|
+ public String uploadDocument(
|
|
|
+ final MultipartFile file) throws CustomException {
|
|
|
+ //检测文件是否支持
|
|
|
Tuple supportFile = isSupportFile(file);
|
|
|
+ //需要保存的文件
|
|
|
File saveFile = supportFile.get(0);
|
|
|
+ //文件名
|
|
|
String fileName = supportFile.get(1);
|
|
|
+ //分割文档对象
|
|
|
DocumentReader reader = supportFile.get(3);
|
|
|
+ //判断文件是否存在
|
|
|
if (!saveFile.exists()) {
|
|
|
try {
|
|
|
+ //保存文件
|
|
|
file.transferTo(saveFile);
|
|
|
+ //保存文档到数据库
|
|
|
vectorStore.add(reader.get());
|
|
|
} catch (Exception e) {
|
|
|
+ //保存文件到数据库失败删除本地缓存文件
|
|
|
boolean delete = saveFile.delete();
|
|
|
if (!delete) {
|
|
|
Log.warn(this.getClass(), "删除文件失败: {}", fileName);
|
|
@@ -79,109 +130,153 @@ public class ChatService {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 根据关键词搜索向量库
|
|
|
+ * 根据关键词搜索向量库.
|
|
|
*
|
|
|
* @param keyword 关键词
|
|
|
* @return 文本内容
|
|
|
*/
|
|
|
- public String search(String keyword) {
|
|
|
- SearchRequest searchRequest = SearchRequest.query(keyword).withSimilarityThreshold(0.5);
|
|
|
+ public String search(final String keyword) {
|
|
|
+ //设置检索文档关键字相似度至少近似 50% 以上
|
|
|
+ SearchRequest searchRequest =
|
|
|
+ SearchRequest.query(keyword).withSimilarityThreshold(0.5);
|
|
|
+ //检索文档
|
|
|
List<Document> documents = vectorStore.similaritySearch(searchRequest);
|
|
|
+ //返回合并的文档
|
|
|
return mergeDocuments(documents);
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 问答流,根据输入内容回答
|
|
|
+ * 问答流,根据输入内容回答.
|
|
|
*
|
|
|
* @param message 输入内容
|
|
|
* @return 回答内容
|
|
|
*/
|
|
|
- public Flux<String> chatStream(String message) {
|
|
|
+ public Flux<String> chatStream(final String message) {
|
|
|
+ //检索知识库
|
|
|
String context = search(message);
|
|
|
+ //预处理
|
|
|
+ //创建用户消息
|
|
|
+ org.springframework.ai.chat.messages.UserMessage userMessage =
|
|
|
+ new org.springframework.ai.chat.messages.UserMessage(message);
|
|
|
+ //设置提示词以及添加回调函数
|
|
|
+ Prompt prompt = new Prompt(userMessage, OllamaOptions.builder()
|
|
|
+ .withFunctions(TOOLS_FUNCTION).build());
|
|
|
+ //预处理, chat 总结的函数结果
|
|
|
+ String toolMessage = ollamaChatModel.call(prompt)
|
|
|
+ .getResult().getOutput().getContent();
|
|
|
+ //合并知识库结果和函数结果
|
|
|
String format = """
|
|
|
- %s : Please answer user questions based on the content.
|
|
|
%s
|
|
|
+ %s
|
|
|
+ 如果以上内容不为空请永远用以上内容回答问题
|
|
|
""";
|
|
|
- return stream(message, new SystemMessage(String.format(format, context, message)));
|
|
|
+ //提问 chat 并返回流
|
|
|
+ return stream(message, new SystemMessage(
|
|
|
+ String.format(format, context, toolMessage, message)));
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 问答总结流, 根据给定的文档总结内容
|
|
|
+ * 问答总结流, 根据给定的文档总结内容.
|
|
|
*
|
|
|
* @param message 输入内容
|
|
|
* @param file 文件
|
|
|
* @return 内容总结
|
|
|
*/
|
|
|
- public Flux<String> chatSummaryStream(String message, MultipartFile file) {
|
|
|
+ public Flux<String> chatSummaryStream(final String message,
|
|
|
+ final MultipartFile file) {
|
|
|
+ //检测模型是否开启
|
|
|
embeddingModel.embed("");
|
|
|
+ //检测文件是否支持
|
|
|
Tuple supportFile;
|
|
|
try {
|
|
|
supportFile = isSupportFile(file);
|
|
|
} catch (CustomException e) {
|
|
|
+ //如果异常则通过流返回异常
|
|
|
return Flux.just(e.getMessage()).concatWithValues("<{完成}>");
|
|
|
}
|
|
|
+ //文件名
|
|
|
String fileName = supportFile.get(1);
|
|
|
+ //文件类型
|
|
|
String fileTypeCn = supportFile.get(2);
|
|
|
+ //分割文档对象
|
|
|
DocumentReader reader = supportFile.get(3);
|
|
|
+ //格式化提问格式
|
|
|
String format = """
|
|
|
%s :: %s ====> %s
|
|
|
↑ %s
|
|
|
""";
|
|
|
- String context = String.format(format, fileTypeCn, fileName, mergeDocuments(reader.get()), message);
|
|
|
+ String context = String.format(format,
|
|
|
+ fileTypeCn, fileName, mergeDocuments(reader.get()), message);
|
|
|
+ //设置用户消息
|
|
|
UserMessage userMessage = new UserMessage(context);
|
|
|
Log.info(this.getClass(), context);
|
|
|
+ //返回 chat 流
|
|
|
return stream(message, userMessage);
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 合并文档
|
|
|
+ * 合并文档.
|
|
|
*
|
|
|
* @param documents 文档列表
|
|
|
* @return 合并后的文档
|
|
|
*/
|
|
|
- private String mergeDocuments(List<Document> documents) {
|
|
|
+ private String mergeDocuments(final List<Document> documents) {
|
|
|
return MergeDocuments.mergeDocuments(documents).stream()
|
|
|
.map(Document::getContent)
|
|
|
.collect(Collectors.joining("\n"));
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 停止 chat
|
|
|
+ * 停止 chat.
|
|
|
*/
|
|
|
public void stop() {
|
|
|
chatUtil.stop();
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 获取prompt
|
|
|
+ * 获取prompt.
|
|
|
*
|
|
|
* @param message 提问内容
|
|
|
* @param context 提示词
|
|
|
*/
|
|
|
- private void setChatMessage(String message, Message context) {
|
|
|
+ private void setChatMessage(final String message,
|
|
|
+ final Message context) {
|
|
|
+ //初始化消息并设置系统提示词
|
|
|
if (listMessage == null) {
|
|
|
listMessage = new ArrayList<>();
|
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
|
}
|
|
|
+ //判断消息列表是否大于定义的消息列表长度
|
|
|
if (listMessage.size() > userConfig.getMaxListMessageLength()) {
|
|
|
- int start = listMessage.size() - userConfig.getMaxListMessageLength();
|
|
|
- listMessage = listMessage.subList(start, listMessage.size()).stream()
|
|
|
- .filter(msg -> !msg.getContent().equals(userConfig.getSysMessage()))
|
|
|
+ //缩减长度到指定列表长度
|
|
|
+ int start =
|
|
|
+ listMessage.size() - userConfig.getMaxListMessageLength();
|
|
|
+ listMessage = listMessage.subList(start,
|
|
|
+ listMessage.size()).stream()
|
|
|
+ //过滤定义的定义的系统消息
|
|
|
+ .filter(msg -> !msg.getContent()
|
|
|
+ .equals(userConfig.getSysMessage()))
|
|
|
.map(msg -> {
|
|
|
+ //判断是否为 chat 消息, 如果是缩减 chat 回答后的长度
|
|
|
if (msg.getRole().equals(UserRole.ASSISTANT)) {
|
|
|
String content = msg.getContent();
|
|
|
- if (content.length() > userConfig.getMaxMessageLength()) {
|
|
|
- String assistantMessage = content.substring(0, userConfig.getMaxMessageLength());
|
|
|
- return new AssistantMessage(assistantMessage);
|
|
|
+ if (content.length()
|
|
|
+ > userConfig.getMaxMessageLength()) {
|
|
|
+ return new AssistantMessage(content.substring(
|
|
|
+ 0,
|
|
|
+ userConfig.getMaxMessageLength()));
|
|
|
}
|
|
|
}
|
|
|
return msg;
|
|
|
}).distinct()
|
|
|
.collect(Collectors.toList());
|
|
|
+ //再次将自定义系统提示词添加到列表末尾
|
|
|
listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
|
}
|
|
|
+ //如果 context 是系统消息则将两个都添加
|
|
|
if (context.getRole().equals(UserRole.SYSTEM)) {
|
|
|
listMessage.add(new UserMessage(message));
|
|
|
+ //判断系统消息是否为空
|
|
|
if (!context.getContent().isBlank()) {
|
|
|
listMessage.add(context);
|
|
|
}
|
|
@@ -192,68 +287,106 @@ public class ChatService {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * chat stream 封装
|
|
|
+ * chat stream 封装.
|
|
|
*
|
|
|
* @param message 提问内容
|
|
|
* @param context 提示词
|
|
|
* @return flux
|
|
|
*/
|
|
|
- private Flux<String> stream(String message, Message context) {
|
|
|
+ private Flux<String> stream(final String message,
|
|
|
+ final Message context) {
|
|
|
+ //获取用户id
|
|
|
String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
|
+ //初始化
|
|
|
chatMessage = new StringBuffer();
|
|
|
+ //设置消息列表
|
|
|
setChatMessage(message, context);
|
|
|
+ //返回消息流
|
|
|
return chatUtil.chat(listMessage, true).map(response -> {
|
|
|
+ //获取消息主体
|
|
|
String str = response.getMessage().getContent();
|
|
|
+ //添加至消息中
|
|
|
chatMessage.append(str);
|
|
|
- System.out.print(str);
|
|
|
+ //格式化字符串
|
|
|
return str.replace(" ", " ")
|
|
|
.replace("\t", " ");
|
|
|
- }).concatWithValues("<{完成}>")
|
|
|
+ })
|
|
|
+ //流完成后拼接的字符
|
|
|
+ .concatWithValues("<{完成}>")
|
|
|
+ //取消时返回的调度器
|
|
|
.cancelOn(chatUtil.getScheduler())
|
|
|
- .doOnCancel(this::stop)
|
|
|
+ //取消时执行的操作
|
|
|
+ .doOnCancel(() -> {
|
|
|
+ //上游
|
|
|
+ this.stop();
|
|
|
+ Log.info(this.getClass(), "{} ==== 用户停止", userId);
|
|
|
+ })
|
|
|
+ //流完成时的操作
|
|
|
.doOnComplete(() -> {
|
|
|
String chatMsg = chatMessage.toString();
|
|
|
listMessage.add(new AssistantMessage(chatMsg));
|
|
|
try {
|
|
|
+ //将消息添加至数据库
|
|
|
chatHistoryService.insert(userId, message, chatMsg);
|
|
|
} catch (InsertFailException e) {
|
|
|
Log.warn(this.getClass(), e.getMessage());
|
|
|
}
|
|
|
- Log.info(this.getClass(), "{} user message ==> {}", userId, message);
|
|
|
- Log.info(this.getClass(), "{} chat message ==> {}", userId, chatMsg);
|
|
|
+ Log.info(this.getClass(),
|
|
|
+ "{} user message ==> {}", userId, message);
|
|
|
+ Log.info(this.getClass(),
|
|
|
+ "{} chat message ==> {}", userId, chatMsg);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 封装文件
|
|
|
+ * 封装文件.
|
|
|
*
|
|
|
* @param file 文件流
|
|
|
- * @return 元组(File file, String fileName, String fileTypeCn, DocumentReader reader)
|
|
|
+ * @return 元组
|
|
|
*/
|
|
|
- private Tuple isSupportFile(MultipartFile file) throws CustomException {
|
|
|
+ private Tuple isSupportFile(
|
|
|
+ final MultipartFile file) throws CustomException {
|
|
|
+ //初始化
|
|
|
byte[] bytes;
|
|
|
try {
|
|
|
+ //读取文件
|
|
|
bytes = file.getBytes();
|
|
|
} catch (IOException e) {
|
|
|
throw new OpenFileFailException("读取文件失败");
|
|
|
}
|
|
|
- String[] split = Objects.requireNonNull(file.getOriginalFilename()).split("\\.");
|
|
|
+ //对文件名分割
|
|
|
+ String[] split = Objects.requireNonNull(file.getOriginalFilename())
|
|
|
+ .split("\\.");
|
|
|
+ //文件类型
|
|
|
String fileType = split[split.length - 1].toLowerCase(Locale.ROOT);
|
|
|
+ //保存文件地址
|
|
|
File savePath = new File(PATH);
|
|
|
+ //保存文件
|
|
|
File saveFile = new File(PATH + md5.digestHex(bytes) + "." + fileType);
|
|
|
+ //文件名
|
|
|
String fileName = saveFile.getName();
|
|
|
+ //判断文件地址是否存在
|
|
|
if (!savePath.exists()) {
|
|
|
+ //创建文件夹
|
|
|
if (!savePath.mkdirs()) {
|
|
|
Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
|
}
|
|
|
}
|
|
|
- Tuple fileTypeAndResource = new FileType(bytes, fileName, 5, userConfig.getCnocrUrl())
|
|
|
+ //获取文件的各参数
|
|
|
+ Tuple fileTypeAndResource = new FileType(
|
|
|
+ bytes,
|
|
|
+ fileName,
|
|
|
+ 5,
|
|
|
+ userConfig.getCnocrUrl())
|
|
|
.getFileTypeAndResource(fileType);
|
|
|
if (fileTypeAndResource == null) {
|
|
|
throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
|
}
|
|
|
+ //文件类型(中文)
|
|
|
String fileTypeCn = fileTypeAndResource.get(0);
|
|
|
+ //分割文档对象
|
|
|
DocumentReader reader = fileTypeAndResource.get(1);
|
|
|
+ //返回(保存文件对象, 文件名, 文件类型中文名, 分割文档对象)
|
|
|
return new Tuple(saveFile, fileName, fileTypeCn, reader);
|
|
|
}
|
|
|
}
|