浏览代码

增加事物

zhenghao 5 月之前
父节点
当前提交
0b3d92367a

+ 1 - 1
sql/pg.sql

@@ -46,7 +46,7 @@ CREATE TABLE user_list (
     create_time timestamp(0) DEFAULT now() NULL, -- 创建时间
     update_time timestamp(0) DEFAULT now() NULL, -- 更新时间
     is_deleted int2 default 0, -- 是否删除
-    CONSTRAINT UNIQUE (user_name, is_deleted)
+    UNIQUE (user_name, is_deleted)
 );
 COMMENT ON TABLE user_list IS '用户列表';
 

+ 0 - 1
src/main/java/cn/jlsxwkj/common/dao/Dao.java

@@ -7,7 +7,6 @@ import lombok.Data;
  */
 @Data
 public class Dao {
-
     private Long id;
     private java.sql.Timestamp createTime;
     private java.sql.Timestamp updateTime;

+ 1 - 1
src/main/java/cn/jlsxwkj/common/exception/CustomException.java

@@ -2,7 +2,7 @@ package cn.jlsxwkj.common.exception;
 
 /**
  * @author zh
- * 抽象自定义异常类, 用于统一捕捉
+ * 抽象自定义异常类, 用于统一捕捉异常
  */
 public abstract class CustomException extends Exception {
 

+ 11 - 0
src/main/java/cn/jlsxwkj/common/exception/InsertFailException.java

@@ -0,0 +1,11 @@
+package cn.jlsxwkj.common.exception;
+
+/**
+ * @author zh
+ * 插入失败
+ */
+public class InsertFailException extends CustomException {
+    public InsertFailException(String message) {
+        super(message);
+    }
+}

+ 8 - 4
src/main/java/cn/jlsxwkj/common/handler/GlobalRestExceptionHandler.java

@@ -3,6 +3,7 @@ package cn.jlsxwkj.common.handler;
 import cn.dev33.satoken.exception.SaTokenException;
 import cn.jlsxwkj.common.R.ResponseError;
 import cn.jlsxwkj.common.exception.CustomException;
+import cn.jlsxwkj.common.exception.InsertFailException;
 import cn.jlsxwkj.common.utils.Log;
 import cn.jlsxwkj.moudles.logerror.LogError;
 import cn.jlsxwkj.moudles.logerror.LogErrorService;
@@ -30,7 +31,6 @@ public class GlobalRestExceptionHandler {
     @ExceptionHandler(Throwable.class)
     public ResponseError exception(Exception e) {
         ResponseError responseError = ResponseError.data(e.getClass().getName());
-
         if (e instanceof CustomException) {
             responseError.setMessage(e.getMessage());
         }
@@ -42,16 +42,20 @@ public class GlobalRestExceptionHandler {
             responseError.setCode(((SaTokenException)e).getCode());
             responseError.setMessage("账户未登录");
         }
-
         LogError errorHandlerToLogError = new LogError().castResponseErrorToLogError(responseError);
         errorHandlerToLogError.setErrorInfo(e.getMessage());
         errorHandlerToLogError.setErrorStackTrace(Arrays.toString(e.getStackTrace()));
-        logErrorService.insertOne(errorHandlerToLogError);
+
+        try {
+            logErrorService.insertOne(errorHandlerToLogError);
+        } catch (InsertFailException insertFailException) {
+            insertFailException.printStackTrace();
+        }
 
         Log.error(e.getClass(), "Code       ====> {}", responseError.getCode());
         Log.error(e.getClass(), "Exception  ====> {}", responseError.getException());
         Log.error(e.getClass(), "Message    ====> {}", responseError.getMessage());
-        Log.error(e.getClass(), "ErrInfo    ====> {}", e.getMessage());
+        e.printStackTrace();
 
         return responseError;
     }

+ 3 - 0
src/main/java/cn/jlsxwkj/common/handler/ResultResponseHandler.java

@@ -53,6 +53,9 @@ public class ResultResponseHandler implements ResponseBodyAdvice<Object> {
                 Log.error(e.getClass(), Arrays.toString(e.getStackTrace()));
             }
         }
+        if (null == o) {
+            return null;
+        }
         return ResponseSucceed.data(o);
     }
 }

