diff --git a/src/main/java/com/hanserwei/airobot/controller/DashscopeAIController.java b/src/main/java/com/hanserwei/airobot/controller/DashscopeAIController.java index 86984f9..c1d7225 100644 --- a/src/main/java/com/hanserwei/airobot/controller/DashscopeAIController.java +++ b/src/main/java/com/hanserwei/airobot/controller/DashscopeAIController.java @@ -4,21 +4,40 @@ import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; import com.hanserwei.airobot.model.AIResponse; import jakarta.annotation.Resource; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.beans.factory.annotation.Value; import org.springframework.http.MediaType; +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.CollectionUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + @RestController @RequestMapping("/v6/ai") public class DashscopeAIController { + // 存储聊天对话 + private final Map> chatMemoryStore = new ConcurrentHashMap<>(); + @Resource private DashScopeChatModel dashScopeChatModel; + @Value("classpath:/prompts/code-assistant.st") + private org.springframework.core.io.Resource templateResource; + /** * 普通对话 * @@ -26,13 +45,38 @@ public class DashscopeAIController { * @return 对话结果 */ @GetMapping("/generate") - public String generate(@RequestParam(value = "message", defaultValue = "你是谁?") String message) { + public String generate(@RequestParam(value = "message", defaultValue = "你是谁?") String message, + @RequestParam(value = "chatId") String chatId) + { + // 提示词模板 + PromptTemplate promptTemplate = new PromptTemplate(templateResource); + // 根据 chatId 获取对话记录 + List messages = chatMemoryStore.get(chatId); + // 若不存在,则初始化一份 + if (CollectionUtils.isEmpty(messages)) { + messages = new ArrayList<>(); + chatMemoryStore.put(chatId, messages); + } + + // 添加 “用户角色消息” 到聊天记录中 + messages.add(new UserMessage(message)); + + // 构建提示词 + Prompt prompt = new Prompt(messages); // 一次性返回结果 ChatClient chatClient = ChatClient.builder(dashScopeChatModel).build(); - return chatClient.prompt() - .user(message) - .call() - .content(); + String responseText = Objects.requireNonNull(chatClient.prompt(prompt) + .call() + .chatResponse()) + .getResult() + .getOutput() + .getText(); + // 添加 “助手角色消息” 到聊天记录中 + if (responseText != null) { + messages.add(new AssistantMessage(responseText)); + } + + return responseText; } /** diff --git a/src/main/java/com/hanserwei/airobot/controller/PromptTemplateController.java b/src/main/java/com/hanserwei/airobot/controller/PromptTemplateController.java new file mode 100644 index 0000000..43997d9 --- /dev/null +++ b/src/main/java/com/hanserwei/airobot/controller/PromptTemplateController.java @@ -0,0 +1,136 @@ +package com.hanserwei.airobot.controller; + +import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel; +import com.hanserwei.airobot.model.AIResponse; +import jakarta.annotation.Resource; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.template.st.StTemplateRenderer; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Map; + +@RestController +@RequestMapping("/v7/ai") +public class PromptTemplateController { + + @Resource + private DashScopeChatModel chatModel; + + /** + * 智能代码生成 + * @param message + * @param lang + * @return + */ + @GetMapping(value = "/generateStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux generateStream(@RequestParam(value = "message") String message, + @RequestParam(value = "lang") String lang) { + // 提示词模板 + String template = """ + 你是一位资深 {lang} 开发工程师。请严格遵循以下要求编写代码: + 1. 功能描述:{description} + 2. 代码需包含详细注释 + 3. 使用业界最佳实践 + """; + + PromptTemplate promptTemplate = new PromptTemplate(template); + + // 填充提示词占位符,转换为 Prompt 提示词对象 + Prompt prompt = promptTemplate.create(Map.of("description", message, "lang", lang)); + + // 流式输出 + return chatModel.stream(prompt) + .mapNotNull(chatResponse -> { + Generation generation = chatResponse.getResult(); + String text = generation.getOutput().getText(); + return AIResponse.builder().v(text).build(); + }); + } + + + /** + * 智能代码生成 2 + * @param message + * @param lang + * @return + */ + @GetMapping(value = "/generateStream2", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux generateStream2(@RequestParam(value = "message") String message, + @RequestParam(value = "lang") String lang) { + // 提示词模板 + PromptTemplate promptTemplate = PromptTemplate.builder() + .renderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) // 自定义占位符 + .template(""" + 你是一位资深 开发工程师。请严格遵循以下要求编写代码: + 1. 功能描述: + 2. 代码需包含详细注释 + 3. 使用业界最佳实践 + """) + .build(); + + // 填充提示词占位符,转换为 Prompt 提示词对象 + Prompt prompt = promptTemplate.create(Map.of("description", message, "lang", lang)); + + // 流式输出 + return chatModel.stream(prompt) + .mapNotNull(chatResponse -> { + Generation generation = chatResponse.getResult(); + String text = generation.getOutput().getText(); + return AIResponse.builder().v(text).build(); + }); + } + + /** + * 智能代码生成 3 + * @param message + * @param lang + * @return + */ + @GetMapping(value = "/generateStream3", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux generateStream3(@RequestParam(value = "message") String message, + @RequestParam(value = "lang") String lang) { + + // 系统角色提示词模板 + String systemPrompt = """ + 你是一位资深 {lang} 开发工程师, 已经从业数十年,经验非常丰富。 + """; + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt); + // 填充提示词占位符,并转换为 Message 对象 + Message systemMessage = systemPromptTemplate.createMessage(Map.of("lang", lang)); + + // 用户角色提示词模板 + String userPrompt = """ + 请严格遵循以下要求编写代码: + 1. 功能描述:{description} + 2. 代码需包含详细注释 + 3. 使用业界最佳实践 + """; + PromptTemplate promptTemplate = new PromptTemplate(userPrompt); + // 填充提示词占位符,并转换为 Message 对象 + Message userMessage = promptTemplate.createMessage(Map.of("description", message)); + + + // 组合多角色消息,构建提示词 Prompt + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + + // 流式输出 + return chatModel.stream(prompt) + .mapNotNull(chatResponse -> { + Generation generation = chatResponse.getResult(); + String text = generation.getOutput().getText(); + return AIResponse.builder().v(text).build(); + }); + } + +} \ No newline at end of file diff --git a/src/main/resources/prompts/code-assistant.st b/src/main/resources/prompts/code-assistant.st new file mode 100644 index 0000000..54018a5 --- /dev/null +++ b/src/main/resources/prompts/code-assistant.st @@ -0,0 +1,4 @@ +你是一位资深 {lang} 开发工程师。请严格遵循以下要求编写代码: +1. 功能描述:{description} +2. 代码需包含详细注释 +3. 使用业界最佳实践