diff --git a/pom.xml b/pom.xml index 51f2eb1..af2cca0 100644 --- a/pom.xml +++ b/pom.xml @@ -43,11 +43,6 @@ com.alibaba.cloud.ai spring-ai-alibaba-starter-dashscope - - - org.springframework.ai - spring-ai-starter-model-chat-memory-repository-cassandra - org.springframework.boot spring-boot-starter-test diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java new file mode 100644 index 0000000..92ec3a5 --- /dev/null +++ b/src/main/java/com/hanserwei/airobot/advisor/CustomChatMemoryAdvisor.java @@ -0,0 +1,100 @@ +package com.hanserwei.airobot.advisor; + +import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.google.common.collect.Lists; +import com.hanserwei.airobot.domain.dos.ChatMessageDO; +import com.hanserwei.airobot.domain.mapper.ChatMessageMapper; +import com.hanserwei.airobot.model.vo.chat.AiChatReqVO; +import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import reactor.core.publisher.Flux; + +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +@Slf4j +public class CustomChatMemoryAdvisor implements StreamAdvisor { + + private final ChatMessageMapper chatMessageMapper; + private final AiChatReqVO aiChatReqVO; + private final int limit; + + public CustomChatMemoryAdvisor(ChatMessageMapper chatMessageMapper, AiChatReqVO aiChatReqVO, int limit) { + this.chatMessageMapper = chatMessageMapper; + this.aiChatReqVO = aiChatReqVO; + this.limit = limit; + } + + @Override + public int getOrder() { + return 2; // order 值越小,越先执行 + } + + @NotNull + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @NotNull + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { + log.info("## 自定义聊天记忆 Advisor..."); + + // 对话 UUID + String chatUuid = aiChatReqVO.getChatId(); + + // 查询数据库拉取最新的聊天消息 + List messages = chatMessageMapper.selectList(Wrappers.lambdaQuery() + .eq(ChatMessageDO::getChatUuid, chatUuid) // 查询指定对话 UUID 下的聊天记录 + .orderByDesc(ChatMessageDO::getCreateTime) // 查询最新的消息 + .last(String.format("LIMIT %d", limit))); // 仅查询 LIMIT 条 + + // 按发布时间升序排列 + List sortedMessages = messages.stream() + .sorted(Comparator.comparing(ChatMessageDO::getCreateTime)) // 升序排列 + .toList(); + + // 所有消息 + List messageList = getMessageList(sortedMessages); + + // 除了记忆消息,还需要添加当前用户消息 + messageList.addAll(chatClientRequest.prompt().getInstructions()); + + // 构建一个新的 ChatClientRequest 请求对象 + ChatClientRequest processedChatClientRequest = chatClientRequest + .mutate() + .prompt(chatClientRequest.prompt().mutate().messages(messageList).build()) + .build(); + + return streamAdvisorChain.nextStream(processedChatClientRequest); + } + + @NotNull + private static List getMessageList(List sortedMessages) { + List messageList = Lists.newArrayList(); + + // 将数据库记录转换为对应类型的消息 + for (ChatMessageDO chatMessageDO : sortedMessages) { + // 消息类型 + String type = chatMessageDO.getRole(); + if (Objects.equals(type, MessageType.USER.getValue())) { // 用户消息 + Message userMessage = new UserMessage(chatMessageDO.getContent()); + messageList.add(userMessage); + } else if (Objects.equals(type, MessageType.ASSISTANT.getValue())) { // AI 助手消息 + Message assistantMessage = new AssistantMessage(chatMessageDO.getContent()); + messageList.add(assistantMessage); + } + } + return messageList; + } +} \ No newline at end of file diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java deleted file mode 100644 index f2777f4..0000000 --- a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAdvisor.java +++ /dev/null @@ -1,63 +0,0 @@ -package com.hanserwei.airobot.advisor; - -import lombok.extern.slf4j.Slf4j; -import org.jetbrains.annotations.NotNull; -import org.springframework.ai.chat.client.ChatClientRequest; -import org.springframework.ai.chat.client.ChatClientResponse; -import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; -import reactor.core.publisher.Flux; - -import java.util.concurrent.atomic.AtomicReference; - -@Slf4j -public class CustomStreamLoggerAdvisor implements StreamAdvisor { - - @Override - public int getOrder() { - return 99; // order 值越小,越先执行 - } - - @NotNull - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @NotNull - @Override - public Flux adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { - - Flux chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest); - - // 创建 AI 流式回答聚合容器(线程安全) - AtomicReference fullContent = new AtomicReference<>(new StringBuilder()); - - // 返回处理后的流 - return chatClientResponseFlux - .doOnNext(response -> { - // 逐块收集内容 - String chunk = null; - if (response.chatResponse() != null) { - chunk = response.chatResponse().getResult().getOutput().getText(); - } - - log.info("## chunk: {}", chunk); - - // 若 chunk 块不为空,则追加到 fullContent 中 - if (chunk != null) { - fullContent.get().append(chunk); - } - }) - .doOnComplete(() -> { - // 流完成后打印完整回答 - String completeResponse = fullContent.get().toString(); - log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse); - }) - .doOnError(error -> { - // 出错时打印已收集的部分 - String partialResponse = fullContent.get().toString(); - log.error("## Stream 流出现错误,已收集回答如下: {}", partialResponse, error); - }); - } -} diff --git a/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java new file mode 100644 index 0000000..100fc92 --- /dev/null +++ b/src/main/java/com/hanserwei/airobot/advisor/CustomStreamLoggerAndMessage2DBAdvisor.java @@ -0,0 +1,108 @@ +package com.hanserwei.airobot.advisor; + +import com.hanserwei.airobot.domain.dos.ChatMessageDO; +import com.hanserwei.airobot.domain.mapper.ChatMessageMapper; +import com.hanserwei.airobot.model.vo.chat.AiChatReqVO; +import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.transaction.support.TransactionTemplate; +import reactor.core.publisher.Flux; + +import java.time.LocalDateTime; +import java.util.concurrent.atomic.AtomicReference; + +@Slf4j +public class CustomStreamLoggerAndMessage2DBAdvisor implements StreamAdvisor { + + private final ChatMessageMapper chatMessageMapper; + private final AiChatReqVO aiChatReqVO; + private final TransactionTemplate transactionTemplate; + + public CustomStreamLoggerAndMessage2DBAdvisor(ChatMessageMapper chatMessageMapper, + AiChatReqVO aiChatReqVO, + TransactionTemplate transactionTemplate) { + this.chatMessageMapper = chatMessageMapper; + this.aiChatReqVO = aiChatReqVO; + this.transactionTemplate = transactionTemplate; + } + + @Override + public int getOrder() { + return 99; // order 值越小,越先执行 + } + + @NotNull + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @NotNull + @Override + public Flux adviseStream(@NotNull ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { + + Flux chatClientResponseFlux = streamAdvisorChain.nextStream(chatClientRequest); + + // 对话 UUID + String chatUuid = aiChatReqVO.getChatId(); + // 用户消息 + String userMessage = aiChatReqVO.getMessage(); + + // 创建 AI 流式回答聚合容器(线程安全) + AtomicReference fullContent = new AtomicReference<>(new StringBuilder()); + + // 返回处理后的流 + return chatClientResponseFlux + .doOnNext(response -> { + // 逐块收集内容 + String chunk = null; + if (response.chatResponse() != null) { + chunk = response.chatResponse().getResult().getOutput().getText(); + } + + log.info("## chunk: {}", chunk); + + // 若 chunk 块不为空,则追加到 fullContent 中 + if (chunk != null) { + fullContent.get().append(chunk); + } + }) + .doOnComplete(() -> { + // 流完成后打印完整回答 + String completeResponse = fullContent.get().toString(); + log.info("\n==== FULL AI RESPONSE ====\n{}\n========================", completeResponse); + + // 开启编程式事务 + transactionTemplate.execute(status -> { + try { + // 1. 存储用户消息 + chatMessageMapper.insert(ChatMessageDO.builder() + .chatUuid(chatUuid) + .content(userMessage) + .role(MessageType.USER.getValue()) // 用户消息 + .createTime(LocalDateTime.now()) + .build()); + + // 2. 存储 AI 回答 + chatMessageMapper.insert(ChatMessageDO.builder() + .chatUuid(chatUuid) + .content(completeResponse) + .role(MessageType.ASSISTANT.getValue()) // AI 回答 + .createTime(LocalDateTime.now()) + .build()); + + return true; + } catch (Exception ex) { + status.setRollbackOnly(); // 标记事务为回滚 + log.error("", ex); + } + return false; + }); + }); + } +} diff --git a/src/main/java/com/hanserwei/airobot/controller/ChatController.java b/src/main/java/com/hanserwei/airobot/controller/ChatController.java index 178e6b9..43e1eca 100644 --- a/src/main/java/com/hanserwei/airobot/controller/ChatController.java +++ b/src/main/java/com/hanserwei/airobot/controller/ChatController.java @@ -4,8 +4,10 @@ import com.alibaba.cloud.ai.dashscope.api.DashScopeApi; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions; import com.google.common.collect.Lists; -import com.hanserwei.airobot.advisor.CustomStreamLoggerAdvisor; +import com.hanserwei.airobot.advisor.CustomChatMemoryAdvisor; +import com.hanserwei.airobot.advisor.CustomStreamLoggerAndMessage2DBAdvisor; import com.hanserwei.airobot.aspect.ApiOperationLog; +import com.hanserwei.airobot.domain.mapper.ChatMessageMapper; import com.hanserwei.airobot.model.vo.chat.AiChatReqVO; import com.hanserwei.airobot.model.vo.chat.AiResponse; import com.hanserwei.airobot.model.vo.chat.NewChatReqVO; @@ -18,6 +20,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.MediaType; +import org.springframework.transaction.support.TransactionTemplate; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -34,6 +37,10 @@ public class ChatController { @Resource private ChatService chatService; + @Resource + private ChatMessageMapper chatMessageMapper; + @Resource + private TransactionTemplate transactionTemplate; @Value("${spring.ai.dashscope.api-key}") private String apiKey; @@ -72,8 +79,10 @@ public class ChatController { // Advisor 集合 List advisors = Lists.newArrayList(); + // 添加自定义对话记忆 Advisor(以最新的 50 条消息作为记忆) + advisors.add(new CustomChatMemoryAdvisor(chatMessageMapper, aiChatReqVO, 50)); // 添加自定义打印流式对话日志 Advisor - advisors.add(new CustomStreamLoggerAdvisor()); + advisors.add(new CustomStreamLoggerAndMessage2DBAdvisor(chatMessageMapper, aiChatReqVO, transactionTemplate)); // 应用 Advisor 集合 chatClientRequestSpec.advisors(advisors); diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index beff518..bd14655 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -17,10 +17,6 @@ spring: maximum-pool-size: 20 # 最大连接池大小 connection-test-query: SELECT 1 # 连接测试查询 validation-timeout: 5000 # 验证连接的有效性 - cassandra: - contact-points: 127.0.0.1 # Cassandra 集群节点地址(可配置多个,用逗号分隔) - port: 9042 # 端口号 - local-datacenter: datacenter1 # 必须与集群配置的数据中心名称一致(大小写敏感) ai: dashscope: api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)