feat(ai): 实现对话记忆与消息持久化功能

- 移除 Cassandra 相关配置及依赖
- 新增 CustomChatMemoryAdvisor 实现对话记忆管理
-重命名并扩展 CustomStreamLoggerAdvisor 为 CustomStreamLoggerAndMessage2DBAdvisor,增加消息入库逻辑
- 在 ChatController 中集成新的 Advisor 并注入相关依赖
- 使用 TransactionTemplate 管理消息存储事务
-限制记忆消息数量为最新 50 条
- 支持将用户消息与 AI 回答同步写入数据库
This commit is contained in:
2025-11-03 16:31:19 +08:00
parent f3f320f390
commit 59eb69747b
6 changed files with 219 additions and 74 deletions

View File

@@ -43,11 +43,6 @@
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
</dependency>
<!-- Cassandra -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-repository-cassandra</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>

View File

@@ -0,0 +1,100 @@
package com.hanserwei.airobot.advisor;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.google.common.collect.Lists;
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import reactor.core.publisher.Flux;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@Slf4j
public class CustomChatMemoryAdvisor implements StreamAdvisor {
private final ChatMessageMapper chatMessageMapper;
private final AiChatReqVO aiChatReqVO;
private final int limit;
public CustomChatMemoryAdvisor(ChatMessageMapper chatMessageMapper, AiChatReqVO aiChatReqVO, int limit) {
this.chatMessageMapper = chatMessageMapper;
this.aiChatReqVO = aiChatReqVO;
this.limit = limit;
}
@Override
public int getOrder() {
return 2; // order 值越小,越先执行
}
@NotNull
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@NotNull
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
log.info("## 自定义聊天记忆 Advisor...");
// 对话 UUID
String chatUuid = aiChatReqVO.getChatId();
// 查询数据库拉取最新的聊天消息
List<ChatMessageDO> messages = chatMessageMapper.selectList(Wrappers.<ChatMessageDO>lambdaQuery()
.eq(ChatMessageDO::getChatUuid, chatUuid) // 查询指定对话 UUID 下的聊天记录
.orderByDesc(ChatMessageDO::getCreateTime) // 查询最新的消息
.last(String.format("LIMIT %d", limit))); // 仅查询 LIMIT 条
// 按发布时间升序排列
List<ChatMessageDO> sortedMessages = messages.stream()
.sorted(Comparator.comparing(ChatMessageDO::getCreateTime)) // 升序排列
.toList();
// 所有消息
List<Message> messageList = getMessageList(sortedMessages);
// 除了记忆消息,还需要添加当前用户消息
messageList.addAll(chatClientRequest.prompt().getInstructions());
// 构建一个新的 ChatClientRequest 请求对象
ChatClientRequest processedChatClientRequest = chatClientRequest
.mutate()
.prompt(chatClientRequest.prompt().mutate().messages(messageList).build())
.build();
return streamAdvisorChain.nextStream(processedChatClientRequest);
}
@NotNull
private static List<Message> getMessageList(List<ChatMessageDO> sortedMessages) {
List<Message> messageList = Lists.newArrayList();
// 将数据库记录转换为对应类型的消息
for (ChatMessageDO chatMessageDO : sortedMessages) {
// 消息类型
String type = chatMessageDO.getRole();
if (Objects.equals(type, MessageType.USER.getValue())) { // 用户消息
Message userMessage = new UserMessage(chatMessageDO.getContent());
messageList.add(userMessage);
} else if (Objects.equals(type, MessageType.ASSISTANT.getValue())) { // AI 助手消息
Message assistantMessage = new AssistantMessage(chatMessageDO.getContent());
messageList.add(assistantMessage);
}
}
return messageList;
}
}

View File

@@ -1,63 +0,0 @@
package com.hanserwei.airobot.advisor;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import reactor.core.publisher.Flux;
import java.util.concurrent.atomic.AtomicReference;
@Slf4j
public class CustomStreamLoggerAdvisor implements StreamAdvisor {
@Override
public int getOrder() {
return 99; // order 值越小,越先执行
}
@NotNull
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@NotNull
@Override
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
// 创建 AI 流式回答聚合容器(线程安全)
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
// 返回处理后的流
return chatClientResponseFlux
.doOnNext(response -> {
// 逐块收集内容
String chunk = null;
if (response.chatResponse() != null) {
chunk = response.chatResponse().getResult().getOutput().getText();
}
log.info("## chunk: {}", chunk);
// 若 chunk 块不为空,则追加到 fullContent 中
if (chunk != null) {
fullContent.get().append(chunk);
}
})
.doOnComplete(() -> {
// 流完成后打印完整回答
String completeResponse = fullContent.get().toString();
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
})
.doOnError(error -> {
// 出错时打印已收集的部分
String partialResponse = fullContent.get().toString();
log.error("## Stream 流出现错误,已收集回答如下: {}", partialResponse, error);
});
}
}

