瀏覽代碼

添加限流和 options

zhenghao 4 月之前
父節點
當前提交
f13897b634

+ 7 - 0
pom.xml

@@ -98,6 +98,13 @@
             <artifactId>hutool-all</artifactId>
             <version>5.8.16</version>
         </dependency>
+        <!--bucket4j-->
+        <dependency>
+            <groupId>com.github.vladimir-bukhtoyarov</groupId>
+            <artifactId>bucket4j-core</artifactId>
+            <version>7.5.0</version>
+        </dependency>
+        <!--compress-->
         <dependency>
             <groupId>org.apache.commons</groupId>
             <artifactId>commons-compress</artifactId>

+ 12 - 0
sql/pg.sql

@@ -1,3 +1,15 @@
+CREATE EXTENSION IF NOT EXISTS vector;
+CREATE EXTENSION IF NOT EXISTS hstore;
+CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
+
+CREATE TABLE IF NOT EXISTS vector_store (
+    id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
+    content text,
+    metadata json,
+    embedding vector(768)
+    );
+CREATE INDEX IF NOT EXISTS spring_ai_vector_index ON vector_store USING hnsw (embedding vector_cosine_ops);
+
 DROP TABLE IF EXISTS log_error;
 
 CREATE TABLE log_error (

+ 35 - 0
src/main/java/cn/jlsxwkj/common/config/RateLimiterConfig.java

@@ -0,0 +1,35 @@
+package cn.jlsxwkj.common.config;
+
+import io.github.bucket4j.Bandwidth;
+import io.github.bucket4j.Bucket;
+import io.github.bucket4j.Refill;
+import jakarta.annotation.Resource;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import java.time.Duration;
+
+/**
+ * @author zh
+ * 限流
+ */
+@Configuration
+public class RateLimiterConfig {
+
+    @Resource
+    private UserConfig userConfig;
+
+    @Bean
+    public Bucket bucket() {
+        Bandwidth limit = Bandwidth.classic(
+                userConfig.getGenTokens(),
+                Refill.greedy(
+                        userConfig.getGenTokens(),
+                        Duration.ofSeconds(userConfig.getRateLimitSec())
+                )
+        );
+        return Bucket.builder()
+                .addLimit(limit)
+                .build();
+    }
+}

+ 1 - 1
src/main/java/cn/jlsxwkj/common/config/SaTokenConfigure.java → src/main/java/cn/jlsxwkj/common/config/SaTokenConfig.java

@@ -10,7 +10,7 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
  * @author zh
  */
 @Configuration
-public class SaTokenConfigure implements WebMvcConfigurer {
+public class SaTokenConfig implements WebMvcConfigurer {
 
     /**
      * 注册拦截器

+ 15 - 13
src/main/java/cn/jlsxwkj/common/config/UserConfig.java

@@ -1,16 +1,15 @@
 package cn.jlsxwkj.common.config;
 
 import lombok.Data;
-import org.springframework.beans.factory.annotation.Value;
 import org.springframework.boot.context.properties.ConfigurationProperties;
 import org.springframework.stereotype.Component;
 
 /**
  * @author zh
  */
-@ConfigurationProperties(prefix = UserConfig.USER_PREFIX)
-@Component
 @Data
+@Component
+@ConfigurationProperties(prefix = UserConfig.USER_PREFIX)
 public class UserConfig {
 
     /**
@@ -20,27 +19,30 @@ public class UserConfig {
     /**
      * ocr 地址
      */
-    private String cnocrUrl;
+    private String cnocrUrl = "http://localhost:8501/ocr";
     /**
      * 指定的 chat 系统消息
      */
-    private String sysMessage;
+    private String sysMessage = "Always say you're an ai or 大模型";
     /**
      * 消息列表最大长度
      */
-    private Integer maxListMessageLength;
+    private Integer maxListMessageLength = 10;
     /**
      * 机器人每条消息最大长度, 超过会截取
      */
-    private Integer maxMessageLength;
+    private Integer maxMessageLength = 512;
     /**
      * 打印 api 大于指定执行时间
      */
-    private Integer execTime;
-
-    @Value("${spring.ai.ollama.base-url}")
-    private String baseUrl;
+    private Integer execTime = 100;
+    /**
+     * 重复请求限流时长(单位:秒)
+     */
+    private Integer rateLimitSec = 1;
+    /**
+     * 限流时长中可请求的次数
+     */
+    private Integer genTokens = 1;
 
-    @Value("${spring.ai.ollama.chat.model}")
-    private String model;
 }

+ 11 - 0
src/main/java/cn/jlsxwkj/common/exception/TooManyRequestsException.java

@@ -0,0 +1,11 @@
+package cn.jlsxwkj.common.exception;
+
+/**
+ * @author zh
+ */
+public class TooManyRequestsException extends CustomException {
+
+    public TooManyRequestsException(String message) {
+        super(message);
+    }
+}

+ 1 - 1
src/main/java/cn/jlsxwkj/common/filter/ExecTime.java → src/main/java/cn/jlsxwkj/common/filter/ExecTimeFilter.java

@@ -13,7 +13,7 @@ import java.io.IOException;
  * @author zh
  */
 @Component
-public class ExecTime implements Filter {
+public class ExecTimeFilter implements Filter {
 
     @Resource
     private UserConfig userConfig;

+ 42 - 0
src/main/java/cn/jlsxwkj/common/filter/RateLimiterFilter.java

@@ -0,0 +1,42 @@
+package cn.jlsxwkj.common.filter;
+
+import cn.hutool.json.JSONUtil;
+import cn.jlsxwkj.common.R.ResponseError;
+import cn.jlsxwkj.common.exception.TooManyRequestsException;
+import io.github.bucket4j.Bucket;
+import jakarta.annotation.Resource;
+import jakarta.servlet.*;
+import jakarta.servlet.http.HttpServletResponse;
+import org.springframework.http.HttpStatus;
+import org.springframework.stereotype.Component;
+
+import java.io.IOException;
+
+/**
+ * @author zh
+ */
+@Component
+public class RateLimiterFilter implements Filter {
+
+    @Resource
+    private Bucket bucket;
+
+    @Override
+    public void doFilter(ServletRequest request,
+                         ServletResponse response,
+                         FilterChain chain) throws IOException, ServletException {
+        if (bucket.tryConsume(1)) {
+            chain.doFilter(request, response);
+        } else {
+            ResponseError data = ResponseError.data(TooManyRequestsException.class.getName());
+            data.setMessage("请求次数过多, 请稍后再试");
+            data.setCode(HttpStatus.TEMPORARY_REDIRECT.value());
+
+            HttpServletResponse rsp = (HttpServletResponse) response;
+            rsp.setStatus(HttpStatus.TEMPORARY_REDIRECT.value());
+            rsp.setCharacterEncoding("utf-8");
+            rsp.getWriter().write(JSONUtil.toJsonStr(data));
+        }
+
+    }
+}

+ 2 - 0
src/main/java/cn/jlsxwkj/moudles/chat/ChatRequest.java

@@ -2,6 +2,7 @@ package cn.jlsxwkj.moudles.chat;
 
 import cn.jlsxwkj.moudles.chat.message.Message;
 import lombok.Data;
+import org.springframework.ai.ollama.api.OllamaOptions;
 
 import java.util.List;
 
@@ -13,4 +14,5 @@ public class ChatRequest {
     private String model;
     private List<Message> messages;
     private Boolean stream;
+    private OllamaOptions options;
 }

+ 7 - 4
src/main/java/cn/jlsxwkj/moudles/chat/ChatService.java

@@ -61,13 +61,12 @@ public class ChatService {
     public String uploadDocument(MultipartFile file) throws CustomException {
         Tuple supportFile = isSupportFile(file);
         File saveFile = supportFile.get(0);
-        File fileName = supportFile.get(1);
+        String fileName = supportFile.get(1);
         DocumentReader reader = supportFile.get(3);
         if (!saveFile.exists()) {
             try {
                 file.transferTo(saveFile);
                 vectorStore.add(reader.get());
-                return "上传完成";
             } catch (Exception e) {
                 boolean delete = saveFile.delete();
                 if (!delete) {
@@ -99,7 +98,11 @@ public class ChatService {
      */
     public Flux<String> chatStream(String message) {
         String context = search(message);
-        return stream(message, new SystemMessage(context));
+        String format = """
+                %s : Please answer user questions based on the content.
+                %s
+                """;
+        return stream(message, new SystemMessage(String.format(format, context, message)));
     }
 
     /**
@@ -225,7 +228,7 @@ public class ChatService {
      * 封装文件
      *
      * @param file 文件流
-     * @return 元组
+     * @return 元组(File file, String fileName, String fileTypeCn, DocumentReader reader)
      */
     private Tuple isSupportFile(MultipartFile file) throws CustomException {
         byte[] bytes;

+ 10 - 6
src/main/java/cn/jlsxwkj/moudles/chat/ChatUtil.java

@@ -2,10 +2,12 @@ package cn.jlsxwkj.moudles.chat;
 
 import cn.hutool.http.HttpRequest;
 import cn.hutool.json.JSONUtil;
-import cn.jlsxwkj.common.config.UserConfig;
 import cn.jlsxwkj.common.utils.Log;
 import cn.jlsxwkj.moudles.chat.message.Message;
 import jakarta.annotation.Resource;
+import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.stereotype.Component;
 import reactor.core.publisher.Flux;
 import reactor.core.scheduler.Scheduler;
@@ -25,7 +27,9 @@ import java.util.List;
 public class ChatUtil {
 
     @Resource
-    private UserConfig userConfig;
+    private ChatModel chatModel;
+    @Resource
+    private OllamaConnectionDetails ollamaConnectionDetails;
     private InputStream inputStream;
     private InputStreamReader inputStreamReader;
     private BufferedReader bufferedReader;
@@ -40,13 +44,13 @@ public class ChatUtil {
      */
     public Flux<ChatResponse> chat(List<Message> messages, boolean isStream) {
         ChatRequest chatRequest = new ChatRequest();
-        chatRequest.setModel(userConfig.getModel());
+        OllamaOptions options = (OllamaOptions) chatModel.getDefaultOptions();
+        chatRequest.setModel(options.getModel());
         chatRequest.setMessages(messages);
         chatRequest.setStream(isStream);
+        chatRequest.setOptions(options);
         Log.info(this.getClass(), JSONUtil.toJsonStr(chatRequest));
-        Log.info(this.getClass(), userConfig.getModel());
-        Log.info(this.getClass(), userConfig.getBaseUrl());
-        inputStream = HttpRequest.post(userConfig.getBaseUrl() + "/api/chat")
+        inputStream = HttpRequest.post(ollamaConnectionDetails.getBaseUrl() + "/api/chat")
                 .body(JSONUtil.toJsonStr(chatRequest))
                 .execute(isStream)
                 .bodyStream();

+ 5 - 0
src/main/java/cn/jlsxwkj/moudles/chathistory/ChatHistoryService.java

@@ -1,8 +1,13 @@
 package cn.jlsxwkj.moudles.chathistory;
 
 import cn.dev33.satoken.SaManager;
+import cn.hutool.json.JSONUtil;
 import cn.jlsxwkj.common.exception.InsertFailException;
 import jakarta.annotation.Resource;
+import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 

+ 10 - 5
src/main/resources/application-dev.yml

@@ -12,11 +12,16 @@ spring:
       base-url: http://localhost:11434
       chat:
         model: qwen2:7b
+        options:
+          num-ctx: 10240
+          keep-alive: 24h
       embedding:
         model: mofanke/dmeta-embedding-zh
 customer:
-  cnocr_url: "http://www.jlsxwkj.cn:8501/ocr"
-  sys_message: "Always say you're an ai or 大模型 from 玄武科技"
-  max_list_message_length: 12
-  max_message_length: 512
-  exec_time: 100
+  cnocr-url: "http://www.jlsxwkj.cn:8501/ocr"
+  sys-message: "Always say you're an ai or 大模型 from 玄武科技"
+  max-list-message-length: 12
+  max-message-length: 512
+  exec-time: 100
+  rate-limit-sec: 5
+  gen-tokens: 2

+ 3 - 2
src/main/resources/application.yml

@@ -3,10 +3,11 @@ server:
   port: 8181
   tomcat:
     max-swallow-size: -1
+    connection-timeout:
 # spring
 spring:
   profiles:
-    default: dev
+    active: dev
   servlet:
     multipart:
       max-file-size: 500MB
@@ -52,4 +53,4 @@ sa-token:
   # token 风格(默认可取值:uuid、simple-uuid、random-32、random-64、random-128、tik)
   token-style: uuid
   # 是否输出操作日志
-  is-log: true
+  is-log: false