Jelajahi Sumber

使用自定义 chat util 访问上游chat api

zhenghao 4 bulan lalu
induk
melakukan
29c715d385

+ 6 - 0
pom.xml

@@ -62,6 +62,12 @@
 			<groupId>org.projectlombok</groupId>
 			<artifactId>lombok</artifactId>
 		</dependency>
+		<!--configuration-->
+		<dependency>
+			<groupId>org.springframework.boot</groupId>
+			<artifactId>spring-boot-configuration-processor</artifactId>
+			<optional>true</optional>
+		</dependency>
 		<!--鉴权-->
 		<dependency>
 			<groupId>cn.dev33</groupId>

+ 7 - 0
src/main/java/cn/jlsxwkj/common/config/UserConfig.java

@@ -1,6 +1,7 @@
 package cn.jlsxwkj.common.config;
 
 import lombok.Data;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.boot.context.properties.ConfigurationProperties;
 import org.springframework.stereotype.Component;
 
@@ -36,4 +37,10 @@ public class UserConfig {
      * 打印 api 大于指定执行时间
      */
     private Integer execTime;
+
+    @Value("${spring.ai.ollama.base-url}")
+    private String baseUrl;
+
+    @Value("${spring.ai.ollama.chat.model}")
+    private String model;
 }

+ 1 - 1
src/main/java/cn/jlsxwkj/common/handler/ResultResponseHandler.java

@@ -52,7 +52,7 @@ public class ResultResponseHandler implements ResponseBodyAdvice<Object> {
                 Log.error(e.getClass(), Arrays.toString(e.getStackTrace()));
             }
         }
-        if (null == o) {
+        if (o == null) {
             return null;
         }
         return ResponseSucceed.data(o);

+ 35 - 0
src/main/java/cn/jlsxwkj/common/utils/FileType.java

@@ -0,0 +1,35 @@
+package cn.jlsxwkj.common.utils;
+
+import cn.hutool.core.lang.Tuple;
+import cn.jlsxwkj.common.reader.ParagraphDocReader;
+import cn.jlsxwkj.common.reader.ParagraphOcrReader;
+import cn.jlsxwkj.common.reader.ParagraphTextReader;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * @author zh
+ */
+public class FileType {
+
+    private static final Map<String, Tuple> TYPE = new HashMap<>();
+
+    public FileType(byte[] bytes,
+                    String fileName,
+                    int winSize,
+                    String cnocrUrl) {
+        TYPE.put("txt", new Tuple("文本", new ParagraphTextReader(bytes, fileName, winSize)));
+        TYPE.put("doc docx", new Tuple("文档", new ParagraphDocReader(bytes, fileName, winSize)));
+        TYPE.put("jpg png", new Tuple("图片", new ParagraphOcrReader(bytes, fileName, cnocrUrl, winSize)));
+    }
+
+    public Tuple getFileTypeAndResource(String fileType) {
+        for (String s : TYPE.keySet()) {
+            if (s.contains(fileType)) {
+                return TYPE.get(s);
+            }
+        }
+        return null;
+    }
+}

+ 4 - 4
src/main/java/cn/jlsxwkj/moudles/chat/ChatController.java

@@ -35,7 +35,7 @@ public class ChatController {
     }
 
     @Operation(summary = "问答文档流")
-    @PostMapping(value = "/chatStream", produces = {MediaType.TEXT_EVENT_STREAM_VALUE, MediaType.APPLICATION_JSON_VALUE})
+    @PostMapping(value = "/chatStream", produces = {MediaType.TEXT_EVENT_STREAM_VALUE})
     public Flux<String> chatStream(@RequestParam String message) {
         return chatService.chatStream(message);
     }
@@ -47,9 +47,9 @@ public class ChatController {
         return chatService.chatSummaryStream(message, file);
     }
 
-    @Operation(summary = "停止答")
-    @GetMapping(value = "/stopChat")
+    @Operation(summary = "停止答")
+    @GetMapping("/stopChat")
     public void stopChat() {
-        chatService.stopChat();
+        chatService.stop();
     }
 }

+ 16 - 0
src/main/java/cn/jlsxwkj/moudles/chat/ChatRequest.java

@@ -0,0 +1,16 @@
+package cn.jlsxwkj.moudles.chat;
+
+import cn.jlsxwkj.moudles.chat.message.Message;
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * @author zh
+ */
+@Data
+public class ChatRequest {
+    private String model;
+    private List<Message> messages;
+    private Boolean stream;
+}

