feat(ai): 实现对话记忆与消息持久化功能
- 移除 Cassandra 相关配置及依赖 - 新增 CustomChatMemoryAdvisor 实现对话记忆管理 -重命名并扩展 CustomStreamLoggerAdvisor 为 CustomStreamLoggerAndMessage2DBAdvisor,增加消息入库逻辑 - 在 ChatController 中集成新的 Advisor 并注入相关依赖 - 使用 TransactionTemplate 管理消息存储事务 -限制记忆消息数量为最新 50 条 - 支持将用户消息与 AI 回答同步写入数据库
This commit is contained in:
5
pom.xml
5
pom.xml
@@ -43,11 +43,6 @@
|
||||
<groupId>com.alibaba.cloud.ai</groupId>
|
||||
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
|
||||
</dependency>
|
||||
<!-- Cassandra -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-starter-model-chat-memory-repository-cassandra</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.hanserwei.airobot.advisor;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
|
||||
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class CustomChatMemoryAdvisor implements StreamAdvisor {
|
||||
|
||||
private final ChatMessageMapper chatMessageMapper;
|
||||
private final AiChatReqVO aiChatReqVO;
|
||||
private final int limit;
|
||||
|
||||
public CustomChatMemoryAdvisor(ChatMessageMapper chatMessageMapper, AiChatReqVO aiChatReqVO, int limit) {
|
||||
this.chatMessageMapper = chatMessageMapper;
|
||||
this.aiChatReqVO = aiChatReqVO;
|
||||
this.limit = limit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 2; // order 值越小,越先执行
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
||||
log.info("## 自定义聊天记忆 Advisor...");
|
||||
|
||||
// 对话 UUID
|
||||
String chatUuid = aiChatReqVO.getChatId();
|
||||
|
||||
// 查询数据库拉取最新的聊天消息
|
||||
List<ChatMessageDO> messages = chatMessageMapper.selectList(Wrappers.<ChatMessageDO>lambdaQuery()
|
||||
.eq(ChatMessageDO::getChatUuid, chatUuid) // 查询指定对话 UUID 下的聊天记录
|
||||
.orderByDesc(ChatMessageDO::getCreateTime) // 查询最新的消息
|
||||
.last(String.format("LIMIT %d", limit))); // 仅查询 LIMIT 条
|
||||
|
||||
// 按发布时间升序排列
|
||||
List<ChatMessageDO> sortedMessages = messages.stream()
|
||||
.sorted(Comparator.comparing(ChatMessageDO::getCreateTime)) // 升序排列
|
||||
.toList();
|
||||
|
||||
// 所有消息
|
||||
List<Message> messageList = getMessageList(sortedMessages);
|
||||
|
||||
// 除了记忆消息,还需要添加当前用户消息
|
||||
messageList.addAll(chatClientRequest.prompt().getInstructions());
|
||||
|
||||
// 构建一个新的 ChatClientRequest 请求对象
|
||||
ChatClientRequest processedChatClientRequest = chatClientRequest
|
||||
.mutate()
|
||||
.prompt(chatClientRequest.prompt().mutate().messages(messageList).build())
|
||||
.build();
|
||||
|
||||
return streamAdvisorChain.nextStream(processedChatClientRequest);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static List<Message> getMessageList(List<ChatMessageDO> sortedMessages) {
|
||||
List<Message> messageList = Lists.newArrayList();
|
||||
|
||||
// 将数据库记录转换为对应类型的消息
|
||||
for (ChatMessageDO chatMessageDO : sortedMessages) {
|
||||
// 消息类型
|
||||
String type = chatMessageDO.getRole();
|
||||
if (Objects.equals(type, MessageType.USER.getValue())) { // 用户消息
|
||||
Message userMessage = new UserMessage(chatMessageDO.getContent());
|
||||
messageList.add(userMessage);
|
||||
} else if (Objects.equals(type, MessageType.ASSISTANT.getValue())) { // AI 助手消息
|
||||
Message assistantMessage = new AssistantMessage(chatMessageDO.getContent());
|
||||
messageList.add(assistantMessage);
|
||||
}
|
||||
}
|
||||
return messageList;
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package com.hanserwei.airobot.advisor;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
@Slf4j
|
||||
public class CustomStreamLoggerAdvisor implements StreamAdvisor {
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 99; // order 值越小,越先执行
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
||||
|
||||
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
|
||||
|
||||
// 创建 AI 流式回答聚合容器(线程安全)
|
||||
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
|
||||
|
||||
// 返回处理后的流
|
||||
return chatClientResponseFlux
|
||||
.doOnNext(response -> {
|
||||
// 逐块收集内容
|
||||
String chunk = null;
|
||||
if (response.chatResponse() != null) {
|
||||
chunk = response.chatResponse().getResult().getOutput().getText();
|
||||
}
|
||||
|
||||
log.info("## chunk: {}", chunk);
|
||||
|
||||
// 若 chunk 块不为空,则追加到 fullContent 中
|
||||
if (chunk != null) {
|
||||
fullContent.get().append(chunk);
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
// 流完成后打印完整回答
|
||||
String completeResponse = fullContent.get().toString();
|
||||
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
|
||||
})
|
||||
.doOnError(error -> {
|
||||
// 出错时打印已收集的部分
|
||||
String partialResponse = fullContent.get().toString();
|
||||
log.error("## Stream 流出现错误,已收集回答如下: {}", partialResponse, error);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package com.hanserwei.airobot.advisor;
|
||||
|
||||
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
|
||||
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
@Slf4j
|
||||
public class CustomStreamLoggerAndMessage2DBAdvisor implements StreamAdvisor {
|
||||
|
||||
private final ChatMessageMapper chatMessageMapper;
|
||||
private final AiChatReqVO aiChatReqVO;
|
||||
private final TransactionTemplate transactionTemplate;
|
||||
|
||||
public CustomStreamLoggerAndMessage2DBAdvisor(ChatMessageMapper chatMessageMapper,
|
||||
AiChatReqVO aiChatReqVO,
|
||||
TransactionTemplate transactionTemplate) {
|
||||
this.chatMessageMapper = chatMessageMapper;
|
||||
this.aiChatReqVO = aiChatReqVO;
|
||||
this.transactionTemplate = transactionTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 99; // order 值越小,越先执行
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
||||
|
||||
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
|
||||
|
||||
// 对话 UUID
|
||||
String chatUuid = aiChatReqVO.getChatId();
|
||||
// 用户消息
|
||||
String userMessage = aiChatReqVO.getMessage();
|
||||
|
||||
// 创建 AI 流式回答聚合容器(线程安全)
|
||||
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
|
||||
|
||||
// 返回处理后的流
|
||||
return chatClientResponseFlux
|
||||
.doOnNext(response -> {
|
||||
// 逐块收集内容
|
||||
String chunk = null;
|
||||
if (response.chatResponse() != null) {
|
||||
chunk = response.chatResponse().getResult().getOutput().getText();
|
||||
}
|
||||
|
||||
log.info("## chunk: {}", chunk);
|
||||
|
||||
// 若 chunk 块不为空,则追加到 fullContent 中
|
||||
if (chunk != null) {
|
||||
fullContent.get().append(chunk);
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> {
|
||||
// 流完成后打印完整回答
|
||||
String completeResponse = fullContent.get().toString();
|
||||
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
|
||||
|
||||
// 开启编程式事务
|
||||
transactionTemplate.execute(status -> {
|
||||
try {
|
||||
// 1. 存储用户消息
|
||||
chatMessageMapper.insert(ChatMessageDO.builder()
|
||||
.chatUuid(chatUuid)
|
||||
.content(userMessage)
|
||||
.role(MessageType.USER.getValue()) // 用户消息
|
||||
.createTime(LocalDateTime.now())
|
||||
.build());
|
||||
|
||||
// 2. 存储 AI 回答
|
||||
chatMessageMapper.insert(ChatMessageDO.builder()
|
||||
.chatUuid(chatUuid)
|
||||
.content(completeResponse)
|
||||
.role(MessageType.ASSISTANT.getValue()) // AI 回答
|
||||
.createTime(LocalDateTime.now())
|
||||
.build());
|
||||
|
||||
return true;
|
||||
} catch (Exception ex) {
|
||||
status.setRollbackOnly(); // 标记事务为回滚
|
||||
log.error("", ex);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -4,8 +4,10 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hanserwei.airobot.advisor.CustomStreamLoggerAdvisor;
|
||||
import com.hanserwei.airobot.advisor.CustomChatMemoryAdvisor;
|
||||
import com.hanserwei.airobot.advisor.CustomStreamLoggerAndMessage2DBAdvisor;
|
||||
import com.hanserwei.airobot.aspect.ApiOperationLog;
|
||||
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||
import com.hanserwei.airobot.model.vo.chat.AiResponse;
|
||||
import com.hanserwei.airobot.model.vo.chat.NewChatReqVO;
|
||||
@@ -18,6 +20,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
@@ -34,6 +37,10 @@ public class ChatController {
|
||||
|
||||
@Resource
|
||||
private ChatService chatService;
|
||||
@Resource
|
||||
private ChatMessageMapper chatMessageMapper;
|
||||
@Resource
|
||||
private TransactionTemplate transactionTemplate;
|
||||
|
||||
@Value("${spring.ai.dashscope.api-key}")
|
||||
private String apiKey;
|
||||
@@ -72,8 +79,10 @@ public class ChatController {
|
||||
|
||||
// Advisor 集合
|
||||
List<Advisor> advisors = Lists.newArrayList();
|
||||
// 添加自定义对话记忆 Advisor(以最新的 50 条消息作为记忆)
|
||||
advisors.add(new CustomChatMemoryAdvisor(chatMessageMapper, aiChatReqVO, 50));
|
||||
// 添加自定义打印流式对话日志 Advisor
|
||||
advisors.add(new CustomStreamLoggerAdvisor());
|
||||
advisors.add(new CustomStreamLoggerAndMessage2DBAdvisor(chatMessageMapper, aiChatReqVO, transactionTemplate));
|
||||
|
||||
// 应用 Advisor 集合
|
||||
chatClientRequestSpec.advisors(advisors);
|
||||
|
||||
@@ -17,10 +17,6 @@ spring:
|
||||
maximum-pool-size: 20 # 最大连接池大小
|
||||
connection-test-query: SELECT 1 # 连接测试查询
|
||||
validation-timeout: 5000 # 验证连接的有效性
|
||||
cassandra:
|
||||
contact-points: 127.0.0.1 # Cassandra 集群节点地址(可配置多个,用逗号分隔)
|
||||
port: 9042 # 端口号
|
||||
local-datacenter: datacenter1 # 必须与集群配置的数据中心名称一致(大小写敏感)
|
||||
ai:
|
||||
dashscope:
|
||||
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)
|
||||
|
||||
Reference in New Issue
Block a user