|
@@ -3,14 +3,22 @@ package cn.jlsxwkj.moudles.chat;
|
|
|
import cn.dev33.satoken.SaManager;
|
|
|
import cn.hutool.core.lang.Tuple;
|
|
|
import cn.hutool.crypto.digest.MD5;
|
|
|
+import cn.hutool.json.JSONObject;
|
|
|
+import cn.hutool.json.JSONUtil;
|
|
|
import cn.jlsxwkj.common.config.UserConfig;
|
|
|
-import cn.jlsxwkj.common.exception.*;
|
|
|
+import cn.jlsxwkj.common.exception.CustomException;
|
|
|
+import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
|
|
|
+import cn.jlsxwkj.common.exception.InsertFailException;
|
|
|
+import cn.jlsxwkj.common.exception.OpenFileFailException;
|
|
|
+import cn.jlsxwkj.common.exception.UnknownException;
|
|
|
import cn.jlsxwkj.common.utils.FileType;
|
|
|
import cn.jlsxwkj.common.utils.Log;
|
|
|
import cn.jlsxwkj.common.utils.MergeDocuments;
|
|
|
-import cn.jlsxwkj.moudles.chat.message.*;
|
|
|
import cn.jlsxwkj.moudles.chat_history.ChatHistoryService;
|
|
|
import jakarta.annotation.Resource;
|
|
|
+import org.springframework.ai.chat.messages.Message;
|
|
|
+import org.springframework.ai.chat.messages.SystemMessage;
|
|
|
+import org.springframework.ai.chat.messages.UserMessage;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.document.Document;
|
|
|
import org.springframework.ai.document.DocumentReader;
|
|
@@ -25,9 +33,13 @@ import reactor.core.publisher.Flux;
|
|
|
|
|
|
import java.io.File;
|
|
|
import java.io.IOException;
|
|
|
-import java.util.*;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Locale;
|
|
|
+import java.util.Objects;
|
|
|
+import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
+import static cn.jlsxwkj.common.config.FunctionCallConfig.TOOLS_FUNCTION;
|
|
|
|
|
|
/**
|
|
|
* @author zh
|
|
@@ -42,20 +54,7 @@ public class ChatService {
|
|
|
/**
|
|
|
* 获取当前项目地址 + /ai_doc.
|
|
|
*/
|
|
|
- private static final String
|
|
|
- PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
|
- /**
|
|
|
- * 函数回调集合.
|
|
|
- */
|
|
|
- private static final HashSet<String> TOOLS_FUNCTION = new HashSet<>();
|
|
|
- /**
|
|
|
- * 历史消息列表.
|
|
|
- */
|
|
|
- private List<Message> listMessage;
|
|
|
- /**
|
|
|
- * chat 消息.
|
|
|
- */
|
|
|
- private StringBuffer chatMessage;
|
|
|
+ private static final String PATH = System.getProperty("user.dir") + "/ai_doc/";
|
|
|
/**
|
|
|
* pg 向量对象(将文档转换为向量并保存至数据库).
|
|
|
*/
|
|
@@ -72,11 +71,6 @@ public class ChatService {
|
|
|
@Resource
|
|
|
private EmbeddingModel embeddingModel;
|
|
|
/**
|
|
|
- * 自定义 chat 工具.
|
|
|
- */
|
|
|
- @Resource
|
|
|
- private ChatUtil chatUtil;
|
|
|
- /**
|
|
|
* 保存对话对象(将对话保存至数据库).
|
|
|
*/
|
|
|
@Resource
|
|
@@ -87,13 +81,6 @@ public class ChatService {
|
|
|
@Resource
|
|
|
private UserConfig userConfig;
|
|
|
|
|
|
- /*
|
|
|
- 初始化回调函数
|
|
|
- */
|
|
|
- static {
|
|
|
- TOOLS_FUNCTION.add("getWeather");
|
|
|
- }
|
|
|
-
|
|
|
/**
|
|
|
* 使用spring ai解析txt文档并保存至 pg.
|
|
|
*
|
|
@@ -101,24 +88,16 @@ public class ChatService {
|
|
|
* @return 保存状态
|
|
|
*/
|
|
|
public String uploadDocument(
|
|
|
- final MultipartFile file) throws CustomException {
|
|
|
- //检测文件是否支持
|
|
|
+ final MultipartFile file) throws CustomException {
|
|
|
Tuple supportFile = isSupportFile(file);
|
|
|
- //需要保存的文件
|
|
|
File saveFile = supportFile.get(0);
|
|
|
- //文件名
|
|
|
String fileName = supportFile.get(1);
|
|
|
- //分割文档对象
|
|
|
DocumentReader reader = supportFile.get(3);
|
|
|
- //判断文件是否存在
|
|
|
if (!saveFile.exists()) {
|
|
|
try {
|
|
|
- //保存文件
|
|
|
file.transferTo(saveFile);
|
|
|
- //保存文档到数据库
|
|
|
vectorStore.add(reader.get());
|
|
|
} catch (Exception e) {
|
|
|
- //保存文件到数据库失败删除本地缓存文件
|
|
|
boolean delete = saveFile.delete();
|
|
|
if (!delete) {
|
|
|
Log.warn(this.getClass(), "删除文件失败: {}", fileName);
|
|
@@ -136,12 +115,8 @@ public class ChatService {
|
|
|
* @return 文本内容
|
|
|
*/
|
|
|
public String search(final String keyword) {
|
|
|
- //设置检索文档关键字相似度至少近似 50% 以上
|
|
|
- SearchRequest searchRequest =
|
|
|
- SearchRequest.query(keyword).withSimilarityThreshold(0.5);
|
|
|
- //检索文档
|
|
|
+ SearchRequest searchRequest = SearchRequest.query(keyword).withSimilarityThreshold(0.5);
|
|
|
List<Document> documents = vectorStore.similaritySearch(searchRequest);
|
|
|
- //返回合并的文档
|
|
|
return mergeDocuments(documents);
|
|
|
}
|
|
|
|
|
@@ -152,27 +127,37 @@ public class ChatService {
|
|
|
* @return 回答内容
|
|
|
*/
|
|
|
public Flux<String> chatStream(final String message) {
|
|
|
- //检索知识库
|
|
|
String context = search(message);
|
|
|
- //预处理
|
|
|
- //创建用户消息
|
|
|
- org.springframework.ai.chat.messages.UserMessage userMessage =
|
|
|
- new org.springframework.ai.chat.messages.UserMessage(message);
|
|
|
- //设置提示词以及添加回调函数
|
|
|
- Prompt prompt = new Prompt(userMessage, OllamaOptions.builder()
|
|
|
- .withFunctions(TOOLS_FUNCTION).build());
|
|
|
- //预处理, chat 总结的函数结果
|
|
|
- String toolMessage = ollamaChatModel.call(prompt)
|
|
|
- .getResult().getOutput().getContent();
|
|
|
- //合并知识库结果和函数结果
|
|
|
- String format = """
|
|
|
+ Prompt prompt = new Prompt(List.of(
|
|
|
+ new SystemMessage(userConfig.getSysMessage()),
|
|
|
+ new UserMessage(message)),
|
|
|
+ OllamaOptions.builder().withFunctions(TOOLS_FUNCTION).build());
|
|
|
+ AtomicBoolean flag = new AtomicBoolean(false);
|
|
|
+ StringBuilder funcCall = new StringBuilder();
|
|
|
+ String funcCallResult = "";
|
|
|
+ Objects.requireNonNull(ollamaChatModel.stream(prompt)
|
|
|
+ .map(response -> {
|
|
|
+ String content = response.getResult().getOutput().getContent();
|
|
|
+ if (content.contains("<tool")) {
|
|
|
+ flag.set(true);
|
|
|
+ }
|
|
|
+ return content;
|
|
|
+ }).collectList().block())
|
|
|
+ .forEach(funcCall::append);
|
|
|
+ if (flag.get()) {
|
|
|
+ String tool = funcCall.toString().split("\n")[1];
|
|
|
+ JSONObject entries = JSONUtil.parseObj(tool);
|
|
|
+ String funcName = entries.get("name", String.class);
|
|
|
+ String args = entries.get("arguments", String.class);
|
|
|
+ funcCallResult = ollamaChatModel.getFunctionCallbackRegister().get(funcName).call(args);
|
|
|
+ }
|
|
|
+ return stream(new UserMessage(message),
|
|
|
+ new SystemMessage("""
|
|
|
%s
|
|
|
%s
|
|
|
- 如果以上内容不为空请永远用以上内容回答问题
|
|
|
- """;
|
|
|
- //提问 chat 并返回流
|
|
|
- return stream(message, new SystemMessage(
|
|
|
- String.format(format, context, toolMessage, message)));
|
|
|
+ 根据参考内容回答, 当内容和问题不匹配时不要参考 system context, 请参考 user context 回答
|
|
|
+ """.formatted(context, funcCallResult))
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -184,34 +169,23 @@ public class ChatService {
|
|
|
*/
|
|
|
public Flux<String> chatSummaryStream(final String message,
|
|
|
final MultipartFile file) {
|
|
|
- //检测模型是否开启
|
|
|
embeddingModel.embed("");
|
|
|
- //检测文件是否支持
|
|
|
Tuple supportFile;
|
|
|
try {
|
|
|
supportFile = isSupportFile(file);
|
|
|
} catch (CustomException e) {
|
|
|
- //如果异常则通过流返回异常
|
|
|
return Flux.just(e.getMessage()).concatWithValues("<{完成}>");
|
|
|
}
|
|
|
- //文件名
|
|
|
String fileName = supportFile.get(1);
|
|
|
- //文件类型
|
|
|
String fileTypeCn = supportFile.get(2);
|
|
|
- //分割文档对象
|
|
|
DocumentReader reader = supportFile.get(3);
|
|
|
- //格式化提问格式
|
|
|
- String format = """
|
|
|
+ String document = mergeDocuments(reader.get());
|
|
|
+ return stream(new UserMessage(message),
|
|
|
+ new UserMessage("""
|
|
|
%s :: %s ====> %s
|
|
|
↑ %s
|
|
|
- """;
|
|
|
- String context = String.format(format,
|
|
|
- fileTypeCn, fileName, mergeDocuments(reader.get()), message);
|
|
|
- //设置用户消息
|
|
|
- UserMessage userMessage = new UserMessage(context);
|
|
|
- Log.info(this.getClass(), context);
|
|
|
- //返回 chat 流
|
|
|
- return stream(message, userMessage);
|
|
|
+ """.formatted(fileTypeCn, fileName, document, message))
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -222,120 +196,33 @@ public class ChatService {
|
|
|
*/
|
|
|
private String mergeDocuments(final List<Document> documents) {
|
|
|
return MergeDocuments.mergeDocuments(documents).stream()
|
|
|
- .map(Document::getContent)
|
|
|
- .collect(Collectors.joining("\n"));
|
|
|
+ .map(Document::getContent).collect(Collectors.joining("\n"));
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * 停止 chat.
|
|
|
- */
|
|
|
- public void stop() {
|
|
|
- chatUtil.stop();
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * 获取prompt.
|
|
|
- *
|
|
|
- * @param message 提问内容
|
|
|
- * @param context 提示词
|
|
|
- */
|
|
|
- private void setChatMessage(final String message,
|
|
|
- final Message context) {
|
|
|
- //初始化消息并设置系统提示词
|
|
|
- if (listMessage == null) {
|
|
|
- listMessage = new ArrayList<>();
|
|
|
- listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
|
- }
|
|
|
- //判断消息列表是否大于定义的消息列表长度
|
|
|
- if (listMessage.size() > userConfig.getMaxListMessageLength()) {
|
|
|
- //缩减长度到指定列表长度
|
|
|
- int start =
|
|
|
- listMessage.size() - userConfig.getMaxListMessageLength();
|
|
|
- listMessage = listMessage.subList(start,
|
|
|
- listMessage.size()).stream()
|
|
|
- //过滤定义的定义的系统消息
|
|
|
- .filter(msg -> !msg.getContent()
|
|
|
- .equals(userConfig.getSysMessage()))
|
|
|
- .map(msg -> {
|
|
|
- //判断是否为 chat 消息, 如果是缩减 chat 回答后的长度
|
|
|
- if (msg.getRole().equals(UserRole.ASSISTANT)) {
|
|
|
- String content = msg.getContent();
|
|
|
- if (content.length()
|
|
|
- > userConfig.getMaxMessageLength()) {
|
|
|
- return new AssistantMessage(content.substring(
|
|
|
- 0,
|
|
|
- userConfig.getMaxMessageLength()));
|
|
|
- }
|
|
|
- }
|
|
|
- return msg;
|
|
|
- }).distinct()
|
|
|
- .collect(Collectors.toList());
|
|
|
- //再次将自定义系统提示词添加到列表末尾
|
|
|
- listMessage.add(new SystemMessage(userConfig.getSysMessage()));
|
|
|
- }
|
|
|
- //如果 context 是系统消息则将两个都添加
|
|
|
- if (context.getRole().equals(UserRole.SYSTEM)) {
|
|
|
- listMessage.add(new UserMessage(message));
|
|
|
- //判断系统消息是否为空
|
|
|
- if (!context.getContent().isBlank()) {
|
|
|
- listMessage.add(context);
|
|
|
- }
|
|
|
- }
|
|
|
- if (context.getRole().equals(UserRole.USER)) {
|
|
|
- listMessage.add(context);
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
/**
|
|
|
* chat stream 封装.
|
|
|
*
|
|
|
- * @param message 提问内容
|
|
|
- * @param context 提示词
|
|
|
+ * @param userMessage 提问内容
|
|
|
+ * @param sysOrUserMessage 用户或系统提示词
|
|
|
* @return flux
|
|
|
*/
|
|
|
- private Flux<String> stream(final String message,
|
|
|
- final Message context) {
|
|
|
- //获取用户id
|
|
|
+ private Flux<String> stream(final Message userMessage,
|
|
|
+ final Message sysOrUserMessage) {
|
|
|
String userId = SaManager.getStpLogic("").getLoginIdAsString();
|
|
|
- //初始化
|
|
|
- chatMessage = new StringBuffer();
|
|
|
- //设置消息列表
|
|
|
- setChatMessage(message, context);
|
|
|
- //返回消息流
|
|
|
- return chatUtil.chat(listMessage, true).map(response -> {
|
|
|
- //获取消息主体
|
|
|
- String str = response.getMessage().getContent();
|
|
|
- //添加至消息中
|
|
|
- chatMessage.append(str);
|
|
|
- //格式化字符串
|
|
|
- return str.replace(" ", " ")
|
|
|
- .replace("\t", " ");
|
|
|
- })
|
|
|
- //流完成后拼接的字符
|
|
|
- .concatWithValues("<{完成}>")
|
|
|
- //取消时返回的调度器
|
|
|
- .cancelOn(chatUtil.getScheduler())
|
|
|
- //取消时执行的操作
|
|
|
- .doOnCancel(() -> {
|
|
|
- //上游
|
|
|
- this.stop();
|
|
|
- Log.info(this.getClass(), "{} ==== 用户停止", userId);
|
|
|
- })
|
|
|
- //流完成时的操作
|
|
|
- .doOnComplete(() -> {
|
|
|
- String chatMsg = chatMessage.toString();
|
|
|
- listMessage.add(new AssistantMessage(chatMsg));
|
|
|
- try {
|
|
|
- //将消息添加至数据库
|
|
|
- chatHistoryService.insert(userId, message, chatMsg);
|
|
|
- } catch (InsertFailException e) {
|
|
|
- Log.warn(this.getClass(), e.getMessage());
|
|
|
- }
|
|
|
- Log.info(this.getClass(),
|
|
|
- "{} user message ==> {}", userId, message);
|
|
|
- Log.info(this.getClass(),
|
|
|
- "{} chat message ==> {}", userId, chatMsg);
|
|
|
- });
|
|
|
+ StringBuilder chatMessage = new StringBuilder();
|
|
|
+ return ollamaChatModel.stream(sysOrUserMessage, userMessage).map(response -> {
|
|
|
+ chatMessage.append(response);
|
|
|
+ return response.replace(" ", " ").replace("\t", " ");
|
|
|
+ }).concatWithValues("<{完成}>")
|
|
|
+ .doOnComplete(() -> {
|
|
|
+ String chatMsg = chatMessage.toString();
|
|
|
+ try {
|
|
|
+ chatHistoryService.insert(userId, userMessage.getContent(), chatMsg);
|
|
|
+ } catch (InsertFailException e) {
|
|
|
+ Log.warn(this.getClass(), e.getMessage());
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -345,48 +232,30 @@ public class ChatService {
|
|
|
* @return 元组
|
|
|
*/
|
|
|
private Tuple isSupportFile(
|
|
|
- final MultipartFile file) throws CustomException {
|
|
|
- //初始化
|
|
|
+ final MultipartFile file) throws CustomException {
|
|
|
byte[] bytes;
|
|
|
try {
|
|
|
- //读取文件
|
|
|
bytes = file.getBytes();
|
|
|
} catch (IOException e) {
|
|
|
throw new OpenFileFailException("读取文件失败");
|
|
|
}
|
|
|
- //对文件名分割
|
|
|
- String[] split = Objects.requireNonNull(file.getOriginalFilename())
|
|
|
- .split("\\.");
|
|
|
- //文件类型
|
|
|
+ String[] split = Objects.requireNonNull(file.getOriginalFilename()).split("\\.");
|
|
|
String fileType = split[split.length - 1].toLowerCase(Locale.ROOT);
|
|
|
- //保存文件地址
|
|
|
File savePath = new File(PATH);
|
|
|
- //保存文件
|
|
|
File saveFile = new File(PATH + md5.digestHex(bytes) + "." + fileType);
|
|
|
- //文件名
|
|
|
String fileName = saveFile.getName();
|
|
|
- //判断文件地址是否存在
|
|
|
if (!savePath.exists()) {
|
|
|
- //创建文件夹
|
|
|
if (!savePath.mkdirs()) {
|
|
|
Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
|
|
|
}
|
|
|
}
|
|
|
- //获取文件的各参数
|
|
|
- Tuple fileTypeAndResource = new FileType(
|
|
|
- bytes,
|
|
|
- fileName,
|
|
|
- 5,
|
|
|
- userConfig.getCnocrUrl())
|
|
|
- .getFileTypeAndResource(fileType);
|
|
|
+ Tuple fileTypeAndResource = new FileType(bytes, fileName, 5, userConfig.getCnocrUrl())
|
|
|
+ .getFileTypeAndResource(fileType);
|
|
|
if (fileTypeAndResource == null) {
|
|
|
throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
|
|
|
}
|
|
|
- //文件类型(中文)
|
|
|
String fileTypeCn = fileTypeAndResource.get(0);
|
|
|
- //分割文档对象
|
|
|
DocumentReader reader = fileTypeAndResource.get(1);
|
|
|
- //返回(保存文件对象, 文件名, 文件类型中文名, 分割文档对象)
|
|
|
return new Tuple(saveFile, fileName, fileTypeCn, reader);
|
|
|
}
|
|
|
}
|