feat(ai): 实现对话记忆与消息持久化功能
- 移除 Cassandra 相关配置及依赖 - 新增 CustomChatMemoryAdvisor 实现对话记忆管理 -重命名并扩展 CustomStreamLoggerAdvisor 为 CustomStreamLoggerAndMessage2DBAdvisor,增加消息入库逻辑 - 在 ChatController 中集成新的 Advisor 并注入相关依赖 - 使用 TransactionTemplate 管理消息存储事务 -限制记忆消息数量为最新 50 条 - 支持将用户消息与 AI 回答同步写入数据库
This commit is contained in:
5
pom.xml
5
pom.xml
@@ -43,11 +43,6 @@
|
|||||||
<groupId>com.alibaba.cloud.ai</groupId>
|
<groupId>com.alibaba.cloud.ai</groupId>
|
||||||
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
|
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<!-- Cassandra -->
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.ai</groupId>
|
|
||||||
<artifactId>spring-ai-starter-model-chat-memory-repository-cassandra</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-test</artifactId>
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
|
|||||||
@@ -0,0 +1,100 @@
|
|||||||
|
package com.hanserwei.airobot.advisor;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
|
||||||
|
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||||
|
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||||
|
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||||
|
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||||
|
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||||
|
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.MessageType;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class CustomChatMemoryAdvisor implements StreamAdvisor {
|
||||||
|
|
||||||
|
private final ChatMessageMapper chatMessageMapper;
|
||||||
|
private final AiChatReqVO aiChatReqVO;
|
||||||
|
private final int limit;
|
||||||
|
|
||||||
|
public CustomChatMemoryAdvisor(ChatMessageMapper chatMessageMapper, AiChatReqVO aiChatReqVO, int limit) {
|
||||||
|
this.chatMessageMapper = chatMessageMapper;
|
||||||
|
this.aiChatReqVO = aiChatReqVO;
|
||||||
|
this.limit = limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOrder() {
|
||||||
|
return 2; // order 值越小,越先执行
|
||||||
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return this.getClass().getSimpleName();
|
||||||
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Override
|
||||||
|
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
||||||
|
log.info("## 自定义聊天记忆 Advisor...");
|
||||||
|
|
||||||
|
// 对话 UUID
|
||||||
|
String chatUuid = aiChatReqVO.getChatId();
|
||||||
|
|
||||||
|
// 查询数据库拉取最新的聊天消息
|
||||||
|
List<ChatMessageDO> messages = chatMessageMapper.selectList(Wrappers.<ChatMessageDO>lambdaQuery()
|
||||||
|
.eq(ChatMessageDO::getChatUuid, chatUuid) // 查询指定对话 UUID 下的聊天记录
|
||||||
|
.orderByDesc(ChatMessageDO::getCreateTime) // 查询最新的消息
|
||||||
|
.last(String.format("LIMIT %d", limit))); // 仅查询 LIMIT 条
|
||||||
|
|
||||||
|
// 按发布时间升序排列
|
||||||
|
List<ChatMessageDO> sortedMessages = messages.stream()
|
||||||
|
.sorted(Comparator.comparing(ChatMessageDO::getCreateTime)) // 升序排列
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
// 所有消息
|
||||||
|
List<Message> messageList = getMessageList(sortedMessages);
|
||||||
|
|
||||||
|
// 除了记忆消息,还需要添加当前用户消息
|
||||||
|
messageList.addAll(chatClientRequest.prompt().getInstructions());
|
||||||
|
|
||||||
|
// 构建一个新的 ChatClientRequest 请求对象
|
||||||
|
ChatClientRequest processedChatClientRequest = chatClientRequest
|
||||||
|
.mutate()
|
||||||
|
.prompt(chatClientRequest.prompt().mutate().messages(messageList).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return streamAdvisorChain.nextStream(processedChatClientRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private static List<Message> getMessageList(List<ChatMessageDO> sortedMessages) {
|
||||||
|
List<Message> messageList = Lists.newArrayList();
|
||||||
|
|
||||||
|
// 将数据库记录转换为对应类型的消息
|
||||||
|
for (ChatMessageDO chatMessageDO : sortedMessages) {
|
||||||
|
// 消息类型
|
||||||
|
String type = chatMessageDO.getRole();
|
||||||
|
if (Objects.equals(type, MessageType.USER.getValue())) { // 用户消息
|
||||||
|
Message userMessage = new UserMessage(chatMessageDO.getContent());
|
||||||
|
messageList.add(userMessage);
|
||||||
|
} else if (Objects.equals(type, MessageType.ASSISTANT.getValue())) { // AI 助手消息
|
||||||
|
Message assistantMessage = new AssistantMessage(chatMessageDO.getContent());
|
||||||
|
messageList.add(assistantMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messageList;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
package com.hanserwei.airobot.advisor;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
|
||||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
|
||||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
|
||||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
|
||||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
|
||||||
import reactor.core.publisher.Flux;
|
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class CustomStreamLoggerAdvisor implements StreamAdvisor {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getOrder() {
|
|
||||||
return 99; // order 值越小,越先执行
|
|
||||||
}
|
|
||||||
|
|
||||||
@NotNull
|
|
||||||
@Override
|
|
||||||
public String getName() {
|
|
||||||
return this.getClass().getSimpleName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@NotNull
|
|
||||||
@Override
|
|
||||||
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
|
||||||
|
|
||||||
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
|
|
||||||
|
|
||||||
// 创建 AI 流式回答聚合容器(线程安全)
|
|
||||||
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
|
|
||||||
|
|
||||||
// 返回处理后的流
|
|
||||||
return chatClientResponseFlux
|
|
||||||
.doOnNext(response -> {
|
|
||||||
// 逐块收集内容
|
|
||||||
String chunk = null;
|
|
||||||
if (response.chatResponse() != null) {
|
|
||||||
chunk = response.chatResponse().getResult().getOutput().getText();
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("## chunk: {}", chunk);
|
|
||||||
|
|
||||||
// 若 chunk 块不为空,则追加到 fullContent 中
|
|
||||||
if (chunk != null) {
|
|
||||||
fullContent.get().append(chunk);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.doOnComplete(() -> {
|
|
||||||
// 流完成后打印完整回答
|
|
||||||
String completeResponse = fullContent.get().toString();
|
|
||||||
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
|
|
||||||
})
|
|
||||||
.doOnError(error -> {
|
|
||||||
// 出错时打印已收集的部分
|
|
||||||
String partialResponse = fullContent.get().toString();
|
|
||||||
log.error("## Stream 流出现错误,已收集回答如下: {}", partialResponse, error);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package com.hanserwei.airobot.advisor;
|
||||||
|
|
||||||
|
import com.hanserwei.airobot.domain.dos.ChatMessageDO;
|
||||||
|
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||||
|
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||||
|
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||||
|
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||||
|
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||||
|
import org.springframework.ai.chat.messages.MessageType;
|
||||||
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class CustomStreamLoggerAndMessage2DBAdvisor implements StreamAdvisor {
|
||||||
|
|
||||||
|
private final ChatMessageMapper chatMessageMapper;
|
||||||
|
private final AiChatReqVO aiChatReqVO;
|
||||||
|
private final TransactionTemplate transactionTemplate;
|
||||||
|
|
||||||
|
public CustomStreamLoggerAndMessage2DBAdvisor(ChatMessageMapper chatMessageMapper,
|
||||||
|
AiChatReqVO aiChatReqVO,
|
||||||
|
TransactionTemplate transactionTemplate) {
|
||||||
|
this.chatMessageMapper = chatMessageMapper;
|
||||||
|
this.aiChatReqVO = aiChatReqVO;
|
||||||
|
this.transactionTemplate = transactionTemplate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOrder() {
|
||||||
|
return 99; // order 值越小,越先执行
|
||||||
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return this.getClass().getSimpleName();
|
||||||
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Override
|
||||||
|
public Flux<ChatClientResponse> adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {
|
||||||
|
|
||||||
|
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest);
|
||||||
|
|
||||||
|
// 对话 UUID
|
||||||
|
String chatUuid = aiChatReqVO.getChatId();
|
||||||
|
// 用户消息
|
||||||
|
String userMessage = aiChatReqVO.getMessage();
|
||||||
|
|
||||||
|
// 创建 AI 流式回答聚合容器(线程安全)
|
||||||
|
AtomicReference<StringBuilder> fullContent = new AtomicReference<>(new StringBuilder());
|
||||||
|
|
||||||
|
// 返回处理后的流
|
||||||
|
return chatClientResponseFlux
|
||||||
|
.doOnNext(response -> {
|
||||||
|
// 逐块收集内容
|
||||||
|
String chunk = null;
|
||||||
|
if (response.chatResponse() != null) {
|
||||||
|
chunk = response.chatResponse().getResult().getOutput().getText();
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("## chunk: {}", chunk);
|
||||||
|
|
||||||
|
// 若 chunk 块不为空,则追加到 fullContent 中
|
||||||
|
if (chunk != null) {
|
||||||
|
fullContent.get().append(chunk);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.doOnComplete(() -> {
|
||||||
|
// 流完成后打印完整回答
|
||||||
|
String completeResponse = fullContent.get().toString();
|
||||||
|
log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse);
|
||||||
|
|
||||||
|
// 开启编程式事务
|
||||||
|
transactionTemplate.execute(status -> {
|
||||||
|
try {
|
||||||
|
// 1. 存储用户消息
|
||||||
|
chatMessageMapper.insert(ChatMessageDO.builder()
|
||||||
|
.chatUuid(chatUuid)
|
||||||
|
.content(userMessage)
|
||||||
|
.role(MessageType.USER.getValue()) // 用户消息
|
||||||
|
.createTime(LocalDateTime.now())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// 2. 存储 AI 回答
|
||||||
|
chatMessageMapper.insert(ChatMessageDO.builder()
|
||||||
|
.chatUuid(chatUuid)
|
||||||
|
.content(completeResponse)
|
||||||
|
.role(MessageType.ASSISTANT.getValue()) // AI 回答
|
||||||
|
.createTime(LocalDateTime.now())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} catch (Exception ex) {
|
||||||
|
status.setRollbackOnly(); // 标记事务为回滚
|
||||||
|
log.error("", ex);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,8 +4,10 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
|
|||||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.hanserwei.airobot.advisor.CustomStreamLoggerAdvisor;
|
import com.hanserwei.airobot.advisor.CustomChatMemoryAdvisor;
|
||||||
|
import com.hanserwei.airobot.advisor.CustomStreamLoggerAndMessage2DBAdvisor;
|
||||||
import com.hanserwei.airobot.aspect.ApiOperationLog;
|
import com.hanserwei.airobot.aspect.ApiOperationLog;
|
||||||
|
import com.hanserwei.airobot.domain.mapper.ChatMessageMapper;
|
||||||
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
import com.hanserwei.airobot.model.vo.chat.AiChatReqVO;
|
||||||
import com.hanserwei.airobot.model.vo.chat.AiResponse;
|
import com.hanserwei.airobot.model.vo.chat.AiResponse;
|
||||||
import com.hanserwei.airobot.model.vo.chat.NewChatReqVO;
|
import com.hanserwei.airobot.model.vo.chat.NewChatReqVO;
|
||||||
@@ -18,6 +20,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor;
|
|||||||
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
@@ -34,6 +37,10 @@ public class ChatController {
|
|||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private ChatService chatService;
|
private ChatService chatService;
|
||||||
|
@Resource
|
||||||
|
private ChatMessageMapper chatMessageMapper;
|
||||||
|
@Resource
|
||||||
|
private TransactionTemplate transactionTemplate;
|
||||||
|
|
||||||
@Value("${spring.ai.dashscope.api-key}")
|
@Value("${spring.ai.dashscope.api-key}")
|
||||||
private String apiKey;
|
private String apiKey;
|
||||||
@@ -72,8 +79,10 @@ public class ChatController {
|
|||||||
|
|
||||||
// Advisor 集合
|
// Advisor 集合
|
||||||
List<Advisor> advisors = Lists.newArrayList();
|
List<Advisor> advisors = Lists.newArrayList();
|
||||||
|
// 添加自定义对话记忆 Advisor(以最新的 50 条消息作为记忆)
|
||||||
|
advisors.add(new CustomChatMemoryAdvisor(chatMessageMapper, aiChatReqVO, 50));
|
||||||
// 添加自定义打印流式对话日志 Advisor
|
// 添加自定义打印流式对话日志 Advisor
|
||||||
advisors.add(new CustomStreamLoggerAdvisor());
|
advisors.add(new CustomStreamLoggerAndMessage2DBAdvisor(chatMessageMapper, aiChatReqVO, transactionTemplate));
|
||||||
|
|
||||||
// 应用 Advisor 集合
|
// 应用 Advisor 集合
|
||||||
chatClientRequestSpec.advisors(advisors);
|
chatClientRequestSpec.advisors(advisors);
|
||||||
|
|||||||
@@ -17,10 +17,6 @@ spring:
|
|||||||
maximum-pool-size: 20 # 最大连接池大小
|
maximum-pool-size: 20 # 最大连接池大小
|
||||||
connection-test-query: SELECT 1 # 连接测试查询
|
connection-test-query: SELECT 1 # 连接测试查询
|
||||||
validation-timeout: 5000 # 验证连接的有效性
|
validation-timeout: 5000 # 验证连接的有效性
|
||||||
cassandra:
|
|
||||||
contact-points: 127.0.0.1 # Cassandra 集群节点地址(可配置多个,用逗号分隔)
|
|
||||||
port: 9042 # 端口号
|
|
||||||
local-datacenter: datacenter1 # 必须与集群配置的数据中心名称一致(大小写敏感)
|
|
||||||
ai:
|
ai:
|
||||||
dashscope:
|
dashscope:
|
||||||
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)
|
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)
|
||||||
|
|||||||
Reference in New Issue
Block a user