View File

@@ -0,0 +1,108 @@
package com.hanserwei.airobot.advisor;
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.transaction.support.TransactionTemplate;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.concurrent.atomic.AtomicReference;
@Slf4j
public class CustomStreamLoggerAndMessage2DBAdvisor implements StreamAdvisor {
private final ChatMessageMapper chatMessageMapper;
private final AiChatReqVO aiChatReqVO;
private final TransactionTemplate transactionTemplate;
public CustomStreamLoggerAndMessage2DBAdvisor(ChatMessageMapper chatMessageMapper,
AiChatReqVO aiChatReqVO,
TransactionTemplate transactionTemplate) {
this.chatMessageMapper = chatMessageMapper;
this.aiChatReqVO = aiChatReqVO;
this.transactionTemplate = transactionTemplate;
}
@Override
public int getOrder() {
return 99; // order 值越小,越先执行
}
@NotNull
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@NotNull
@Override
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
// 对话 UUID
String chatUuid = aiChatReqVO.getChatId();
// 用户消息
String userMessage = aiChatReqVO.getMessage();
// 创建 AI 流式回答聚合容器(线程安全)
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
// 返回处理后的流
return chatClientResponseFlux
.doOnNext(response -> {
// 逐块收集内容
String chunk = null;
if (response.chatResponse() != null) {
chunk = response.chatResponse().getResult().getOutput().getText();
}
log.info("## chunk: {}", chunk);
// 若 chunk 块不为空,则追加到 fullContent 中
if (chunk != null) {
fullContent.get().append(chunk);
}
})
.doOnComplete(() -> {
// 流完成后打印完整回答
String completeResponse = fullContent.get().toString();
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
// 开启编程式事务
transactionTemplate.execute(status -> {
try {
// 1. 存储用户消息
chatMessageMapper.insert(ChatMessageDO.builder()
.chatUuid(chatUuid)
.content(userMessage)
.role(MessageType.USER.getValue()) // 用户消息
.createTime(LocalDateTime.now())
.build());
// 2. 存储 AI 回答
chatMessageMapper.insert(ChatMessageDO.builder()
.chatUuid(chatUuid)
.content(completeResponse)
.role(MessageType.ASSISTANT.getValue()) // AI 回答
.createTime(LocalDateTime.now())
.build());
return true;
} catch (Exception ex) {
status.setRollbackOnly(); // 标记事务为回滚
log.error("", ex);
}
return false;
});
});
}
}

View File

@@ -4,8 +4,10 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.google.common.collect.Lists;
import com.hanserwei.airobot.advisor.CustomStreamLoggerAdvisor;
import com.hanserwei.airobot.advisor.CustomChatMemoryAdvisor;
import com.hanserwei.airobot.advisor.CustomStreamLoggerAndMessage2DBAdvisor;
import com.hanserwei.airobot.aspect.ApiOperationLog;
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
import com.hanserwei.airobot.model.vo.chat.AiResponse;
import com.hanserwei.airobot.model.vo.chat.NewChatReqVO;
@@ -18,6 +20,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@@ -34,6 +37,10 @@ public class ChatController {
@Resource
private ChatService chatService;
@Resource
private ChatMessageMapper chatMessageMapper;
@Resource
private TransactionTemplate transactionTemplate;
@Value("${spring.ai.dashscope.api-key}")
private String apiKey;
@@ -72,8 +79,10 @@ public class ChatController {
// Advisor 集合
List<Advisor> advisors = Lists.newArrayList();
// 添加自定义对话记忆 Advisor以最新的 50 条消息作为记忆)
advisors.add(new CustomChatMemoryAdvisor(chatMessageMapper, aiChatReqVO, 50));
// 添加自定义打印流式对话日志 Advisor
advisors.add(new CustomStreamLoggerAdvisor());
advisors.add(new CustomStreamLoggerAndMessage2DBAdvisor(chatMessageMapper, aiChatReqVO, transactionTemplate));
// 应用 Advisor 集合
chatClientRequestSpec.advisors(advisors);

View File

@@ -17,10 +17,6 @@ spring:
maximum-pool-size: 20 # 最大连接池大小
connection-test-query: SELECT 1 # 连接测试查询
validation-timeout: 5000 # 验证连接的有效性
cassandra:
contact-points: 127.0.0.1 # Cassandra 集群节点地址(可配置多个,用逗号分隔)
port: 9042 # 端口号
local-datacenter: datacenter1 # 必须与集群配置的数据中心名称一致(大小写敏感)
ai:
dashscope:
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)