Spring AI的扩展开发:如何给Spring AI贡献新的Provider
Spring AI的扩展开发:如何给Spring AI贡献新的Provider
开篇故事:黄磊的"定制化需求"
2026年初,某大型国有银行的架构师黄磊遇到了一个问题。
银行的AI转型要求使用国产大模型(合规要求),但当时Spring AI官方还没有对他们使用的某款私有部署的金融专用大模型的官方支持。
两条路:
- 自己在业务代码里直接调用HTTP API(意味着失去Spring AI的所有生态:记忆管理/向量检索/流式响应/Advisor机制...)
- 自己实现一个Spring AI Provider(学习成本高,但一次投入,团队所有项目受益)
黄磊选择了第2条路。他花了3天时间研究Spring AI的扩展接口,又花了2天实现了一个完整的Provider。
之后,他将这个Provider以合适的方式分享给了Spring AI社区,获得了大量正向反馈,也帮助了其他银行的技术团队。
这次经历让黄磊深刻体会到:理解框架的扩展机制,比只会使用框架更有竞争力。
本文将完整带你实现一个Spring AI Provider,以"通义千问(Qwen)"为例,覆盖ChatModel、EmbeddingModel、流式响应、Spring Boot AutoConfiguration,以及单元测试。
TL;DR
- Provider接口体系:ChatModel/EmbeddingModel/SpeechModel是核心SPI
- ChatModel实现:处理ChatPrompt→ChatResponse的完整转换
- 流式支持:实现
stream()方法,返回Flux<ChatResponse> - EmbeddingModel:向量化文本,与RAG系统集成
- Spring Boot集成:AutoConfiguration + @ConditionalOn + Properties
- 测试策略:Mock HTTP + Spring Boot集成测试
一、Spring AI Provider架构解析
1.1 核心SPI接口
Spring AI Provider体系:
AiProvider(标记接口)
├── ChatModel(对话模型)
│ ├── call(Prompt) → ChatResponse
│ ├── stream(Prompt) → Flux<ChatResponse>
│ └── getDefaultOptions() → ChatOptions
│
├── EmbeddingModel(嵌入模型)
│ ├── embed(String) → float[]
│ ├── embed(Document) → float[]
│ └── embedForResponse(List<String>) → EmbeddingResponse
│
├── ImageModel(图像生成模型)
│ └── call(ImagePrompt) → ImageResponse
│
└── SpeechModel(语音模型)
└── call(SpeechPrompt) → SpeechResponse1.2 Provider的完整文件结构
spring-ai-qwen/
├── pom.xml
└── src/
├── main/java/org/springframework/ai/qwen/
│ ├── QwenChatModel.java # 核心ChatModel实现
│ ├── QwenEmbeddingModel.java # EmbeddingModel实现
│ ├── api/
│ │ ├── QwenApi.java # HTTP客户端封装
│ │ └── dto/ # 请求/响应DTO
│ └── metadata/
│ └── QwenChatResponseMetadata.java
├── main/resources/META-INF/
│ └── spring/
│ └── org.springframework.boot.autoconfigure.AutoConfiguration.imports
└── test/java/
└── QwenChatModelTests.java二、实现QwenApi:HTTP客户端封装
2.1 通义千问API规范
// QwenApi.java
@Slf4j
public class QwenApi {
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
private final WebClient webClient;
public QwenApi(String apiKey) {
this(DEFAULT_BASE_URL, apiKey);
}
public QwenApi(String baseUrl, String apiKey) {
this.webClient = WebClient.builder()
.baseUrl(baseUrl)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.codecs(configurer -> configurer.defaultCodecs()
.maxInMemorySize(10 * 1024 * 1024))
.build();
}
// 同步Chat调用
public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) {
return webClient.post()
.uri("/services/aigc/text-generation/generation")
.bodyValue(request)
.retrieve()
.onStatus(status -> status.is4xxClientError(), this::handle4xxError)
.onStatus(status -> status.is5xxServerError(), this::handle5xxError)
.bodyToMono(ChatCompletionResponse.class)
.timeout(Duration.ofSeconds(120))
.block();
}
// 流式Chat调用
public Flux<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
ChatCompletionRequest streamRequest = request.toBuilder()
.stream(true)
.streamOptions(new StreamOptions(true))
.build();
return webClient.post()
.uri("/services/aigc/text-generation/generation")
.header("X-DashScope-SSE", "enable") // 通义千问的流式标头
.bodyValue(streamRequest)
.retrieve()
.bodyToFlux(String.class)
.filter(line -> line.startsWith("data:"))
.map(line -> line.substring(5).trim())
.filter(data -> !data.equals("[DONE]"))
.flatMap(data -> {
try {
return Flux.just(objectMapper.readValue(data, ChatCompletionChunk.class));
} catch (JsonProcessingException e) {
log.warn("解析流式响应失败: {}", data);
return Flux.empty();
}
});
}
// Embedding调用
public EmbeddingResponse embedding(EmbeddingRequest request) {
return webClient.post()
.uri("/services/embeddings/text-embedding/text-embedding")
.bodyValue(request)
.retrieve()
.bodyToMono(EmbeddingResponse.class)
.timeout(Duration.ofSeconds(30))
.block();
}
private Mono<? extends Throwable> handle4xxError(ClientResponse response) {
return response.bodyToMono(String.class)
.map(body -> new QwenApiException("API请求失败: " + response.statusCode() + " - " + body));
}
private Mono<? extends Throwable> handle5xxError(ClientResponse response) {
return response.bodyToMono(String.class)
.map(body -> new QwenApiException("服务器错误: " + response.statusCode() + " - " + body));
}
}2.2 DTO数据传输对象
// ChatCompletionRequest.java
@Data
@Builder
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class ChatCompletionRequest {
private String model;
@Builder.Default
private Input input = new Input();
@Builder.Default
private Parameters parameters = new Parameters();
private Boolean stream;
private StreamOptions streamOptions;
@Data
@Builder
public static class Input {
private List<Message> messages;
}
@Data
@Builder
public static class Parameters {
private Integer maxTokens;
private Float temperature;
private Float topP;
private Integer topK;
private List<String> stop;
private String resultFormat; // "text" or "message"
@Builder.Default
private String resultFormat = "message";
// 工具调用
private List<Tool> tools;
private String toolChoice;
}
@Data
@Builder
public static class Message {
private String role; // "system" / "user" / "assistant" / "tool"
private String content;
private String name; // 工具调用时的工具名
private List<ToolCall> toolCalls;
private String toolCallId;
}
}
// ChatCompletionResponse.java
@Data
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class ChatCompletionResponse {
private String requestId;
private Output output;
private Usage usage;
private String code;
private String message;
@Data
public static class Output {
private String text;
private String finishReason;
private List<Choice> choices;
}
@Data
public static class Choice {
private String finishReason;
private Message message;
}
@Data
public static class Usage {
private Integer inputTokens;
private Integer outputTokens;
private Integer totalTokens;
}
}三、实现QwenChatModel
3.1 核心ChatModel实现
// QwenChatModel.java
@Slf4j
public class QwenChatModel implements ChatModel {
private static final String DEFAULT_MODEL = "qwen-turbo";
private final QwenApi qwenApi;
private final QwenChatOptions defaultOptions;
private final RetryTemplate retryTemplate;
private final ObservationRegistry observationRegistry;
// 主构造器
public QwenChatModel(QwenApi qwenApi, QwenChatOptions defaultOptions) {
this(qwenApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE,
ObservationRegistry.NOOP);
}
public QwenChatModel(QwenApi qwenApi, QwenChatOptions defaultOptions,
RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
Assert.notNull(qwenApi, "QwenApi must not be null");
Assert.notNull(defaultOptions, "QwenChatOptions must not be null");
this.qwenApi = qwenApi;
this.defaultOptions = defaultOptions;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}
/**
* 同步调用:Prompt → ChatResponse
*/
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
ChatCompletionRequest request = createRequest(prompt, false);
ChatCompletionResponse response = this.qwenApi.chatCompletion(request);
return toChatResponse(response);
});
}
/**
* 流式调用:Prompt → Flux<ChatResponse>
*/
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);
Flux<ChatCompletionChunk> chunkFlux =
this.qwenApi.streamChatCompletion(request);
return chunkFlux
.map(chunk -> toChatResponseFromChunk(chunk))
.filter(response -> response != null);
});
}
/**
* 获取默认Options
*/
@Override
public QwenChatOptions getDefaultOptions() {
return this.defaultOptions;
}
// 将Spring AI的Prompt转换为通义千问的请求格式
private ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 合并默认Options和请求级Options
QwenChatOptions runtimeOptions = mergeOptions(prompt.getOptions());
// 转换消息格式
List<ChatCompletionRequest.Message> messages = prompt.getInstructions()
.stream()
.map(this::toQwenMessage)
.toList();
// 转换工具(如果有)
List<ChatCompletionRequest.Tool> tools = null;
if (runtimeOptions.getTools() != null && !runtimeOptions.getTools().isEmpty()) {
tools = runtimeOptions.getTools().stream()
.map(this::toQwenTool)
.toList();
}
return ChatCompletionRequest.builder()
.model(runtimeOptions.getModel() != null ?
runtimeOptions.getModel() : DEFAULT_MODEL)
.input(ChatCompletionRequest.Input.builder()
.messages(messages)
.build())
.parameters(ChatCompletionRequest.Parameters.builder()
.maxTokens(runtimeOptions.getMaxTokens())
.temperature(runtimeOptions.getTemperature() != null ?
runtimeOptions.getTemperature().floatValue() : null)
.topP(runtimeOptions.getTopP() != null ?
runtimeOptions.getTopP().floatValue() : null)
.stop(runtimeOptions.getStopSequences())
.tools(tools)
.resultFormat("message")
.build())
.stream(stream)
.build();
}
// 转换消息类型
private ChatCompletionRequest.Message toQwenMessage(Message message) {
String role = switch (message.getMessageType()) {
case SYSTEM -> "system";
case USER -> "user";
case ASSISTANT -> "assistant";
case TOOL -> "tool";
};
// 处理多模态内容(图片等)
if (message instanceof UserMessage userMsg &&
!CollectionUtils.isEmpty(userMsg.getMedia())) {
List<Map<String, Object>> multimodalContent = new ArrayList<>();
multimodalContent.add(Map.of("text", message.getContent()));
for (Media media : userMsg.getMedia()) {
multimodalContent.add(Map.of(
"image", media.getData().toString()
));
}
return ChatCompletionRequest.Message.builder()
.role(role)
.content(objectMapper.writeValueAsString(multimodalContent))
.build();
}
return ChatCompletionRequest.Message.builder()
.role(role)
.content(message.getContent())
.build();
}
// 将通义千问响应转换为Spring AI的ChatResponse
private ChatResponse toChatResponse(ChatCompletionResponse response) {
if (response == null) {
throw new RuntimeException("Empty response from Qwen API");
}
List<Generation> generations = new ArrayList<>();
if (response.getOutput().getChoices() != null) {
for (ChatCompletionResponse.Choice choice : response.getOutput().getChoices()) {
AssistantMessage assistantMessage = new AssistantMessage(
choice.getMessage().getContent(),
Map.of(),
toToolCalls(choice.getMessage().getToolCalls())
);
ChatGenerationMetadata metadata = ChatGenerationMetadata.from(
choice.getFinishReason(), null);
generations.add(new Generation(assistantMessage, metadata));
}
}
ChatResponseMetadata metadata = QwenChatResponseMetadata.from(response);
return new ChatResponse(generations, metadata);
}
// 将Options合并(runtime优先于default)
private QwenChatOptions mergeOptions(ChatOptions runtimeOptions) {
QwenChatOptions merged = QwenChatOptions.builder()
.withModel(this.defaultOptions.getModel())
.withMaxTokens(this.defaultOptions.getMaxTokens())
.withTemperature(this.defaultOptions.getTemperature())
.build();
if (runtimeOptions instanceof QwenChatOptions qwenOptions) {
if (qwenOptions.getModel() != null) merged.setModel(qwenOptions.getModel());
if (qwenOptions.getMaxTokens() != null) merged.setMaxTokens(qwenOptions.getMaxTokens());
if (qwenOptions.getTemperature() != null) merged.setTemperature(qwenOptions.getTemperature());
}
return merged;
}
}四、实现QwenChatOptions
// QwenChatOptions.java
@Data
@Builder
@JsonInclude(JsonInclude.Include.NON_NULL)
public class QwenChatOptions implements ChatOptions {
// 通义千问支持的模型列表
public static final String MODEL_QWEN_TURBO = "qwen-turbo";
public static final String MODEL_QWEN_PLUS = "qwen-plus";
public static final String MODEL_QWEN_MAX = "qwen-max";
public static final String MODEL_QWEN_LONG = "qwen-long";
public static final String MODEL_QWEN_VL_PLUS = "qwen-vl-plus"; // 视觉模型
private String model;
@JsonProperty("max_tokens")
private Integer maxTokens;
private Double temperature;
@JsonProperty("top_p")
private Double topP;
private Integer topK;
private List<String> stopSequences;
// 工具调用
private List<FunctionCallback> tools;
private String toolChoice;
// 通义千问特有参数
private Boolean enableSearch; // 启用联网搜索
private Boolean incrementalOutput; // 流式输出时的增量模式
@Override
public Double getFrequencyPenalty() { return null; }
@Override
public Double getPresencePenalty() { return null; }
@Override
public List<String> getStopSequences() { return stopSequences; }
// Builder便利方法
public static QwenChatOptions.QwenChatOptionsBuilder builder() {
return new QwenChatOptionsBuilder();
}
// 工厂方法
public static QwenChatOptions fromOptions(ChatOptions options) {
if (options instanceof QwenChatOptions qo) return qo;
return QwenChatOptions.builder()
.withModel(options.getModel())
.withMaxTokens(options.getMaxTokens())
.withTemperature(options.getTemperature())
.build();
}
}五、实现EmbeddingModel
// QwenEmbeddingModel.java
@Slf4j
public class QwenEmbeddingModel extends AbstractEmbeddingModel {
public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-v3";
private final QwenApi qwenApi;
private final QwenEmbeddingOptions defaultOptions;
private final RetryTemplate retryTemplate;
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
return this.retryTemplate.execute(ctx -> {
QwenEmbeddingOptions options = mergeOptions(request.getOptions());
// 将Document列表转换为Qwen的Embedding请求
QwenApi.EmbeddingRequest qwenRequest = new QwenApi.EmbeddingRequest(
options.getModel(),
new QwenApi.EmbeddingInput(
request.getInstructions().stream()
.map(Document::getContent)
.toList()
),
new QwenApi.EmbeddingParameters(
"float" // 输出格式
)
);
QwenApi.EmbeddingResponse qwenResponse = this.qwenApi.embedding(qwenRequest);
// 转换响应
List<Embedding> embeddings = qwenResponse.getOutput().getEmbeddings()
.stream()
.map(e -> new Embedding(
toFloatArray(e.getEmbedding()),
e.getTextIndex()
))
.toList();
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(
qwenResponse.getModel(),
new DefaultUsage(
qwenResponse.getUsage().getInputTokens(),
0
)
);
return new EmbeddingResponse(embeddings, metadata);
});
}
@Override
public int dimensions() {
// text-embedding-v3的维度是1024(可配置)
return defaultOptions.getDimension() != null ?
defaultOptions.getDimension() : 1536;
}
private float[] toFloatArray(List<Double> doubles) {
float[] floats = new float[doubles.size()];
for (int i = 0; i < doubles.size(); i++) {
floats[i] = doubles.get(i).floatValue();
}
return floats;
}
}六、Spring Boot AutoConfiguration
6.1 Properties配置类
// QwenConnectionProperties.java
@ConfigurationProperties(prefix = "spring.ai.qwen")
@Data
public class QwenConnectionProperties extends AiParentConnectionProperties {
public static final String DEFAULT_BASE_URL =
"https://dashscope.aliyuncs.com/api/v1";
@NestedConfigurationProperty
private QwenChatProperties chat = new QwenChatProperties();
@NestedConfigurationProperty
private QwenEmbeddingProperties embedding = new QwenEmbeddingProperties();
}
// QwenChatProperties.java
@Data
@ConfigurationProperties(prefix = "spring.ai.qwen.chat")
public class QwenChatProperties {
public static final String CONFIG_PREFIX = "spring.ai.qwen.chat";
@NestedConfigurationProperty
private QwenChatOptions options = QwenChatOptions.builder()
.withModel(QwenChatOptions.MODEL_QWEN_TURBO)
.withMaxTokens(1024)
.withTemperature(0.7)
.build();
private boolean enabled = true;
}6.2 AutoConfiguration类
// QwenAutoConfiguration.java
@AutoConfiguration
@ConditionalOnClass(QwenApi.class)
@EnableConfigurationProperties({QwenConnectionProperties.class,
QwenChatProperties.class,
QwenEmbeddingProperties.class})
@Import({QwenChatAutoConfiguration.class, QwenEmbeddingAutoConfiguration.class})
public class QwenAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public QwenApi qwenApi(QwenConnectionProperties connectionProperties) {
String apiKey = connectionProperties.getApiKey();
Assert.hasText(apiKey,
"Qwen API key不能为空。请设置 spring.ai.qwen.api-key 属性。");
String baseUrl = StringUtils.hasText(connectionProperties.getBaseUrl()) ?
connectionProperties.getBaseUrl() : QwenConnectionProperties.DEFAULT_BASE_URL;
return new QwenApi(baseUrl, apiKey);
}
}
// QwenChatAutoConfiguration.java
@AutoConfiguration
@ConditionalOnMissingBean(QwenChatModel.class)
@ConditionalOnProperty(prefix = QwenChatProperties.CONFIG_PREFIX,
name = "enabled", havingValue = "true",
matchIfMissing = true)
public class QwenChatAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public QwenChatModel qwenChatModel(
QwenApi qwenApi,
QwenChatProperties chatProperties,
ObjectProvider<RetryTemplate> retryTemplateProvider,
ObjectProvider<ObservationRegistry> observationRegistryProvider) {
RetryTemplate retryTemplate = retryTemplateProvider
.getIfAvailable(() -> RetryUtils.DEFAULT_RETRY_TEMPLATE);
ObservationRegistry registry = observationRegistryProvider
.getIfUnique(() -> ObservationRegistry.NOOP);
return new QwenChatModel(
qwenApi,
chatProperties.getOptions(),
retryTemplate,
registry
);
}
}6.3 注册AutoConfiguration
# src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
org.springframework.ai.qwen.autoconfigure.QwenAutoConfiguration
org.springframework.ai.qwen.autoconfigure.QwenChatAutoConfiguration
org.springframework.ai.qwen.autoconfigure.QwenEmbeddingAutoConfiguration七、完整的单元测试
7.1 Mock测试
// QwenChatModelTests.java
@ExtendWith(MockitoExtension.class)
class QwenChatModelTests {
@Mock
private QwenApi qwenApi;
private QwenChatModel chatModel;
@BeforeEach
void setup() {
chatModel = new QwenChatModel(qwenApi, QwenChatOptions.builder()
.withModel(QwenChatOptions.MODEL_QWEN_TURBO)
.withMaxTokens(1024)
.build());
}
@Test
@DisplayName("正常调用:返回正确的ChatResponse")
void testNormalCall() {
// Given
ChatCompletionResponse mockResponse = buildMockResponse(
"你好,我是通义千问!", "stop");
when(qwenApi.chatCompletion(any())).thenReturn(mockResponse);
// When
ChatResponse response = chatModel.call(new Prompt("你好"));
// Then
assertThat(response).isNotNull();
assertThat(response.getResult().getOutput().getContent())
.isEqualTo("你好,我是通义千问!");
// 验证请求参数
ArgumentCaptor<ChatCompletionRequest> requestCaptor =
ArgumentCaptor.forClass(ChatCompletionRequest.class);
verify(qwenApi).chatCompletion(requestCaptor.capture());
ChatCompletionRequest capturedRequest = requestCaptor.getValue();
assertThat(capturedRequest.getModel()).isEqualTo(QwenChatOptions.MODEL_QWEN_TURBO);
assertThat(capturedRequest.getInput().getMessages()).hasSize(1);
assertThat(capturedRequest.getInput().getMessages().get(0).getContent())
.isEqualTo("你好");
}
@Test
@DisplayName("多轮对话:系统提示词和对话历史正确传递")
void testMultiTurnConversation() {
// Given
ChatCompletionResponse mockResponse = buildMockResponse("好的", "stop");
when(qwenApi.chatCompletion(any())).thenReturn(mockResponse);
// 构造多轮对话
List<Message> messages = List.of(
new SystemMessage("你是一个Java专家"),
new UserMessage("什么是Spring AI?"),
new AssistantMessage("Spring AI是..."),
new UserMessage("它支持哪些模型?")
);
// When
chatModel.call(new Prompt(messages));
// Then
ArgumentCaptor<ChatCompletionRequest> requestCaptor =
ArgumentCaptor.forClass(ChatCompletionRequest.class);
verify(qwenApi).chatCompletion(requestCaptor.capture());
List<ChatCompletionRequest.Message> sentMessages =
requestCaptor.getValue().getInput().getMessages();
assertThat(sentMessages).hasSize(4);
assertThat(sentMessages.get(0).getRole()).isEqualTo("system");
assertThat(sentMessages.get(1).getRole()).isEqualTo("user");
assertThat(sentMessages.get(2).getRole()).isEqualTo("assistant");
assertThat(sentMessages.get(3).getRole()).isEqualTo("user");
}
@Test
@DisplayName("Runtime Options优先于默认Options")
void testRuntimeOptionsOverride() {
when(qwenApi.chatCompletion(any()))
.thenReturn(buildMockResponse("回答", "stop"));
// 使用Runtime Options覆盖默认模型
chatModel.call(new Prompt("问题",
QwenChatOptions.builder()
.withModel(QwenChatOptions.MODEL_QWEN_MAX)
.withTemperature(0.2)
.build()));
ArgumentCaptor<ChatCompletionRequest> requestCaptor =
ArgumentCaptor.forClass(ChatCompletionRequest.class);
verify(qwenApi).chatCompletion(requestCaptor.capture());
// 验证运行时模型被使用
assertThat(requestCaptor.getValue().getModel())
.isEqualTo(QwenChatOptions.MODEL_QWEN_MAX);
assertThat(requestCaptor.getValue().getParameters().getTemperature())
.isEqualTo(0.2f);
}
@Test
@DisplayName("API异常时正确抛出异常")
void testApiException() {
when(qwenApi.chatCompletion(any()))
.thenThrow(new QwenApiException("API调用失败"));
assertThatThrownBy(() -> chatModel.call(new Prompt("问题")))
.isInstanceOf(QwenApiException.class)
.hasMessageContaining("API调用失败");
}
@Test
@DisplayName("流式调用:正确返回Flux")
void testStreamCall() {
// Given
List<ChatCompletionChunk> chunks = List.of(
buildChunk("你好", null),
buildChunk("!", "stop")
);
when(qwenApi.streamChatCompletion(any())).thenReturn(Flux.fromIterable(chunks));
// When
List<String> tokens = new ArrayList<>();
chatModel.stream(new Prompt("问题"))
.map(response -> response.getResult().getOutput().getContent())
.doOnNext(tokens::add)
.blockLast();
// Then
assertThat(tokens).containsExactly("你好", "!");
}
// 辅助方法:构建Mock响应
private ChatCompletionResponse buildMockResponse(String content, String finishReason) {
ChatCompletionResponse response = new ChatCompletionResponse();
ChatCompletionResponse.Output output = new ChatCompletionResponse.Output();
ChatCompletionResponse.Choice choice = new ChatCompletionResponse.Choice();
ChatCompletionRequest.Message message = new ChatCompletionRequest.Message();
message.setContent(content);
message.setRole("assistant");
choice.setMessage(message);
choice.setFinishReason(finishReason);
output.setChoices(List.of(choice));
response.setOutput(output);
ChatCompletionResponse.Usage usage = new ChatCompletionResponse.Usage();
usage.setInputTokens(10);
usage.setOutputTokens(5);
usage.setTotalTokens(15);
response.setUsage(usage);
response.setRequestId("test-request-id");
return response;
}
}7.2 Spring Boot集成测试
// QwenChatModelIntegrationTests.java
@SpringBootTest
@ActiveProfiles("integration-test")
@Disabled("需要真实的API Key才能运行")
class QwenChatModelIntegrationTests {
@Autowired
private ChatClient chatClient;
@Test
void testBasicChat() {
String response = chatClient.prompt()
.user("用一句话介绍Java")
.call()
.content();
assertThat(response).isNotBlank();
System.out.println("Response: " + response);
}
@Test
void testStreamingChat() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
StringBuilder sb = new StringBuilder();
chatClient.prompt()
.user("数到5")
.stream()
.chatResponse()
.subscribe(
response -> sb.append(response.getResult().getOutput().getContent()),
error -> latch.countDown(),
latch::countDown
);
latch.await(30, TimeUnit.SECONDS);
System.out.println("Streaming response: " + sb.toString());
assertThat(sb.toString()).isNotBlank();
}
}八、使用你的新Provider
8.1 application.yml配置
spring:
ai:
qwen:
api-key: ${QWEN_API_KEY}
base-url: https://dashscope.aliyuncs.com/api/v1
chat:
options:
model: qwen-turbo
max-tokens: 2048
temperature: 0.7
embedding:
options:
model: text-embedding-v3
dimension: 15368.2 Spring AI应用中使用
// 与其他Provider完全相同的使用方式
@Service
public class MyAiService {
private final ChatClient chatClient; // 自动注入QwenChatModel
private final EmbeddingModel embeddingModel; // 自动注入QwenEmbeddingModel
public String chat(String question) {
return chatClient.prompt()
.user(question)
.call()
.content();
}
public float[] embed(String text) {
return embeddingModel.embed(text);
}
}九、常见问题 FAQ
Q1:开发Provider需要了解Spring AI的哪些内部机制?
A:核心需要了解:
ChatModel/EmbeddingModel接口规范Prompt/Message/ChatResponse数据模型ChatOptions的合并机制(默认优先级)Advisor链的执行顺序- Spring Boot AutoConfiguration的条件注解
Q2:如何处理不同模型API的差异?
A:通过@ConditionalOnProperty和@ConditionalOnMissingBean:
- 定义清晰的接口(ChatModel等)
- 实现特定于该API的DTO转换
- 在
createRequest()和toChatResponse()中处理所有差异 - 暴露模型特有的功能(如
enableSearch)通过特定Options
Q3:是否可以向Spring AI官方贡献Provider?
A:完全可以:
- Fork spring-projects/spring-ai
- 参考已有的
spring-ai-openai模块的结构 - 实现所有必要接口和测试(80%+覆盖率)
- 提交Pull Request并回应代码审查意见
- 可以先以独立库发布,积累用户反馈后再提PR
Q4:如何支持Function/Tool calling?
A:通义千问的Tool Calling与OpenAI类似:
- 在
ChatCompletionRequest.Parameters中添加tools字段 - 处理响应中的
tool_calls字段 - 实现
FunctionCallback到通义千问Tool格式的转换 - 参考
spring-ai-openai中的OpenAiFunctionCallingHelper
Q5:如何让Provider支持Micrometer观测性?
A:Spring AI内置了观测性支持:
- 构造器中注入
ObservationRegistry - 在
call()方法中使用ChatModelObservationContext创建Observation - 观测数据自动发布到Prometheus/Zipkin等
十、总结
实现一个Spring AI Provider需要:
| 步骤 | 工作量 | 关键文件 |
|---|---|---|
| HTTP客户端封装 | 1天 | QwenApi.java |
| ChatModel实现 | 1天 | QwenChatModel.java |
| EmbeddingModel | 0.5天 | QwenEmbeddingModel.java |
| Spring Boot集成 | 0.5天 | AutoConfiguration |
| 单元测试 | 1天 | *Tests.java |
黄磊3天完成Provider开发不是神话,而是框架良好设计的结果——Spring AI的扩展接口非常清晰,只要你理解了一个Provider(如OpenAI的实现),就能快速实现其他Provider。
这种能力的价值不只是接入一个新模型,而是深刻理解框架内部原理——这会让你成为团队中真正的AI技术专家,而不只是会调API的工程师。
