feat(user): 添加用户查询工具和测试数据接口

- 新增 QueryTool 类,提供 findAll 和 findAllByIdIn 方法用于查询用户
- 在 ChatClientConfiguration 中注册 QueryTool 为默认工具
- 创建 TestDataController,提供生成测试用户数据的接口- 新增 UserService 和 UserRepository,实现用户数据的批量插入和查询功能
- 将 ChatMessageDTO 从 model 包移动到 dto 包,优化包结构
-为 UserEntity 添加 createTime 和 updateTime 字段,完善实体类审计信息
- 新增 RedisConfig 配置类,为后续 Redis 功能做准备
This commit is contained in:
2025-10-23 18:08:30 +08:00
parent f8ff5808e5
commit 40c05838f7
9 changed files with 162 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ package com.hanserwei.snailsai.config;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository; import com.alibaba.cloud.ai.memory.redis.BaseRedisChatMemoryRepository;
import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository; import com.alibaba.cloud.ai.memory.redis.LettuceRedisChatMemoryRepository;
import com.hanserwei.snailsai.tools.QueryTool;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
@@ -28,6 +29,8 @@ public class ChatClientConfiguration {
@Resource @Resource
private DashScopeChatModel dashScopeChatModel; private DashScopeChatModel dashScopeChatModel;
@Resource
private QueryTool queryTool;
@Bean @Bean
public BaseRedisChatMemoryRepository redisChatMemoryRepository() { public BaseRedisChatMemoryRepository redisChatMemoryRepository() {
@@ -51,6 +54,7 @@ public class ChatClientConfiguration {
@Bean @Bean
public ChatClient dashScopeChatClient(ChatMemory chatMemory) { public ChatClient dashScopeChatClient(ChatMemory chatMemory) {
return ChatClient.builder(dashScopeChatModel) return ChatClient.builder(dashScopeChatModel)
.defaultTools(queryTool)
.defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor()) .defaultAdvisors(PromptChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor())
.build(); .build();
} }

View File

@@ -0,0 +1,7 @@
package com.hanserwei.snailsai.config;
import org.springframework.context.annotation.Configuration;
@Configuration
public class RedisConfig {
}

View File

@@ -1,7 +1,7 @@
package com.hanserwei.snailsai.controller; package com.hanserwei.snailsai.controller;
import com.hanserwei.snailsai.model.AIResponse; import com.hanserwei.snailsai.model.AIResponse;
import com.hanserwei.snailsai.model.ChatMessageDTO; import com.hanserwei.snailsai.dto.ChatMessageDTO;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;

View File

@@ -0,0 +1,29 @@
package com.hanserwei.snailsai.controller;
import com.hanserwei.snailsai.entity.UserEntity;
import com.hanserwei.snailsai.service.UserService;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
public class TestDataController {
private final UserService userService;
public TestDataController(UserService userService) {
this.userService = userService;
}
/**
* POST /api/test/generate-users?count=100
* 插入假数据用于分页测试
*/
@PostMapping("/api/test/generate-users")
public String generateTestData(@RequestParam(defaultValue = "100") int count) {
List<UserEntity> insertedUsers = userService.insertDummyUsers(count);
return String.format("成功插入了 %d 条假数据!", insertedUsers.size());
}
}

View File

@@ -1,4 +1,4 @@
package com.hanserwei.snailsai.model; package com.hanserwei.snailsai.dto;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;

View File

@@ -7,6 +7,7 @@ import lombok.NoArgsConstructor;
import java.io.Serial; import java.io.Serial;
import java.io.Serializable; import java.io.Serializable;
import java.time.LocalDateTime;
@Entity @Entity
@@ -45,9 +46,9 @@ public class UserEntity implements Serializable {
// 5. 补充审计字段 (推荐) // 5. 补充审计字段 (推荐)
@Column(name = "create_time", nullable = false) @Column(name = "create_time", nullable = false)
private java.time.LocalDateTime createTime; private LocalDateTime createTime;
@Column(name = "update_time") @Column(name = "update_time")
private java.time.LocalDateTime updateTime; private LocalDateTime updateTime;
} }

View File

@@ -0,0 +1,21 @@
package com.hanserwei.snailsai.repository;
import com.hanserwei.snailsai.entity.UserEntity;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
public interface UserRepository extends JpaRepository<UserEntity, Long> {
@NotNull List<UserEntity> findAll();
@NotNull List<UserEntity> findAllByIdIn(List<Long> ids);
@NotNull List<UserEntity> findAllByUsernameContaining(String name);
}

View File

@@ -0,0 +1,69 @@
package com.hanserwei.snailsai.service;
import com.hanserwei.snailsai.entity.UserEntity;
import com.hanserwei.snailsai.repository.UserRepository;
import jakarta.transaction.Transactional;
import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
@Service
public class UserService {
private final UserRepository userRepository;
// 构造器注入
public UserService(UserRepository userRepository) {
this.userRepository = userRepository;
}
/**
* 批量生成并插入指定数量的假用户数据
* @param count 要插入的用户数量
* @return 插入成功的用户列表
*/
@Transactional // 确保整个批量操作在一个事务中完成
public List<UserEntity> insertDummyUsers(int count) {
if (count <= 0) {
return List.of(); // 返回空列表
}
List<UserEntity> dummyUsers = new ArrayList<>(count);
LocalDateTime now = LocalDateTime.now();
// 确保用户名是唯一的
// 我们可以先获取当前数据库中用户数量,作为生成唯一用户名的起始点
long startId = userRepository.count();
for (int i = 1; i <= count; i++) {
// 使用 startId + i 来保证生成的用户名在多次运行时尽可能不重复
long userIndex = startId + i;
// 注意:密码通常应该是加密后的,这里为了演示使用明文
UserEntity user = new UserEntity(
null, // ID 设为 null让 JPA 自动生成
"testUser_" + userIndex, // 确保用户名唯一
"123456", // 模拟一个加密后的密码,或者使用一个测试用的明文,例如 "password"
(long) (i % 3) + 1, // 随机分配一个 hobbyId (例如 1, 2, 3)
now.plusSeconds(i), // 模拟创建时间略微递增
null // updateTime 初始为 null
);
dummyUsers.add(user);
}
// 使用 JpaRepository 的 saveAll 方法进行批量插入,效率比单个 save 要高
return userRepository.saveAll(dummyUsers);
}
public List<UserEntity> findAll() {
return userRepository.findAll();
}
public List<UserEntity> findAllByIdIn(List<Long> ids) {
return userRepository.findAllByIdIn(ids);
}
}

View File

@@ -0,0 +1,27 @@
package com.hanserwei.snailsai.tools;
import com.hanserwei.snailsai.entity.UserEntity;
import com.hanserwei.snailsai.service.UserService;
import jakarta.annotation.Resource;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
public class QueryTool {
@Resource
private UserService userService;
@Tool(name = "findAll", description = "查询所有用户")
public List<UserEntity> findAll() {
return userService.findAll();
}
@Tool(name = "findAllByIdIn", description = "根据id列表查询用户")
public List<UserEntity> findAllByIdIn(List<Long> ids) {
return userService.findAllByIdIn(ids);
}
}