+ 15 - 0
src/main/java/cn/jlsxwkj/moudles/chat/ChatResponse.java

@@ -0,0 +1,15 @@
+package cn.jlsxwkj.moudles.chat;
+
+import cn.jlsxwkj.moudles.chat.message.AssistantMessage;
+import lombok.Data;
+
+/**
+ * @author zh
+ */
+@Data
+public class ChatResponse {
+    private String model;
+    private String createdAt;
+    private AssistantMessage message;
+    private Boolean done;
+}

+ 36 - 48
src/main/java/cn/jlsxwkj/moudles/chat/ChatService.java

@@ -8,21 +8,15 @@ import cn.jlsxwkj.common.exception.CustomException;
 import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
 import cn.jlsxwkj.common.exception.InsertFailException;
 import cn.jlsxwkj.common.exception.UnknownException;
-import cn.jlsxwkj.common.reader.ParagraphDocReader;
-import cn.jlsxwkj.common.reader.ParagraphOcrReader;
-import cn.jlsxwkj.common.reader.ParagraphTextReader;
+import cn.jlsxwkj.common.utils.FileType;
 import cn.jlsxwkj.common.utils.Log;
 import cn.jlsxwkj.common.utils.MergeDocuments;
+import cn.jlsxwkj.moudles.chat.message.*;
 import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
 import jakarta.annotation.Resource;
-import org.reactivestreams.Subscription;
-import org.springframework.ai.chat.messages.*;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.document.Document;
 import org.springframework.ai.document.DocumentReader;
 import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.ollama.OllamaChatModel;
 import org.springframework.ai.vectorstore.SearchRequest;
 import org.springframework.ai.vectorstore.VectorStore;
 import org.springframework.stereotype.Service;
