feat(ai): 实现基于模板的智能代码生成功能
- 新增 PromptTemplateController 控制器,支持多种提示词模板方式 - 支持流式输出智能生成的代码内容- 提供系统角色与用户角色组合的提示词构建方式 - 新增 code-assistant.st 模板文件用于代码生成场景 - 扩展 DashscopeAIController,增加对话记忆功能 - 支持通过 chatId 维护多轮对话上下文- 引入 Spring AI 相关依赖以支持提示词模板和消息管理
This commit is contained in:
@@ -4,21 +4,40 @@ import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||
import com.hanserwei.airobot.model.AIResponse;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/v6/ai")
|
||||
public class DashscopeAIController {
|
||||
|
||||
// 存储聊天对话
|
||||
private final Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();
|
||||
|
||||
@Resource
|
||||
private DashScopeChatModel dashScopeChatModel;
|
||||
|
||||
@Value("classpath:/prompts/code-assistant.st")
|
||||
private org.springframework.core.io.Resource templateResource;
|
||||
|
||||
/**
|
||||
* 普通对话
|
||||
*
|
||||
@@ -26,13 +45,38 @@ public class DashscopeAIController {
|
||||
* @return 对话结果
|
||||
*/
|
||||
@GetMapping("/generate")
|
||||
public String generate(@RequestParam(value = "message", defaultValue = "你是谁?") String message) {
|
||||
public String generate(@RequestParam(value = "message", defaultValue = "你是谁?") String message,
|
||||
@RequestParam(value = "chatId") String chatId)
|
||||
{
|
||||
// 提示词模板
|
||||
PromptTemplate promptTemplate = new PromptTemplate(templateResource);
|
||||
// 根据 chatId 获取对话记录
|
||||
List<Message> messages = chatMemoryStore.get(chatId);
|
||||
// 若不存在,则初始化一份
|
||||
if (CollectionUtils.isEmpty(messages)) {
|
||||
messages = new ArrayList<>();
|
||||
chatMemoryStore.put(chatId, messages);
|
||||
}
|
||||
|
||||
// 添加 “用户角色消息” 到聊天记录中
|
||||
messages.add(new UserMessage(message));
|
||||
|
||||
// 构建提示词
|
||||
Prompt prompt = new Prompt(messages);
|
||||
// 一次性返回结果
|
||||
ChatClient chatClient = ChatClient.builder(dashScopeChatModel).build();
|
||||
return chatClient.prompt()
|
||||
.user(message)
|
||||
String responseText = Objects.requireNonNull(chatClient.prompt(prompt)
|
||||
.call()
|
||||
.content();
|
||||
.chatResponse())
|
||||
.getResult()
|
||||
.getOutput()
|
||||
.getText();
|
||||
// 添加 “助手角色消息” 到聊天记录中
|
||||
if (responseText != null) {
|
||||
messages.add(new AssistantMessage(responseText));
|
||||
}
|
||||
|
||||
return responseText;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
package com.hanserwei.airobot.controller;
|
||||
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
|
||||
import com.hanserwei.airobot.model.AIResponse;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/v7/ai")
|
||||
public class PromptTemplateController {
|
||||
|
||||
@Resource
|
||||
private DashScopeChatModel chatModel;
|
||||
|
||||
/**
|
||||
* 智能代码生成
|
||||
* @param message
|
||||
* @param lang
|
||||
* @return
|
||||
*/
|
||||
@GetMapping(value = "/generateStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<AIResponse> generateStream(@RequestParam(value = "message") String message,
|
||||
@RequestParam(value = "lang") String lang) {
|
||||
// 提示词模板
|
||||
String template = """
|
||||
你是一位资深 {lang} 开发工程师。请严格遵循以下要求编写代码:
|
||||
1. 功能描述:{description}
|
||||
2. 代码需包含详细注释
|
||||
3. 使用业界最佳实践
|
||||
""";
|
||||
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template);
|
||||
|
||||
// 填充提示词占位符,转换为 Prompt 提示词对象
|
||||
Prompt prompt = promptTemplate.create(Map.of("description", message, "lang", lang));
|
||||
|
||||
// 流式输出
|
||||
return chatModel.stream(prompt)
|
||||
.mapNotNull(chatResponse -> {
|
||||
Generation generation = chatResponse.getResult();
|
||||
String text = generation.getOutput().getText();
|
||||
return AIResponse.builder().v(text).build();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 智能代码生成 2
|
||||
* @param message
|
||||
* @param lang
|
||||
* @return
|
||||
*/
|
||||
@GetMapping(value = "/generateStream2", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<AIResponse> generateStream2(@RequestParam(value = "message") String message,
|
||||
@RequestParam(value = "lang") String lang) {
|
||||
// 提示词模板
|
||||
PromptTemplate promptTemplate = PromptTemplate.builder()
|
||||
.renderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) // 自定义占位符
|
||||
.template("""
|
||||
你是一位资深 <lang> 开发工程师。请严格遵循以下要求编写代码:
|
||||
1. 功能描述:<description>
|
||||
2. 代码需包含详细注释
|
||||
3. 使用业界最佳实践
|
||||
""")
|
||||
.build();
|
||||
|
||||
// 填充提示词占位符,转换为 Prompt 提示词对象
|
||||
Prompt prompt = promptTemplate.create(Map.of("description", message, "lang", lang));
|
||||
|
||||
// 流式输出
|
||||
return chatModel.stream(prompt)
|
||||
.mapNotNull(chatResponse -> {
|
||||
Generation generation = chatResponse.getResult();
|
||||
String text = generation.getOutput().getText();
|
||||
return AIResponse.builder().v(text).build();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 智能代码生成 3
|
||||
* @param message
|
||||
* @param lang
|
||||
* @return
|
||||
*/
|
||||
@GetMapping(value = "/generateStream3", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<AIResponse> generateStream3(@RequestParam(value = "message") String message,
|
||||
@RequestParam(value = "lang") String lang) {
|
||||
|
||||
// 系统角色提示词模板
|
||||
String systemPrompt = """
|
||||
你是一位资深 {lang} 开发工程师, 已经从业数十年,经验非常丰富。
|
||||
""";
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
|
||||
// 填充提示词占位符,并转换为 Message 对象
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("lang", lang));
|
||||
|
||||
// 用户角色提示词模板
|
||||
String userPrompt = """
|
||||
请严格遵循以下要求编写代码:
|
||||
1. 功能描述:{description}
|
||||
2. 代码需包含详细注释
|
||||
3. 使用业界最佳实践
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(userPrompt);
|
||||
// 填充提示词占位符,并转换为 Message 对象
|
||||
Message userMessage = promptTemplate.createMessage(Map.of("description", message));
|
||||
|
||||
|
||||
// 组合多角色消息,构建提示词 Prompt
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
|
||||
// 流式输出
|
||||
return chatModel.stream(prompt)
|
||||
.mapNotNull(chatResponse -> {
|
||||
Generation generation = chatResponse.getResult();
|
||||
String text = generation.getOutput().getText();
|
||||
return AIResponse.builder().v(text).build();
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
4
src/main/resources/prompts/code-assistant.st
Normal file
4
src/main/resources/prompts/code-assistant.st
Normal file
@@ -0,0 +1,4 @@
|
||||
你是一位资深 {lang} 开发工程师。请严格遵循以下要求编写代码:
|
||||
1. 功能描述:{description}
|
||||
2. 代码需包含详细注释
|
||||
3. 使用业界最佳实践
|
||||
Reference in New Issue
Block a user