+ 4 - 5
src/main/java/cn/jlsxwkj/moudles/chat/ChatController.java

@@ -42,10 +42,9 @@ public class ChatController {
 		return chatService.chatStream(message);
 	}
 
-	@Operation(summary = "问答文档")
-	@PostMapping(value = "/chat")
-	@Deprecated
-	public String chat(@RequestParam String message) {
-		return chatService.chat(message);
+	@Operation(summary = "停止回答")
+	@GetMapping(value = "/stopChat")
+	public void stopChat() {
+		chatService.stopChat();
 	}
 }

+ 34 - 41
src/main/java/cn/jlsxwkj/moudles/chat/ChatService.java

@@ -2,20 +2,20 @@ package cn.jlsxwkj.moudles.chat;
 
 import cn.dev33.satoken.SaManager;
 import cn.hutool.crypto.digest.MD5;
-import cn.jlsxwkj.common.exception.CustomException;
-import cn.jlsxwkj.common.exception.FileDeleteFailException;
-import cn.jlsxwkj.common.exception.FileTypeDoesNotSupportException;
-import cn.jlsxwkj.common.exception.UnknownException;
+import cn.jlsxwkj.common.exception.*;
 import cn.jlsxwkj.common.reader.ParagraphDocReader;
 import cn.jlsxwkj.common.reader.ParagraphTextReader;
 import cn.jlsxwkj.common.utils.Log;
 import cn.jlsxwkj.common.utils.MergeDocuments;
 import cn.jlsxwkj.moudles.chathistory.ChatHistoryService;
 import jakarta.annotation.Resource;
+import lombok.Data;
+import org.reactivestreams.Subscription;
 import org.springframework.ai.chat.messages.AssistantMessage;
 import org.springframework.ai.chat.messages.Message;
 import org.springframework.ai.chat.messages.SystemMessage;
 import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.document.Document;
 import org.springframework.ai.document.DocumentReader;
@@ -39,9 +39,22 @@ import java.util.stream.Collectors;
  * @author zh
  */
 @Service