@@ -47,20 +41,19 @@ public class ChatService {
     private final MD5 md5 = cn.hutool.crypto.digest.MD5.create();
     private final String PATH = System.getProperty("user.dir") + "/ai_doc/";
     private List<Message> listMessage;
+    private StringBuffer chatMessage;
 
     @Resource
     private VectorStore vectorStore;
     @Resource
     private EmbeddingModel embeddingModel;
     @Resource
-    private OllamaChatModel ollamaChatModel;
+    private ChatUtil chatUtil;
     @Resource
     private ChatHistoryService chatHistoryService;
     @Resource
     private UserConfig userConfig;
 
-    private StringBuffer chatMessage;
-    private Subscription subscription;
 
     /**
      * 使用spring ai解析txt文档并保存至 pg
@@ -135,25 +128,30 @@ public class ChatService {
     }
 
     /**
-     * 取消订阅
+     * 合并文档
+     *
+     * @param documents 文档列表
+     * @return 合并后的文档
      */
-    public void stopChat() {
-        subscription.cancel();
-    }
-
     private String mergeDocuments(List<Document> documents) {
         return MergeDocuments.mergeDocuments(documents).stream()
                 .map(Document::getContent).collect(Collectors.joining("\n"));
     }
 
     /**
+     * 停止 chat
+     */
+    public void stop() {
+        chatUtil.stop();
+    }
+
+    /**
      * 获取prompt
      *
      * @param message 提问内容
      * @param context 提示词
-     * @return 提示词
      */
-    private Prompt getChatPrompt(String message, Message context) {
+    private void setChatMessage(String message, Message context) {
         if (listMessage == null) {
             listMessage = new ArrayList<>();
             listMessage.add(new SystemMessage(userConfig.getSysMessage()));
@@ -163,7 +161,7 @@ public class ChatService {
                     .stream()
                     .filter(msg -> !msg.getContent().equals(userConfig.getSysMessage()))
                     .map(msg -> {
-                        if (msg.getMessageType().equals(MessageType.ASSISTANT)) {
+                        if (msg.getRole().equals(UserRole.ASSISTANT)) {
                             String content = msg.getContent();
                             if (content.length() > userConfig.getMaxMessageLength()) {
                                 return new AssistantMessage(content.substring(0, userConfig.getMaxMessageLength()));
@@ -174,16 +172,15 @@ public class ChatService {
                     .collect(Collectors.toList());
             listMessage.add(new SystemMessage(userConfig.getSysMessage()));
         }
-        if (context.getMessageType().equals(MessageType.SYSTEM)) {
+        if (context.getRole().equals(UserRole.SYSTEM)) {
             listMessage.add(new UserMessage(message));
             if (!context.getContent().isBlank()) {
                 listMessage.add(context);
             }
         }
-        if (context.getMessageType().equals(MessageType.USER)) {
-            return new Prompt(context);
+        if (context.getRole().equals(UserRole.USER)) {
+            listMessage.add(context);
         }
-        return new Prompt(listMessage);
     }
 
     /**
@@ -196,15 +193,16 @@ public class ChatService {
     private Flux<String> stream(String message, Message context) {
         String userId = SaManager.getStpLogic("").getLoginIdAsString();
         chatMessage = new StringBuffer();
-        Flux<ChatResponse> stream = ollamaChatModel.stream(getChatPrompt(message, context));
-        return stream.doOnSubscribe(subscription -> this.subscription = subscription).map(
-                response -> {
-                    String str = response.getResult().getOutput().getContent();
-                    chatMessage.append(str);
-                    return str.replace(" ", "  ")
-                            .replace("\t", "    ");
-                }).concatWithValues("<{完成}>")
-                .onErrorComplete()
+        setChatMessage(message, context);
+        return chatUtil.chat(listMessage, true).map(response -> {
+            String str = response.getMessage().getContent();
+            chatMessage.append(str);
+            System.out.print(str);
+            return str.replace(" ", "  ")
+                    .replace("\t", "    ");
+        }).concatWithValues("<{完成}>")
+                .cancelOn(chatUtil.getScheduler())
+                .doOnCancel(this::stop)
                 .doOnComplete(() -> {
                     String chatMsg = chatMessage.toString();
                     listMessage.add(new AssistantMessage(chatMsg));
@@ -236,28 +234,18 @@ public class ChatService {
         File savePath = new File(PATH);
         File saveFile = new File(PATH + md5.digestHex(bytes) + "." + fileType);
         String fileName = saveFile.getName();
-        String fileTypeCn;
         if (!savePath.exists()) {
             if (!savePath.mkdirs()) {
                 Log.warn(this.getClass(), "创建文件夹失败: " + PATH);
             }
         }
-        DocumentReader reader;
-        switch (fileType) {
-            case "txt" -> {
-                reader = new ParagraphTextReader(bytes, fileName, 5);
-                fileTypeCn = "文本";
-            }
-            case "doc", "docx" -> {
-                reader = new ParagraphDocReader(bytes, fileName, 5);
-                fileTypeCn = "文档";
-            }
-            case "jpg", "png" -> {
-                reader = new ParagraphOcrReader(bytes, fileName, userConfig.getCnocrUrl(), 5);
-                fileTypeCn = "图片";
-            }
-            default -> throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
+        Tuple fileTypeAndResource = new FileType(bytes, fileName, 5, userConfig.getCnocrUrl())
+                .getFileTypeAndResource(fileType);
+        if (fileTypeAndResource == null) {
+            throw new FileTypeDoesNotSupportException("暂不支持的文件类型: " + fileType);
         }
+        String fileTypeCn = fileTypeAndResource.get(0);
+        DocumentReader reader = fileTypeAndResource.get(1);
         return new Tuple(saveFile, fileName, fileTypeCn, reader);
     }
 }

+ 89 - 0
src/main/java/cn/jlsxwkj/moudles/chat/ChatUtil.java

@@ -0,0 +1,89 @@
+package cn.jlsxwkj.moudles.chat;
+
+import cn.hutool.http.HttpRequest;
+import cn.hutool.json.JSONUtil;
+import cn.jlsxwkj.common.config.UserConfig;
+import cn.jlsxwkj.common.utils.Log;
+import cn.jlsxwkj.moudles.chat.message.Message;
+import jakarta.annotation.Resource;
+import org.springframework.stereotype.Component;
+import reactor.core.publisher.Flux;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.List;
+
+/**
+ * @author zh
+ * chat 工具类, 用来请求上游 chat api
+ */
+@Component
+public class ChatUtil {
+
+    @Resource
+    private UserConfig userConfig;
+    private InputStream inputStream;
+    private InputStreamReader inputStreamReader;
+    private BufferedReader bufferedReader;
+    private final Scheduler scheduler = Schedulers.single();
+
+    /**
+     * 请求上游 chat api
+     *
+     * @param messages 消息列表(带历史对话)
+     * @param isStream 是否流式返回
+     * @return 流式返回
+     */
+    public Flux<ChatResponse> chat(List<Message> messages, boolean isStream) {
+        ChatRequest chatRequest = new ChatRequest();
+        chatRequest.setModel(userConfig.getModel());
+        chatRequest.setMessages(messages);
+        chatRequest.setStream(isStream);
+        Log.info(this.getClass(), JSONUtil.toJsonStr(chatRequest));
+        Log.info(this.getClass(), userConfig.getModel());
+        Log.info(this.getClass(), userConfig.getBaseUrl());
+        inputStream = HttpRequest.post(userConfig.getBaseUrl() + "/api/chat")
+                .body(JSONUtil.toJsonStr(chatRequest))
+                .execute(isStream)
+                .bodyStream();
+        inputStreamReader = new InputStreamReader(inputStream);
+        bufferedReader = new BufferedReader(inputStreamReader);
+        return Flux.fromStream(bufferedReader.lines())
+                .map(str -> JSONUtil.parse(str).toBean(ChatResponse.class))
+                .subscribeOn(scheduler, false);
+    }
+
+    /**
+     * 停止流
+     */
+    public void stop() {
+        try {
+            bufferedReader.close();
+        } catch (IOException e) {
+            Log.error(e.getClass(), e.getMessage());
+        }
+        try {
+            inputStreamReader.close();
+        } catch (IOException e) {
+            Log.error(e.getClass(), e.getMessage());
+        }
+        try {
+            inputStream.close();
+        } catch (IOException e) {
+            Log.error(e.getClass(), e.getMessage());
+        }
+    }
+
+    /**
+     * 获取调度器
+     *
+     * @return 调度器
+     */
+    public Scheduler getScheduler() {
+        return scheduler;
+    }
+}

+ 18 - 0
src/main/java/cn/jlsxwkj/moudles/chat/message/AssistantMessage.java

@@ -0,0 +1,18 @@
+package cn.jlsxwkj.moudles.chat.message;
+
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+
+/**
+ * @author zh
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+public class AssistantMessage extends Message {
+
+    public AssistantMessage(String message) {
+        this.setRole(UserRole.ASSISTANT);
+        this.setContent(message);
+    }
+
+}

+ 12 - 0
src/main/java/cn/jlsxwkj/moudles/chat/message/Message.java

@@ -0,0 +1,12 @@
+package cn.jlsxwkj.moudles.chat.message;
+
+import lombok.Data;
+
+/**
+ * @author zh
+ */
+@Data
+public abstract class Message {
+    private String role;
+    private String content;
+}

+ 18 - 0
src/main/java/cn/jlsxwkj/moudles/chat/message/SystemMessage.java

@@ -0,0 +1,18 @@
+package cn.jlsxwkj.moudles.chat.message;
+
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+
+/**
+ * @author zh
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+public class SystemMessage extends Message {
+
+    public SystemMessage(String message) {
+        this.setRole(UserRole.SYSTEM);
+        this.setContent(message);
+    }
+
+}

+ 17 - 0
src/main/java/cn/jlsxwkj/moudles/chat/message/UserMessage.java

@@ -0,0 +1,17 @@
+package cn.jlsxwkj.moudles.chat.message;
+
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+
+/**
+ * @author zh
+ */
+@EqualsAndHashCode(callSuper = true)
+@Data
+public class UserMessage extends Message {
+
+    public UserMessage(String message) {
+        this.setRole(UserRole.USER);
+        this.setContent(message);
+    }
+}

+ 12 - 0
src/main/java/cn/jlsxwkj/moudles/chat/message/UserRole.java

@@ -0,0 +1,12 @@
+package cn.jlsxwkj.moudles.chat.message;
+
+/**
+ * @author zh
+ */
+public final class UserRole {
+
+    public static final String USER = "user";
+    public static final String SYSTEM = "system";
+    public static final String ASSISTANT = "assistant";
+
+}

+ 1 - 3
src/main/resources/application-dev.yml

@@ -11,9 +11,7 @@ spring:
     ollama:
       base-url: http://localhost:11434
       chat:
-        model: qwen2:1.5b
-        options:
-          temperature: 0
+        model: qwen2:7b
       embedding:
         model: mofanke/dmeta-embedding-zh
 customer: