admin管理员组

文章数量:1442507

聊聊Spring AI的Chat Model

本文主要研究一下Spring AI的Chat Model

Model

spring-ai-core/src/main/java/org/springframework/ai/model/Model.java

代码语言:javascript代码运行次数:0运行复制
public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {

	/**
	 * Executes a method call to the AI model.
	 * @param request the request object to be sent to the AI model
	 * @return the response from the AI model
	 */
	TRes call(TReq request);

}

Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型

ModelRequest

spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java

代码语言:javascript代码运行次数:0运行复制
public interface ModelRequest<T> {

	/**
	 * Retrieves the instructions or input required by the AI model.
	 * @return the instructions or input required by the AI model
	 */
	T getInstructions(); // required input

	/**
	 * Retrieves the customizable options for AI model interactions.
	 * @return the customizable options for AI model interactions
	 */
	ModelOptions getOptions();

}

ModelRequest定义了getInstructions、getOptions方法

ModelResponse

spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java

代码语言:javascript代码运行次数:0运行复制
public interface ModelResponse<T extends ModelResult<?>> {

	/**
	 * Retrieves the result of the AI model.
	 * @return the result generated by the AI model
	 */
	T getResult();

	/**
	 * Retrieves the list of generated outputs by the AI model.
	 * @return the list of generated outputs
	 */
	List<T> getResults();

	/**
	 * Retrieves the response metadata associated with the AI model's response.
	 * @return the response metadata
	 */
	ResponseMetadata getMetadata();

}

ModelResponse定义了getResult、getMetadata方法,其中result为ModelResult类型

ModelResult

spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java

代码语言:javascript代码运行次数:0运行复制
public interface ModelResult<T> {

	/**
	 * Retrieves the output generated by the AI model.
	 * @return the output generated by the AI model
	 */
	T getOutput();

	/**
	 * Retrieves the metadata associated with the result of an AI model.
	 * @return the metadata associated with the result
	 */
	ResultMetadata getMetadata();

}

ModelResult定义了getMetadata方法

StreamingModel

spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java

代码语言:javascript代码运行次数:0运行复制
public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {

	/**
	 * Executes a method call to the AI model.
	 * @param request the request object to be sent to the AI model
	 * @return the streaming response from the AI model
	 */
	Flux<TResChunk> stream(TReq request);

}

StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>

StreamingChatModel

spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java

代码语言:javascript代码运行次数:0运行复制
@FunctionalInterface
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {

	default Flux<String> stream(String message) {
		Prompt prompt = new Prompt(message);
		return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
				|| response.getResult().getOutput().getText() == null) ? ""
						: response.getResult().getOutput().getText());
	}

	default Flux<String> stream(Message... messages) {
		Prompt prompt = new Prompt(Arrays.asList(messages));
		return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
				|| response.getResult().getOutput().getText() == null) ? ""
						: response.getResult().getOutput().getText());
	}

	@Override
	Flux<ChatResponse> stream(Prompt prompt);

}

StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>,并提供了Flux<String> stream(String message)Flux<String> stream(Message... messages)这两个default方法

ChatModel

spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java

代码语言:javascript代码运行次数:0运行复制
public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {

	default String call(String message) {
		Prompt prompt = new Prompt(new UserMessage(message));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	default String call(Message... messages) {
		Prompt prompt = new Prompt(Arrays.asList(messages));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	@Override
	ChatResponse call(Prompt prompt);

	default ChatOptions getDefaultOptions() {
		return ChatOptions.builder().build();
	}

	default Flux<ChatResponse> stream(Prompt prompt) {
		throw new UnsupportedOperationException("streaming is not supported");
	}

}

ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型 ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel)、spring-ai-openai(OpenAiChatModel)、spring-ai-minimax(MiniMaxChatModel)、spring-ai-moonshot(MoonshotChatModel)、spring-ai-zhipuai(ZhiPuAiChatModel)

OllamaAutoConfiguration

org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

代码语言:javascript代码运行次数:0运行复制
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
		OllamaConnectionProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class })
public class OllamaAutoConfiguration {

	@Bean
	@ConditionalOnMissingBean(OllamaConnectionDetails.class)
	public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) {
		return new PropertiesOllamaConnectionDetails(properties);
	}

	@Bean
	@ConditionalOnMissingBean
	public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
			ObjectProvider<RestClient.Builder> restClientBuilderProvider,
			ObjectProvider<WebClient.Builder> webClientBuilderProvider) {
		return new OllamaApi(connectionDetails.getBaseUrl(),
				restClientBuilderProvider.getIfAvailable(RestClient::builder),
				webClientBuilderProvider.getIfAvailable(WebClient::builder));
	}

	@Bean
	@ConditionalOnMissingBean
	@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
			matchIfMissing = true)
	public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
			OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
			FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
			ObjectProvider<ChatModelObservationConvention> observationConvention) {
		var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
				: PullModelStrategy.NEVER;

		var chatModel = OllamaChatModel.builder()
			.ollamaApi(ollamaApi)
			.defaultOptions(properties.getOptions())
			.functionCallbackResolver(functionCallbackResolver)
			.toolFunctionCallbacks(toolFunctionCallbacks)
			.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
			.modelManagementOptions(
					new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
							initProperties.getTimeout(), initProperties.getMaxRetries()))
			.build();

		observationConvention.ifAvailable(chatModel::setObservationConvention);

		return chatModel;
	}

	@Bean
	@ConditionalOnMissingBean
	@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
			matchIfMissing = true)
	public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
			OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
			ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
		var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
				? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;

		var embeddingModel = OllamaEmbeddingModel.builder()
			.ollamaApi(ollamaApi)
			.defaultOptions(properties.getOptions())
			.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
			.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
					initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
					initProperties.getMaxRetries()))
			.build();

		observationConvention.ifAvailable(embeddingModel::setObservationConvention);

		return embeddingModel;
	}

	@Bean
	@ConditionalOnMissingBean
	public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
		DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
		manager.setApplicationContext(context);
		return manager;
	}

	static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails {

		private final OllamaConnectionProperties properties;

		PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) {
			this.properties = properties;
		}

		@Override
		public String getBaseUrl() {
			return this.properties.getBaseUrl();
		}

	}

}

spring-ai-spring-boot-autoconfigure提供了一系列的AutoConfiguration,比如OllamaAutoConfiguration自动配置了OllamaChatModel

小结

Spring AI的Model接口定义了call方法,入参为ModelRequest类型,返回ModelResponse类型;StreamingModel接口定义了stream方法,入参为ModelRequest类型,返回Flux<ModelResponse>;StreamingChatModel继承了StreamingModel接口,指定了入参为Prompt类型,返回类型为Flux<ChatResponse>,并提供了Flux<String> stream(String message)Flux<String> stream(Message... messages)这两个default方法;而ChatModel继承了Model、StreamingChatModel接口,其中Model的入参为Prompt类型,返回为ChatResponse类型。ChatModel在不同模块中有不同的实现,比如spring-ai-ollama(OllamaChatModel)、spring-ai-openai(OpenAiChatModel)、spring-ai-minimax(MiniMaxChatModel)、spring-ai-moonshot(MoonshotChatModel)、spring-ai-zhipuai(ZhiPuAiChatModel)。

doc

  • chatmodel
  • chat/comparison

本文标签: 聊聊Spring AI的Chat Model