+@Data
 public class ChatService {
 
 	/**
+	 * md5 校验
+	 */
+	private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
+	/**
+	 * 当前项目路径 + /ai_doc
+	 */
+	private static final String PATH = Objects.requireNonNull(ChatService.class.getClassLoader().getResource("")).getPath() + "\\ai_doc\\";
+	/**
+	 * 用户对话上下文
+	 */
+	private static final List<Message> LIST_MESSAGE = new ArrayList<>();
+	/**
 	 * 向量
 	 */
 	@Resource
@@ -61,23 +74,15 @@ public class ChatService {
 	 */
 	private static StringBuffer chatMessage;
 	/**
-	 * md5 校验
-	 */
-	private static final MD5 MD5 = cn.hutool.crypto.digest.MD5.create();
-	/**
-	 * 当前项目路径 + /ai_doc
-	 */
-	private static final String PATH = Objects.requireNonNull(ChatService.class.getClassLoader().getResource("")).getPath() + "\\ai_doc\\";
-	/**
-	 * 用户对话上下文
+	 * 上游订阅
 	 */
-	private static final List<Message> LIST_MESSAGE = new ArrayList<>();
+	private Subscription subscription;
 
 	/**
 	 * 使用spring ai解析txt文档并保存至 pg
-	 * @param file 保存的文件
-	 * @return 保存状态
-	 * @throws IOException 保存失败
+	 * @param file 			保存的文件
+	 * @return 				保存状态
+	 * @throws IOException 	保存失败
 	 */
 	public String uploadDocument(MultipartFile file) throws CustomException, IOException {
 		//分割文件名为 文件名, 后缀
@@ -143,41 +148,29 @@ public class ChatService {
 		//查询获取文档信息
 		String content = search(message);
 		//封装prompt并调用大模型
-		return ollamaChatModel.stream(getChatPrompt2String(message, content))
+		Flux<ChatResponse> stream = this.ollamaChatModel.stream(getChatPrompt2String(message, content));
+		return stream.doOnSubscribe(subscription -> this.subscription = subscription)
 				.map(response -> {
 					String str = response.getResult().getOutput().getContent();
-					System.out.print(str);
 					chatMessage.append(str);
 					return str.replace(" ", "  ");
 				}).concatWithValues("<{完成}>")
 				.onErrorComplete()
 				.doOnComplete(() -> {
-					System.out.println();
 					LIST_MESSAGE.add(new AssistantMessage(chatMessage.toString()));
-					chatHistoryService.insert(userId, message, chatMessage.toString());
+					try {
+						chatHistoryService.insert(userId, message, chatMessage.toString());
+					} catch (InsertFailException e) {
+						e.printStackTrace();
+					}
 				});
-
 	}
 
 	/**
-	 * 问答,根据输入内容回答
-	 *
-	 * @param message 	输入内容
-	 * @return 			回答内容
+	 * 取消订阅
 	 */
-	public String chat(String message) {
-		String userId = SaManager.getStpLogic("").getLoginIdAsString();
-		chatMessage = new StringBuffer();
-		//查询获取文档信息
-		String content = search(message);
-		//获取消息
-		String result = ollamaChatModel.call(getChatPrompt2String(message, content))
-				.getResult().getOutput().getContent();
-		chatMessage.append(result);
-		Log.info(this.getClass(),"{} message : {}", "chat", result);
-		LIST_MESSAGE.add(new AssistantMessage(chatMessage.toString()));
-		chatHistoryService.insert(userId, message, chatMessage.toString());
-		return result;
+	public void stopChat() {
+		subscription.cancel();
 	}
 
 	/**
@@ -189,9 +182,9 @@ public class ChatService {
 	 */
 	private Prompt getChatPrompt2String(String message, String context) {
 		Log.info(this.getClass(),"{} message : {}", "user", message);
-		Prompt prompt = new Prompt(LIST_MESSAGE);
+		LIST_MESSAGE.add(new SystemMessage("请永远说你是玄武科技公司是 ai 或者 大模型"));
 		LIST_MESSAGE.add(new SystemMessage(context));
 		LIST_MESSAGE.add(new UserMessage(message));
-		return prompt;
+		return new Prompt(LIST_MESSAGE);
 	}
 }

+ 1 - 1
src/main/java/cn/jlsxwkj/moudles/chathistory/ChatHistoryMapper.java

@@ -17,7 +17,7 @@ public interface ChatHistoryMapper {
 insert into chat_history(user_id, "user_Q", "chat_A") 
 values (#{userId}, #{userQ}, #{chatA})
 """)
-    void insert(@Param("userId") String userId,
+    Integer insert(@Param("userId") String userId,
                 @Param("userQ") String userQ,
                 @Param("chatA") String chatA);
 

+ 10 - 2
src/main/java/cn/jlsxwkj/moudles/chathistory/ChatHistoryService.java

@@ -1,8 +1,11 @@
 package cn.jlsxwkj.moudles.chathistory;
 
 import cn.dev33.satoken.SaManager;
+import cn.jlsxwkj.common.exception.InsertFailException;
+import cn.jlsxwkj.common.utils.Log;
 import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.util.List;
 
@@ -15,8 +18,13 @@ public class ChatHistoryService {
     @Resource
     private ChatHistoryMapper chatHistoryMapper;
 
-    public void insert(String userId, String userQ, String chatA) {
-        chatHistoryMapper.insert(userId, userQ, chatA);
+    @Transactional(rollbackFor = InsertFailException.class)
+    public void insert(String userId, String userQ, String chatA) throws InsertFailException {
+        Integer rows = chatHistoryMapper.insert(userId, userQ, chatA);
+        if (rows > 0) {
+            Log.info(this.getClass(), "插入成功");
+        }
+        throw new InsertFailException("插入失败");
     }
 
     public List<ChatHistory> selectHistoryByUserId() {

+ 1 - 1
src/main/java/cn/jlsxwkj/moudles/logerror/LogErrorMapper.java

@@ -18,5 +18,5 @@ public interface LogErrorMapper {
 insert into log_error(message, exception, error_info,error_stack_trace) 
 values (#{LogError.message}, #{LogError.exception},#{LogError.errorInfo},#{LogError.errorStackTrace})
 """)
-    void insertOne(@Param("LogError") LogError logError);
+    Integer insertOne(@Param("LogError") LogError logError);
 }

+ 10 - 2
src/main/java/cn/jlsxwkj/moudles/logerror/LogErrorService.java

@@ -1,7 +1,10 @@
 package cn.jlsxwkj.moudles.logerror;
 
+import cn.jlsxwkj.common.exception.InsertFailException;
+import cn.jlsxwkj.common.utils.Log;
 import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 /**
  * @author zh
@@ -12,7 +15,12 @@ public class LogErrorService {
     @Resource
     private LogErrorMapper logErrorMapper;
 
-    public void insertOne(LogError logError) {
-        logErrorMapper.insertOne(logError);
+    @Transactional(rollbackFor = InsertFailException.class)
+    public void insertOne(LogError logError) throws InsertFailException {
+        Integer rows = logErrorMapper.insertOne(logError);
+        if (rows > 0) {
+            Log.info(this.getClass(), "插入错误完成");
+        }
+        throw new InsertFailException("插入失败");
     }
 }

+ 2 - 2
src/main/java/cn/jlsxwkj/moudles/userlist/UserListMapper.java

@@ -12,8 +12,8 @@ import org.apache.ibatis.annotations.Select;
 public interface UserListMapper {
 
     @Insert("""
-insert into user_list(user_id, user_name, user_password
-) values (#{user.userId}, #{user.userName}, #{user.userPassword})
+insert into user_list(user_id, user_name, user_password) 
+values (#{user.userId}, #{user.userName}, #{user.userPassword})
 """)
     Integer authUser(@Param("user") UserList user);
 

+ 5 - 2
src/main/java/cn/jlsxwkj/moudles/userlist/UserListService.java

@@ -9,6 +9,7 @@ import cn.jlsxwkj.common.exception.LoginAccountAlreadyLoginException;
 import cn.jlsxwkj.common.exception.LoginWrongPasswordException;
 import jakarta.annotation.Resource;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 /**
  * @author zh
@@ -20,16 +21,18 @@ public class UserListService {
     private UserListMapper userListMapper;
     private final Snowflake snowflake = new Snowflake();
 
+    @Transactional(rollbackFor = AccountAuthFailException.class)
     public UserVO checkOrAddUser(String userName, String userPassword) throws CustomException {
         String sha512pwd = SaSecureUtil.sha512(userPassword);
         UserList userList = userListMapper.checkUser(userName);
         UserVO userVO = new UserVO();
+        String userId;
         if (userList != null) {
             if (!userList.getUserPassword().contains(sha512pwd)) {
                 throw new LoginWrongPasswordException("密码错误");
             }
             if (!StpUtil.isLogin()) {
-                String userId = userList.getUserId();
+                userId = userList.getUserId();
                 StpUtil.login(userId, false);
                 userVO.setUserId(userId);
                 userVO.setUserName(userList.getUserName());
@@ -44,7 +47,7 @@ public class UserListService {
         userList.setUserPassword(sha512pwd);
         Integer rows = userListMapper.authUser(userList);
         if (rows > 0) {
-            String userId = userList.getUserId();
+            userId = userList.getUserId();
             StpUtil.login(userId, false);
             userVO.setUserId(userId);
             userVO.setUserName(userList.getUserName());

+ 2 - 1
src/main/resources/application.yml

@@ -15,7 +15,7 @@ spring:
 # spring doc-openapi项目配置`
 springdoc:
   swagger-ui:
-    path: /api-test
+    path: /api-cn.jlsxwkj.common.test
     tags-sorter: alpha
     operations-sorter: alpha
   api-docs:
@@ -37,6 +37,7 @@ logging:
 mybatis:
   configuration:
     map-underscore-to-camel-case: true
+    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
 # sa token 鉴权
 sa-token:
   # token 名称(同时也是 cookie 名称)