diff --git a/pom.xml b/pom.xml index b08430d..ab5d8b5 100644 --- a/pom.xml +++ b/pom.xml @@ -79,6 +79,10 @@ com.baomidou mybatis-plus-jsqlparser + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + diff --git a/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java b/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java index ca0d99f..6c5b481 100644 --- a/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java +++ b/snails-chat/src/main/java/com/hanserwei/chat/config/ChatClientConfiguration.java @@ -3,6 +3,7 @@ package com.hanserwei.chat.config; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository; import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository; +import com.hanserwei.chat.tools.AiDBTools; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; @@ -28,6 +29,8 @@ public class ChatClientConfiguration { @Resource private DashScopeChatModel dashScopeChatModel; + @Resource + private AiDBTools aiDBTools; @Bean public BaseRedisChatMemoryRepository redisChatMemoryRepository() { @@ -51,6 +54,7 @@ public class ChatClientConfiguration { @Bean public ChatClient dashScopeChatClient(ChatMemory chatMemory) { return ChatClient.builder(dashScopeChatModel) + .defaultTools(aiDBTools) .defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor()) .build(); } diff --git a/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java b/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java new file mode 100644 index 0000000..1256208 --- /dev/null +++ b/snails-chat/src/main/java/com/hanserwei/chat/controller/AiChatController.java @@ -0,0 +1,36 @@ +package com.hanserwei.chat.controller; + +import com.hanserwei.chat.model.dto.ChatMessageDTO; +import com.hanserwei.chat.model.vo.AIResponse; +import jakarta.annotation.Resource; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; + +@RestController +@RequestMapping("/ai") +public class AiChatController { + + @Resource + private ChatClient dashScopeChatClient; + + @PostMapping(path = "/chat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux chatWithAi(@RequestBody ChatMessageDTO chatMessageDTO) { + + return dashScopeChatClient.prompt() + .user(chatMessageDTO.getMessage()) + .advisors(p -> p.param(ChatMemory.CONVERSATION_ID, chatMessageDTO.getConversionId())) + .stream() + .chatResponse() + .mapNotNull(chatResponse -> AIResponse.builder() + .v(chatResponse.getResult().getOutput().getText()) + .build()); + + } + +} diff --git a/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java b/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java index 6ca919a..32e26c9 100644 --- a/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java +++ b/snails-chat/src/main/java/com/hanserwei/chat/domain/dataobject/User.java @@ -4,12 +4,13 @@ import com.baomidou.mybatisplus.annotation.IdType; import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; +import com.fasterxml.jackson.annotation.JsonFormat; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; -import java.time.LocalDateTime; +import java.time.OffsetDateTime; @Data @Builder @@ -30,7 +31,8 @@ public class User { private Integer age; @TableField(value = "created_at") - private LocalDateTime createdAt; + @JsonFormat(shape = JsonFormat.Shape.STRING, pattern = "yyyy-MM-dd'T'HH:mm:ssXXX") + private OffsetDateTime createdAt; @TableField(value = "is_active") private Boolean isActive; diff --git a/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java b/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java new file mode 100644 index 0000000..56090d5 --- /dev/null +++ b/snails-chat/src/main/java/com/hanserwei/chat/model/dto/ChatMessageDTO.java @@ -0,0 +1,17 @@ +package com.hanserwei.chat.model.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +@Builder +public class ChatMessageDTO { + + private String message; + + private Long conversionId; +} \ No newline at end of file diff --git a/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java b/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java new file mode 100644 index 0000000..8823a69 --- /dev/null +++ b/snails-chat/src/main/java/com/hanserwei/chat/model/vo/AIResponse.java @@ -0,0 +1,15 @@ +package com.hanserwei.chat.model.vo; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class AIResponse { + // 流式响应内容 + private String v; +} \ No newline at end of file diff --git a/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java b/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java new file mode 100644 index 0000000..193640c --- /dev/null +++ b/snails-chat/src/main/java/com/hanserwei/chat/tools/AiDBTools.java @@ -0,0 +1,97 @@ +package com.hanserwei.chat.tools; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.hanserwei.chat.domain.dataobject.User; +import com.hanserwei.chat.service.UserService; +import jakarta.annotation.Resource; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +public class AiDBTools { + + @Resource + private UserService userService; + + @Tool(name = "findAll", description = "查询所有用户") + public List findAll() { + return userService.list(); + } + + @Tool(name = "findAllByIdIn", description = "根据id列表查询用户") + public List findAllByIdIn(@ToolParam(description = "用户id列表") List ids) { + return userService.listByIds(ids); + } + + @Tool(name = "findById", description = "根据id查询用户") + public User findById(Long id) { + return userService.getById(id); + } + + @Tool(name = "findByName", description = "根据名称查询用户") + public User findByName(String name) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class) + .eq(User::getName, name); + return userService.getOne(queryWrapper); + } + + @Tool(name = "findByNameLike", description = "根据名称模糊查询用户") + public List findByNameLike(String name) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class) + .like(User::getName, name); + return userService.list(queryWrapper); + } + + @Tool(name = "findByAge", description = "根据年龄查询用户") + public List findByAge(Integer age) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class) + .eq(User::getAge, age); + return userService.list(queryWrapper); + } + + @Tool(name = "findByAgeBetween", description = "根据年龄范围查询用户") + public List findByAgeBetween(Integer start, Integer end) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(User.class) + .between(User::getAge, start, end); + return userService.list(queryWrapper); + } + + // 插入 数据 + @Tool(name = "insert", + description = """ + 插入一个新的用户。 + 需要一个用户对象作为参数, + 该对象必须包含以下字段: + name (String), email (String), 和 age (Integer)。 + """) + public void insert(@ToolParam(description = "用户对象") User user) { + userService.save(user); + } + + @Tool(name = "update", + description = """ + 更新现有用户。需要一个用户对象作为参数,该对象 **必须包含用户 ID**, + 并携带要修改的字段,例如 name (String), email (String), 或 age (Integer)。 + """) + public void update(@ToolParam(description = "用户对象") User user) { + userService.updateById(user); + } + + @Tool(name = "delete", description = "删除用户") + public void delete(Long id) { + userService.removeById(id); + } + + //封禁 + @Tool(name = "ban", description = "根据用户ID封禁用户。") + public void ban(@ToolParam(description = "用户id") Long id) { + userService.update(User.builder() + .isActive(false) + .build(), + new LambdaQueryWrapper<>(User.class) + .eq(User::getId, id)); + } +} diff --git a/snails-chat/src/main/resources/config/application.yml b/snails-chat/src/main/resources/config/application.yml index 86c7b3e..04dbb8f 100644 --- a/snails-chat/src/main/resources/config/application.yml +++ b/snails-chat/src/main/resources/config/application.yml @@ -9,6 +9,9 @@ spring: name: snails-ai banner: location: config/banner.txt + jackson: + serialization: + write-dates-as-timestamps: false data: redis: host: localhost @@ -25,7 +28,7 @@ spring: time-between-eviction-runs: 10000 datasource: driver-class-name: org.postgresql.Driver - url: jdbc:postgresql://localhost:5432/postgres + url: jdbc:postgresql://localhost:5432/postgres?serverTimezone=Asia/Shanghai username: postgres password: postgressql # HikariCP 连接池配置