diff --git a/pom.xml b/pom.xml index 8ee3128..7db8d2d 100644 --- a/pom.xml +++ b/pom.xml @@ -57,6 +57,14 @@ jasypt-spring-boot-starter 3.0.5 + + com.alibaba.cloud.ai + spring-ai-alibaba-starter-memory-redis + + + org.springframework.boot + spring-boot-starter-data-redis + diff --git a/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java b/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java index 0baaeb0..7aeed1f 100644 --- a/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java +++ b/src/main/java/com/hanserwei/snailsai/config/ChatClientConfiguration.java @@ -1,19 +1,57 @@ package com.hanserwei.snailsai.config; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; +import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository; +import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @Configuration public class ChatClientConfiguration { + @Value("${spring.ai.memory.redis.host}") + private String redisHost; + @Value("${spring.ai.memory.redis.port}") + private int redisPort; + @Value("${spring.ai.memory.redis.password}") + private String redisPassword; + @Value("${spring.ai.memory.redis.timeout}") + private int redisTimeout; + + @Resource private DashScopeChatModel dashScopeChatModel; @Bean - public ChatClient dashScopeChatClient() { - return ChatClient.builder(dashScopeChatModel).build(); + public BaseRedisChatMemoryRepository redisChatMemoryRepository() { + // 构建RedissonRedisChatMemoryRepository实例 + return LettuceRedisChatMemoryRepository.builder() + .host(redisHost) + .port(redisPort) + .password(redisPassword) + .timeout(redisTimeout) + .build(); + } + + @Bean + public ChatMemory chatMemory(BaseRedisChatMemoryRepository chatMemoryRepository) { + return MessageWindowChatMemory + .builder() + .maxMessages(100000) + .chatMemoryRepository(chatMemoryRepository).build(); + } + + @Bean + public ChatClient dashScopeChatClient(ChatMemory chatMemory) { + return ChatClient.builder(dashScopeChatModel) + .defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor()) + .build(); } } diff --git a/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java b/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java new file mode 100644 index 0000000..ec9e2e7 --- /dev/null +++ b/src/main/java/com/hanserwei/snailsai/config/CorsConfig.java @@ -0,0 +1,19 @@ +package com.hanserwei.snailsai.config; + +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.CorsRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +@Configuration +public class CorsConfig implements WebMvcConfigurer { + + @Override + public void addCorsMappings(CorsRegistry registry) { + registry.addMapping("/**") // 匹配所有路径 + .allowedOriginPatterns("*") // 允许所有域名(生产环境应指定具体域名) + .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS") // 允许的请求方法 + .allowedHeaders("*") // 允许所有请求头 + .allowCredentials(true) // 允许发送 Cookie + .maxAge(3600); // 预检请求的有效期(秒) + } +} \ No newline at end of file diff --git a/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java b/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java index 5155de4..05640fb 100644 --- a/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java +++ b/src/main/java/com/hanserwei/snailsai/controller/DashScopeController.java @@ -1,7 +1,9 @@ package com.hanserwei.snailsai.controller; +import com.hanserwei.snailsai.model.ChatMessageDTO; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; @@ -15,9 +17,10 @@ public class DashScopeController { private ChatClient dashScopeChatClient; @PostMapping(value = "/chat",produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux chat(@RequestBody String userPrompt) { + public Flux chat(@RequestBody ChatMessageDTO chatMessageDTO) { return dashScopeChatClient.prompt() - .user(userPrompt) + .user(chatMessageDTO.getMessage()) + .advisors(e-> e.param(ChatMemory.CONVERSATION_ID,chatMessageDTO.getConversionId())) .stream() .content(); } diff --git a/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java b/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java new file mode 100644 index 0000000..923e838 --- /dev/null +++ b/src/main/java/com/hanserwei/snailsai/model/ChatMessageDTO.java @@ -0,0 +1,17 @@ +package com.hanserwei.snailsai.model; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +@Builder +public class ChatMessageDTO { + + private String message; + + private Long conversionId; +} diff --git a/src/main/resources/config/application.yml b/src/main/resources/config/application.yml index a44f393..cfcca7f 100644 --- a/src/main/resources/config/application.yml +++ b/src/main/resources/config/application.yml @@ -1,4 +1,9 @@ #file: noinspection SpringBootConfigYamlInspection +server: + servlet: + encoding: + charset: utf-8 + force: true spring: application: name: snails-ai @@ -12,6 +17,20 @@ spring: hibernate.format_sql: true # 开启 SQL 语法高亮 (重点:让 SQL 醒目) hibernate.highlight_sql: true + data: + redis: + host: localhost + port: 6379 + password: redis + database: 4 + lettuce: + pool: + enabled: true + max-active: 20 + max-idle: 10 + max-wait: 10000 + min-idle: 10 + time-between-eviction-runs: 10000 datasource: driver-class-name: com.mysql.cj.jdbc.Driver url: jdbc:mysql://127.0.0.1:3306/snails_ai?useUnicode=true&characterEncoding=utf-8&zeroDateTimeBehavior=convertToNull&transformedBitIsBoolean=true&allowMultiQueries=true&useSSL=false&allowPublicKeyRetrieval=true @@ -24,6 +43,12 @@ spring: connection-timeout: 5000 # 获取连接超时 5 秒 max-lifetime: 28800000 # 8 小时(确保在数据库连接超时前被回收) ai: + memory: + redis: + host: localhost + port: 6379 + timeout: 5000 + password: redis dashscope: api-key: ENC(cMgcKZkFllyE88DIbGwLKot9Vg02co+gsmY8L8o4/o3UjhcmqO4lJzFU35Sx0n+qFG8pDL0wBjoWrT8X6BuRw9vNlQhY1LgRWHaF9S1zzyM=) chat: @@ -35,6 +60,7 @@ logging: org.hibernate.SQL: debug # 隐藏掉 Hibernate 冗长的连接池 INFO 信息 org.hibernate.orm.connections.pooling: WARN + org.springframework.ai.chat.client.advisor: DEBUG jasypt: encryptor: password: ${jasypt.encryptor.password}