From 59eb69747b12c6ab051ea03bfeeaffac43a4d6ce Mon Sep 17 00:00:00 2001
From: Hanserwei <2628273921@qq.com>
Date: Mon, 3 Nov 2025 16:31:19 +0800
Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E5=AE=9E=E7=8E=B0=E5=AF=B9?=
=?UTF-8?q?=E8=AF=9D=E8=AE=B0=E5=BF=86=E4=B8=8E=E6=B6=88=E6=81=AF=E6=8C=81?=
=?UTF-8?q?=E4=B9=85=E5=8C=96=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 移除 Cassandra 相关配置及依赖
- 新增 CustomChatMemoryAdvisor 实现对话记忆管理
-重命名并扩展 CustomStreamLoggerAdvisor 为 CustomStreamLoggerAndMessage2DBAdvisor,增加消息入库逻辑
- 在 ChatController 中集成新的 Advisor 并注入相关依赖
- 使用 TransactionTemplate 管理消息存储事务
-限制记忆消息数量为最新 50 条
- 支持将用户消息与 AI 回答同步写入数据库
---
pom.xml | 5 -
.../advisor/CustomChatMemoryAdvisor.java | 100 ++++++++++++++++
.../advisor/CustomStreamLoggerAdvisor.java | 63 ----------
...ustomStreamLoggerAndMessage2DBAdvisor.java | 108 ++++++++++++++++++
.../airobot/controller/ChatController.java | 13 ++-
src/main/resources/application.yml | 4 -
6 files changed, 219 insertions(+), 74 deletions(-)
create mode 100644 src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java
delete mode 100644 src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java
create mode 100644 src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java
diff --git a/pom.xml b/pom.xml
index 51f2eb1..af2cca0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -43,11 +43,6 @@
com.alibaba.cloud.ai
spring-ai-alibaba-starter-dashscope
-
-
- org.springframework.ai
- spring-ai-starter-model-chat-memory-repository-cassandra
-
org.springframework.boot
spring-boot-starter-test
diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java
new file mode 100644
index 0000000..92ec3a5
--- /dev/null
+++ b/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java
@@ -0,0 +1,100 @@
+package com.hanserwei.airobot.advisor;
+
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
+import com.google.common.collect.Lists;
+import com.hanserwei.airobot.domain.dos.ChatMessageDO;
+import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
+import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
+import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.ai.chat.client.ChatClientRequest;
+import org.springframework.ai.chat.client.ChatClientResponse;
+import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
+import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.messages.UserMessage;
+import reactor.core.publisher.Flux;
+
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+
+@Slf4j
+public class CustomChatMemoryAdvisor implements StreamAdvisor {
+
+ private final ChatMessageMapper chatMessageMapper;
+ private final AiChatReqVO aiChatReqVO;
+ private final int limit;
+
+ public CustomChatMemoryAdvisor(ChatMessageMapper chatMessageMapper, AiChatReqVO aiChatReqVO, int limit) {
+ this.chatMessageMapper = chatMessageMapper;
+ this.aiChatReqVO = aiChatReqVO;
+ this.limit = limit;
+ }
+
+ @Override
+ public int getOrder() {
+ return 2; // order 值越小,越先执行
+ }
+
+ @NotNull
+ @Override
+ public String getName() {
+ return this.getClass().getSimpleName();
+ }
+
+ @NotNull
+ @Override
+ public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
+ log.info("## 自定义聊天记忆 Advisor...");
+
+ // 对话 UUID
+ String chatUuid = aiChatReqVO.getChatId();
+
+ // 查询数据库拉取最新的聊天消息
+ List messages = chatMessageMapper.selectList(Wrappers.lambdaQuery()
+ .eq(ChatMessageDO::getChatUuid, chatUuid) // 查询指定对话 UUID 下的聊天记录
+ .orderByDesc(ChatMessageDO::getCreateTime) // 查询最新的消息
+ .last(String.format("LIMIT %d", limit))); // 仅查询 LIMIT 条
+
+ // 按发布时间升序排列
+ List sortedMessages = messages.stream()
+ .sorted(Comparator.comparing(ChatMessageDO::getCreateTime)) // 升序排列
+ .toList();
+
+ // 所有消息
+ List messageList = getMessageList(sortedMessages);
+
+ // 除了记忆消息,还需要添加当前用户消息
+ messageList.addAll(chatClientRequest.prompt().getInstructions());
+
+ // 构建一个新的 ChatClientRequest 请求对象
+ ChatClientRequest processedChatClientRequest = chatClientRequest
+ .mutate()
+ .prompt(chatClientRequest.prompt().mutate().messages(messageList).build())
+ .build();
+
+ return streamAdvisorChain.nextStream(processedChatClientRequest);
+ }
+
+ @NotNull
+ private static List getMessageList(List sortedMessages) {
+ List messageList = Lists.newArrayList();
+
+ // 将数据库记录转换为对应类型的消息
+ for (ChatMessageDO chatMessageDO : sortedMessages) {
+ // 消息类型
+ String type = chatMessageDO.getRole();
+ if (Objects.equals(type, MessageType.USER.getValue())) { // 用户消息
+ Message userMessage = new UserMessage(chatMessageDO.getContent());
+ messageList.add(userMessage);
+ } else if (Objects.equals(type, MessageType.ASSISTANT.getValue())) { // AI 助手消息
+ Message assistantMessage = new AssistantMessage(chatMessageDO.getContent());
+ messageList.add(assistantMessage);
+ }
+ }
+ return messageList;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java
deleted file mode 100644
index f2777f4..0000000
--- a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java
+++ /dev/null
@@ -1,63 +0,0 @@
-package com.hanserwei.airobot.advisor;
-
-import lombok.extern.slf4j.Slf4j;
-import org.jetbrains.annotations.NotNull;
-import org.springframework.ai.chat.client.ChatClientRequest;
-import org.springframework.ai.chat.client.ChatClientResponse;
-import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
-import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
-import reactor.core.publisher.Flux;
-
-import java.util.concurrent.atomic.AtomicReference;
-
-@Slf4j
-public class CustomStreamLoggerAdvisor implements StreamAdvisor {
-
- @Override
- public int getOrder() {
- return 99; // order 值越小,越先执行
- }
-
- @NotNull
- @Override
- public String getName() {
- return this.getClass().getSimpleName();
- }
-
- @NotNull
- @Override
- public Flux adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
-
- Flux chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
-
- // 创建 AI 流式回答聚合容器(线程安全)
- AtomicReference fullContent = new AtomicReference<>(new StringBuilder());
-
- // 返回处理后的流
- return chatClientResponseFlux
- .doOnNext(response -> {
- // 逐块收集内容
- String chunk = null;
- if (response.chatResponse() != null) {
- chunk = response.chatResponse().getResult().getOutput().getText();
- }
-
- log.info("## chunk: {}", chunk);
-
- // 若 chunk 块不为空,则追加到 fullContent 中
- if (chunk != null) {
- fullContent.get().append(chunk);
- }
- })
- .doOnComplete(() -> {
- // 流完成后打印完整回答
- String completeResponse = fullContent.get().toString();
- log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
- })
- .doOnError(error -> {
- // 出错时打印已收集的部分
- String partialResponse = fullContent.get().toString();
- log.error("## Stream 流出现错误,已收集回答如下: {}", partialResponse, error);
- });
- }
-}
diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java
new file mode 100644
index 0000000..100fc92
--- /dev/null
+++ b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java
@@ -0,0 +1,108 @@
+package com.hanserwei.airobot.advisor;
+
+import com.hanserwei.airobot.domain.dos.ChatMessageDO;
+import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
+import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
+import lombok.extern.slf4j.Slf4j;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.ai.chat.client.ChatClientRequest;
+import org.springframework.ai.chat.client.ChatClientResponse;
+import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
+import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
+import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.transaction.support.TransactionTemplate;
+import reactor.core.publisher.Flux;
+
+import java.time.LocalDateTime;
+import java.util.concurrent.atomic.AtomicReference;
+
+@Slf4j
+public class CustomStreamLoggerAndMessage2DBAdvisor implements StreamAdvisor {
+
+ private final ChatMessageMapper chatMessageMapper;
+ private final AiChatReqVO aiChatReqVO;
+ private final TransactionTemplate transactionTemplate;
+
+ public CustomStreamLoggerAndMessage2DBAdvisor(ChatMessageMapper chatMessageMapper,
+ AiChatReqVO aiChatReqVO,
+ TransactionTemplate transactionTemplate) {
+ this.chatMessageMapper = chatMessageMapper;
+ this.aiChatReqVO = aiChatReqVO;
+ this.transactionTemplate = transactionTemplate;
+ }
+
+ @Override
+ public int getOrder() {
+ return 99; // order 值越小,越先执行
+ }
+
+ @NotNull
+ @Override
+ public String getName() {
+ return this.getClass().getSimpleName();
+ }
+
+ @NotNull
+ @Override
+ public Flux adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
+
+ Flux chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
+
+ // 对话 UUID
+ String chatUuid = aiChatReqVO.getChatId();
+ // 用户消息
+ String userMessage = aiChatReqVO.getMessage();
+
+ // 创建 AI 流式回答聚合容器(线程安全)
+ AtomicReference fullContent = new AtomicReference<>(new StringBuilder());
+
+ // 返回处理后的流
+ return chatClientResponseFlux
+ .doOnNext(response -> {
+ // 逐块收集内容
+ String chunk = null;
+ if (response.chatResponse() != null) {
+ chunk = response.chatResponse().getResult().getOutput().getText();
+ }
+
+ log.info("## chunk: {}", chunk);
+
+ // 若 chunk 块不为空,则追加到 fullContent 中
+ if (chunk != null) {
+ fullContent.get().append(chunk);
+ }
+ })
+ .doOnComplete(() -> {
+ // 流完成后打印完整回答
+ String completeResponse = fullContent.get().toString();
+ log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
+
+ // 开启编程式事务
+ transactionTemplate.execute(status -> {
+ try {
+ // 1. 存储用户消息
+ chatMessageMapper.insert(ChatMessageDO.builder()
+ .chatUuid(chatUuid)
+ .content(userMessage)
+ .role(MessageType.USER.getValue()) // 用户消息
+ .createTime(LocalDateTime.now())
+ .build());
+
+ // 2. 存储 AI 回答
+ chatMessageMapper.insert(ChatMessageDO.builder()
+ .chatUuid(chatUuid)
+ .content(completeResponse)
+ .role(MessageType.ASSISTANT.getValue()) // AI 回答
+ .createTime(LocalDateTime.now())
+ .build());
+
+ return true;
+ } catch (Exception ex) {
+ status.setRollbackOnly(); // 标记事务为回滚
+ log.error("", ex);
+ }
+ return false;
+ });
+ });
+ }
+}
diff --git a/src/main/java/com/hanserwei/airobot/controller/ChatController.java b/src/main/java/com/hanserwei/airobot/controller/ChatController.java
index 178e6b9..43e1eca 100644
--- a/src/main/java/com/hanserwei/airobot/controller/ChatController.java
+++ b/src/main/java/com/hanserwei/airobot/controller/ChatController.java
@@ -4,8 +4,10 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.google.common.collect.Lists;
-import com.hanserwei.airobot.advisor.CustomStreamLoggerAdvisor;
+import com.hanserwei.airobot.advisor.CustomChatMemoryAdvisor;
+import com.hanserwei.airobot.advisor.CustomStreamLoggerAndMessage2DBAdvisor;
import com.hanserwei.airobot.aspect.ApiOperationLog;
+import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
import com.hanserwei.airobot.model.vo.chat.AiResponse;
import com.hanserwei.airobot.model.vo.chat.NewChatReqVO;
@@ -18,6 +20,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
+import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@@ -34,6 +37,10 @@ public class ChatController {
@Resource
private ChatService chatService;
+ @Resource
+ private ChatMessageMapper chatMessageMapper;
+ @Resource
+ private TransactionTemplate transactionTemplate;
@Value("${spring.ai.dashscope.api-key}")
private String apiKey;
@@ -72,8 +79,10 @@ public class ChatController {
// Advisor 集合
List advisors = Lists.newArrayList();
+ // 添加自定义对话记忆 Advisor(以最新的 50 条消息作为记忆)
+ advisors.add(new CustomChatMemoryAdvisor(chatMessageMapper, aiChatReqVO, 50));
// 添加自定义打印流式对话日志 Advisor
- advisors.add(new CustomStreamLoggerAdvisor());
+ advisors.add(new CustomStreamLoggerAndMessage2DBAdvisor(chatMessageMapper, aiChatReqVO, transactionTemplate));
// 应用 Advisor 集合
chatClientRequestSpec.advisors(advisors);
diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml
index beff518..bd14655 100644
--- a/src/main/resources/application.yml
+++ b/src/main/resources/application.yml
@@ -17,10 +17,6 @@ spring:
maximum-pool-size: 20 # 最大连接池大小
connection-test-query: SELECT 1 # 连接测试查询
validation-timeout: 5000 # 验证连接的有效性
- cassandra:
- contact-points: 127.0.0.1 # Cassandra 集群节点地址(可配置多个,用逗号分隔)
- port: 9042 # 端口号
- local-datacenter: datacenter1 # 必须与集群配置的数据中心名称一致(大小写敏感)
ai:
dashscope:
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)