diff --git a/pom.xml b/pom.xml
index 8ee3128..7db8d2d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -57,6 +57,14 @@
jasypt-spring-boot-starter
3.0.5
+
+ com.alibaba.cloud.ai
+ spring-ai-alibaba-starter-memory-redis
+
+
+ org.springframework.boot
+ spring-boot-starter-data-redis
+
diff --git a/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java b/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java
index 0baaeb0..7aeed1f 100644
--- a/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java
+++ b/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java
@@ -1,19 +1,57 @@
package com.hanserwei.snailsai.config;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
+import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository;
+import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
+import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.memory.MessageWindowChatMemory;
+import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class ChatClientConfiguration {
+ @Value("${spring.ai.memory.redis.host}")
+ private String redisHost;
+ @Value("${spring.ai.memory.redis.port}")
+ private int redisPort;
+ @Value("${spring.ai.memory.redis.password}")
+ private String redisPassword;
+ @Value("${spring.ai.memory.redis.timeout}")
+ private int redisTimeout;
+
+
@Resource
private DashScopeChatModel dashScopeChatModel;
@Bean
- public ChatClient dashScopeChatClient() {
- return ChatClient.builder(dashScopeChatModel).build();
+ public BaseRedisChatMemoryRepository redisChatMemoryRepository() {
+ // 构建RedissonRedisChatMemoryRepository实例
+ return LettuceRedisChatMemoryRepository.builder()
+ .host(redisHost)
+ .port(redisPort)
+ .password(redisPassword)
+ .timeout(redisTimeout)
+ .build();
+ }
+
+ @Bean
+ public ChatMemory chatMemory(BaseRedisChatMemoryRepository chatMemoryRepository) {
+ return MessageWindowChatMemory
+ .builder()
+ .maxMessages(100000)
+ .chatMemoryRepository(chatMemoryRepository).build();
+ }
+
+ @Bean
+ public ChatClient dashScopeChatClient(ChatMemory chatMemory) {
+ return ChatClient.builder(dashScopeChatModel)
+ .defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor())
+ .build();
}
}
diff --git a/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java b/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java
new file mode 100644
index 0000000..ec9e2e7
--- /dev/null
+++ b/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java
@@ -0,0 +1,19 @@
+package com.hanserwei.snailsai.config;
+
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.CorsRegistry;
+import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
+
+@Configuration
+public class CorsConfig implements WebMvcConfigurer {
+
+ @Override
+ public void addCorsMappings(CorsRegistry registry) {
+ registry.addMapping("/**") // 匹配所有路径
+ .allowedOriginPatterns("*") // 允许所有域名(生产环境应指定具体域名)
+ .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS") // 允许的请求方法
+ .allowedHeaders("*") // 允许所有请求头
+ .allowCredentials(true) // 允许发送 Cookie
+ .maxAge(3600); // 预检请求的有效期(秒)
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java b/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java
index 5155de4..05640fb 100644
--- a/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java
+++ b/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java
@@ -1,7 +1,9 @@
package com.hanserwei.snailsai.controller;
+import com.hanserwei.snailsai.model.ChatMessageDTO;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
@@ -15,9 +17,10 @@ public class DashScopeController {
private ChatClient dashScopeChatClient;
@PostMapping(value = "/chat",produces = MediaType.TEXT_EVENT_STREAM_VALUE)
- public Flux chat(@RequestBody String userPrompt) {
+ public Flux chat(@RequestBody ChatMessageDTO chatMessageDTO) {
return dashScopeChatClient.prompt()
- .user(userPrompt)
+ .user(chatMessageDTO.getMessage())
+ .advisors(e-> e.param(ChatMemory.CONVERSATION_ID,chatMessageDTO.getConversionId()))
.stream()
.content();
}
diff --git a/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java b/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java
new file mode 100644
index 0000000..923e838
--- /dev/null
+++ b/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java
@@ -0,0 +1,17 @@
+package com.hanserwei.snailsai.model;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+@Builder
+public class ChatMessageDTO {
+
+ private String message;
+
+ private Long conversionId;
+}
diff --git a/src/main/resources/config/application.yml b/src/main/resources/config/application.yml
index a44f393..cfcca7f 100644
--- a/src/main/resources/config/application.yml
+++ b/src/main/resources/config/application.yml
@@ -1,4 +1,9 @@
#file: noinspection SpringBootConfigYamlInspection
+server:
+ servlet:
+ encoding:
+ charset: utf-8
+ force: true
spring:
application:
name: snails-ai
@@ -12,6 +17,20 @@ spring:
hibernate.format_sql: true
# 开启 SQL 语法高亮 (重点:让 SQL 醒目)
hibernate.highlight_sql: true
+ data:
+ redis:
+ host: localhost
+ port: 6379
+ password: redis
+ database: 4
+ lettuce:
+ pool:
+ enabled: true
+ max-active: 20
+ max-idle: 10
+ max-wait: 10000
+ min-idle: 10
+ time-between-eviction-runs: 10000
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://127.0.0.1:3306/snails_ai?useUnicode=true&characterEncoding=utf-8&zeroDateTimeBehavior=convertToNull&transformedBitIsBoolean=true&allowMultiQueries=true&useSSL=false&allowPublicKeyRetrieval=true
@@ -24,6 +43,12 @@ spring:
connection-timeout: 5000 # 获取连接超时 5 秒
max-lifetime: 28800000 # 8 小时(确保在数据库连接超时前被回收)
ai:
+ memory:
+ redis:
+ host: localhost
+ port: 6379
+ timeout: 5000
+ password: redis
dashscope:
api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=)
chat:
@@ -35,6 +60,7 @@ logging:
org.hibernate.SQL: debug
# 隐藏掉 Hibernate 冗长的连接池 INFO 信息
org.hibernate.orm.connections.pooling: WARN
+ org.springframework.ai.chat.client.advisor: DEBUG
jasypt:
encryptor:
password: ${jasypt.encryptor.password}