diff --git a/pom.xml b/pom.xml
index b08430d..ab5d8b5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -79,6 +79,10 @@
com.baomidou
mybatis-plus-jsqlparser
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jsr310
+
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java b/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java
index ca0d99f..6c5b481 100644
--- a/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java
+++ b/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java
@@ -3,6 +3,7 @@ package com.hanserwei.chat.config;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository;
import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository;
+import com.hanserwei.chat.tools.AiDBTools;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
@@ -28,6 +29,8 @@ public class ChatClientConfiguration {
@Resource
private DashScopeChatModel dashScopeChatModel;
+ @Resource
+ private AiDBTools aiDBTools;
@Bean
public BaseRedisChatMemoryRepository redisChatMemoryRepository() {
@@ -51,6 +54,7 @@ public class ChatClientConfiguration {
@Bean
public ChatClient dashScopeChatClient(ChatMemory chatMemory) {
return ChatClient.builder(dashScopeChatModel)
+ .defaultTools(aiDBTools)
.defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor())
.build();
}
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java b/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java
new file mode 100644
index 0000000..1256208
--- /dev/null
+++ b/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java
@@ -0,0 +1,36 @@
+package com.hanserwei.chat.controller;
+
+import com.hanserwei.chat.model.dto.ChatMessageDTO;
+import com.hanserwei.chat.model.vo.AIResponse;
+import jakarta.annotation.Resource;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.http.MediaType;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+import reactor.core.publisher.Flux;
+
+@RestController
+@RequestMapping("/ai")
+public class AiChatController {
+
+ @Resource
+ private ChatClient dashScopeChatClient;
+
+ @PostMapping(path = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
+ public Flux chatWithAi(@RequestBody ChatMessageDTO chatMessageDTO) {
+
+ return dashScopeChatClient.prompt()
+ .user(chatMessageDTO.getMessage())
+ .advisors(p -> p.param(ChatMemory.CONVERSATION_ID, chatMessageDTO.getConversionId()))
+ .stream()
+ .chatResponse()
+ .mapNotNull(chatResponse -> AIResponse.builder()
+ .v(chatResponse.getResult().getOutput().getText())
+ .build());
+
+ }
+
+}
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java b/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java
index 6ca919a..32e26c9 100644
--- a/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java
+++ b/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java
@@ -4,12 +4,13 @@ import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
+import com.fasterxml.jackson.annotation.JsonFormat;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
-import java.time.LocalDateTime;
+import java.time.OffsetDateTime;
@Data
@Builder
@@ -30,7 +31,8 @@ public class User {
private Integer age;
@TableField(value = "created_at")
- private LocalDateTime createdAt;
+ @JsonFormat(shape = JsonFormat.Shape.STRING, pattern = "yyyy-MM-dd'T'HH:mm:ssXXX")
+ private OffsetDateTime createdAt;
@TableField(value = "is_active")
private Boolean isActive;
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java b/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java
new file mode 100644
index 0000000..56090d5
--- /dev/null
+++ b/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java
@@ -0,0 +1,17 @@
+package com.hanserwei.chat.model.dto;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+@Builder
+public class ChatMessageDTO {
+
+ private String message;
+
+ private Long conversionId;
+}
\ No newline at end of file
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java b/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java
new file mode 100644
index 0000000..8823a69
--- /dev/null
+++ b/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java
@@ -0,0 +1,15 @@
+package com.hanserwei.chat.model.vo;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@Builder
+@AllArgsConstructor
+@NoArgsConstructor
+public class AIResponse {
+ // 流式响应内容
+ private String v;
+}
\ No newline at end of file
diff --git a/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java b/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java
new file mode 100644
index 0000000..193640c
--- /dev/null
+++ b/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java
@@ -0,0 +1,97 @@
+package com.hanserwei.chat.tools;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.hanserwei.chat.domain.dataobject.User;
+import com.hanserwei.chat.service.UserService;
+import jakarta.annotation.Resource;
+import org.springframework.ai.tool.annotation.Tool;
+import org.springframework.ai.tool.annotation.ToolParam;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+
+@Component
+public class AiDBTools {
+
+ @Resource
+ private UserService userService;
+
+ @Tool(name = "findAll", description = "查询所有用户")
+ public List findAll() {
+ return userService.list();
+ }
+
+ @Tool(name = "findAllByIdIn", description = "根据id列表查询用户")
+ public List findAllByIdIn(@ToolParam(description = "用户id列表") List ids) {
+ return userService.listByIds(ids);
+ }
+
+ @Tool(name = "findById", description = "根据id查询用户")
+ public User findById(Long id) {
+ return userService.getById(id);
+ }
+
+ @Tool(name = "findByName", description = "根据名称查询用户")
+ public User findByName(String name) {
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class)
+ .eq(User::getName, name);
+ return userService.getOne(queryWrapper);
+ }
+
+ @Tool(name = "findByNameLike", description = "根据名称模糊查询用户")
+ public List findByNameLike(String name) {
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class)
+ .like(User::getName, name);
+ return userService.list(queryWrapper);
+ }
+
+ @Tool(name = "findByAge", description = "根据年龄查询用户")
+ public List findByAge(Integer age) {
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class)
+ .eq(User::getAge, age);
+ return userService.list(queryWrapper);
+ }
+
+ @Tool(name = "findByAgeBetween", description = "根据年龄范围查询用户")
+ public List findByAgeBetween(Integer start, Integer end) {
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class)
+ .between(User::getAge, start, end);
+ return userService.list(queryWrapper);
+ }
+
+ // 插入 数据
+ @Tool(name = "insert",
+ description = """
+ 插入一个新的用户。
+ 需要一个用户对象作为参数,
+ 该对象必须包含以下字段:
+ name (String), email (String), 和 age (Integer)。
+ """)
+ public void insert(@ToolParam(description = "用户对象") User user) {
+ userService.save(user);
+ }
+
+ @Tool(name = "update",
+ description = """
+ 更新现有用户。需要一个用户对象作为参数,该对象 **必须包含用户 ID**,
+ 并携带要修改的字段,例如 name (String), email (String), 或 age (Integer)。
+ """)
+ public void update(@ToolParam(description = "用户对象") User user) {
+ userService.updateById(user);
+ }
+
+ @Tool(name = "delete", description = "删除用户")
+ public void delete(Long id) {
+ userService.removeById(id);
+ }
+
+ //封禁
+ @Tool(name = "ban", description = "根据用户ID封禁用户。")
+ public void ban(@ToolParam(description = "用户id") Long id) {
+ userService.update(User.builder()
+ .isActive(false)
+ .build(),
+ new LambdaQueryWrapper<>(User.class)
+ .eq(User::getId, id));
+ }
+}
diff --git a/snails-chat/src/main/resources/config/application.yml b/snails-chat/src/main/resources/config/application.yml
index 86c7b3e..04dbb8f 100644
--- a/snails-chat/src/main/resources/config/application.yml
+++ b/snails-chat/src/main/resources/config/application.yml
@@ -9,6 +9,9 @@ spring:
name: snails-ai
banner:
location: config/banner.txt
+ jackson:
+ serialization:
+ write-dates-as-timestamps: false
data:
redis:
host: localhost
@@ -25,7 +28,7 @@ spring:
time-between-eviction-runs: 10000
datasource:
driver-class-name: org.postgresql.Driver
- url: jdbc:postgresql://localhost:5432/postgres
+ url: jdbc:postgresql://localhost:5432/postgres?serverTimezone=Asia/Shanghai
username: postgres
password: postgressql
# HikariCP 连接池配置