admin管理员组

文章数量:1442162

聊聊Spring AI的EmbeddingModel

本文主要研究一下Spring AI的EmbeddingModel

EmbeddingModel

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java

代码语言:javascript代码运行次数:0运行复制
public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

	@Override
	EmbeddingResponse call(EmbeddingRequest request);

	/**
	 * Embeds the given text into a vector.
	 * @param text the text to embed.
	 * @return the embedded vector.
	 */
	default float[] embed(String text) {
		Assert.notNull(text, "Text must not be null");
		List<float[]> response = this.embed(List.of(text));
		return response.iterator().next();
	}

	/**
	 * Embeds the given document's content into a vector.
	 * @param document the document to embed.
	 * @return the embedded vector.
	 */
	float[] embed(Document document);

	/**
	 * Embeds a batch of texts into vectors.
	 * @param texts list of texts to embed.
	 * @return list of embedded vectors.
	 */
	default List<float[]> embed(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
			.getResults()
			.stream()
			.map(Embedding::getOutput)
			.toList();
	}

	/**
	 * Embeds a batch of {@link Document}s into vectors based on a
	 * {@link BatchingStrategy}.
	 * @param documents list of {@link Document}s.
	 * @param options {@link EmbeddingOptions}.
	 * @param batchingStrategy {@link BatchingStrategy}.
	 * @return a list of float[] that represents the vectors for the incoming
	 * {@link Document}s. The returned list is expected to be in the same order of the
	 * {@link Document} list.
	 */
	default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
		Assert.notNull(documents, "Documents must not be null");
		List<float[]> embeddings = new ArrayList<>(documents.size());
		List<List<Document>> batch = batchingStrategy.batch(documents);
		for (List<Document> subBatch : batch) {
			List<String> texts = subBatch.stream().map(Document::getText).toList();
			EmbeddingRequest request = new EmbeddingRequest(texts, options);
			EmbeddingResponse response = this.call(request);
			for (int i = 0; i < subBatch.size(); i++) {
				embeddings.add(response.getResults().get(i).getOutput());
			}
		}
		Assert.isTrue(embeddings.size() == documents.size(),
				"Embeddings must have the same number as that of the documents");
		return embeddings;
	}

	/**
	 * Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
	 * @param texts list of texts to embed.
	 * @return the embedding response.
	 */
	default EmbeddingResponse embedForResponse(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
	}

	/**
	 * Get the number of dimensions of the embedded vectors. Note that by default, this
	 * method will call the remote Embedding endpoint to get the dimensions of the
	 * embedded vectors. If the dimensions are known ahead of time, it is recommended to
	 * override this method.
	 * @return the number of dimensions of the embedded vectors.
	 */
	default int dimensions() {
		return embed("Test String").length;
	}

}

EmbeddingModel继承了Model接口,其入参类型为EmbeddingRequest,返回类型为EmbeddingResponse,它定义了call、embed接口,提供了embed、embedForResponse、dimensions的默认实现

EmbeddingRequest

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java

代码语言:javascript代码运行次数:0运行复制
public class EmbeddingRequest implements ModelRequest<List<String>> {

	private final List<String> inputs;

	private final EmbeddingOptions options;

	public EmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
		this.inputs = inputs;
		this.options = options;
	}

	@Override
	public List<String> getInstructions() {
		return this.inputs;
	}

	@Override
	public EmbeddingOptions getOptions() {
		return this.options;
	}

}

EmbeddingRequest实现了ModelRequest接口,其getInstructions返回的是List<String>

EmbeddingResponse

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java

代码语言:javascript代码运行次数:0运行复制
public class EmbeddingResponse implements ModelResponse<Embedding> {

	/**
	 * Embedding data.
	 */
	private final List<Embedding> embeddings;

	/**
	 * Embedding metadata.
	 */
	private final EmbeddingResponseMetadata metadata;

	/**
	 * Creates a new {@link EmbeddingResponse} instance with empty metadata.
	 * @param embeddings the embedding data.
	 */
	public EmbeddingResponse(List<Embedding> embeddings) {
		this(embeddings, new EmbeddingResponseMetadata());
	}

	/**
	 * Creates a new {@link EmbeddingResponse} instance.
	 * @param embeddings the embedding data.
	 * @param metadata the embedding metadata.
	 */
	public EmbeddingResponse(List<Embedding> embeddings, EmbeddingResponseMetadata metadata) {
		this.embeddings = embeddings;
		this.metadata = metadata;
	}

	/**
	 * @return Get the embedding metadata.
	 */
	public EmbeddingResponseMetadata getMetadata() {
		return this.metadata;
	}

	@Override
	public Embedding getResult() {
		Assert.notEmpty(this.embeddings, "No embedding data available.");
		return this.embeddings.get(0);
	}

	/**
	 * @return Get the embedding data.
	 */
	@Override
	public List<Embedding> getResults() {
		return this.embeddings;
	}

	//......
}	

EmbeddingResponse实现了ModelResponse接口,其result为Embedding类型

AbstractEmbeddingModel

spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

代码语言:javascript代码运行次数:0运行复制
public abstract class AbstractEmbeddingModel implements EmbeddingModel {

	private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();

	/**
	 * Default constructor.
	 */
	public AbstractEmbeddingModel() {
	}

	/**
	 * Cached embedding dimensions.
	 */
	protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1);

	/**
	 * Return the dimension of the requested embedding generative name. If the generative
	 * name is unknown uses the EmbeddingModel to perform a dummy EmbeddingModel#embed and
	 * count the response dimensions.
	 * @param embeddingModel Fall-back client to determine, empirically the dimensions.
	 * @param modelName Embedding generative name to retrieve the dimensions for.
	 * @param dummyContent Dummy content to use for the empirical dimension calculation.
	 * @return Returns the embedding dimensions for the modelName.
	 */
	public static int dimensions(EmbeddingModel embeddingModel, String modelName, String dummyContent) {

		if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) {
			// Retrieve the dimension from a pre-configured file.
			return KNOWN_EMBEDDING_DIMENSIONS.get(modelName);
		}
		else {
			// Determine the dimensions empirically.
			// Generate an embedding and count the dimension size;
			return embeddingModel.embed(dummyContent).length;
		}
	}

	private static Map<String, Integer> loadKnownModelDimensions() {
		try {
			Properties properties = new Properties();
			properties.load(new DefaultResourceLoader()
				.getResource("classpath:/embedding/embedding-model-dimensions.properties")
				.getInputStream());
			return properties.entrySet()
				.stream()
				.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
		}
		catch (IOException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public int dimensions() {
		if (this.embeddingDimensions.get() < 0) {
			this.embeddingDimensions.set(dimensions(this, "Test", "Hello World"));
		}
		return this.embeddingDimensions.get();
	}

}

AbstractEmbeddingModel实现了EmbeddingModel接口定义的dimensions方法,它在不同模块有不同的实现子类,比如spring-ai-openai的OpenAiEmbeddingModel、spring-ai-ollama的OllamaEmbeddingModel、spring-ai-minimax的MiniMaxEmbeddingModel

OllamaEmbeddingAutoConfiguration

org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfiguration.java

代码语言:javascript代码运行次数:0运行复制
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaEmbeddingModel.class)
@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.OLLAMA,
		matchIfMissing = true)
@EnableConfigurationProperties({ OllamaEmbeddingProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { OllamaApiAutoConfiguration.class, RestClientAutoConfiguration.class,
		WebClientAutoConfiguration.class })
public class OllamaEmbeddingAutoConfiguration {

	@Bean
	@ConditionalOnMissingBean
	public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
			OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
			ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
		var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
				? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;

		var embeddingModel = OllamaEmbeddingModel.builder()
			.ollamaApi(ollamaApi)
			.defaultOptions(properties.getOptions())
			.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
			.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
					initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
					initProperties.getMaxRetries()))
			.build();

		observationConvention.ifAvailable(embeddingModel::setObservationConvention);

		return embeddingModel;
	}

}

OllamaEmbeddingAutoConfiguration在spring.ai.model.embeddingollama时启用,它自动配置了OllamaEmbeddingModel

OllamaEmbeddingProperties

org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java

代码语言:javascript代码运行次数:0运行复制
@ConfigurationProperties(OllamaEmbeddingProperties.CONFIG_PREFIX)
public class OllamaEmbeddingProperties {

	public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding";

	/**
	 * Client lever Ollama options. Use this property to configure generative temperature,
	 * topK and topP and alike parameters. The null values are ignored defaulting to the
	 * generative's defaults.
	 */
	@NestedConfigurationProperty
	private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();

	public String getModel() {
		return this.options.getModel();
	}

	public void setModel(String model) {
		this.options.setModel(model);
	}

	public OllamaOptions getOptions() {
		return this.options;
	}

}

OllamaEmbeddingProperties主要是提供了OllamaOptions属性配置,具体可以参考.cpp/blob/master/examples/main/README.md

OllamaInitializationProperties

org/springframework/ai/model/ollama/autoconfigure/OllamaInitializationProperties.java

代码语言:javascript代码运行次数:0运行复制
@ConfigurationProperties(OllamaInitializationProperties.CONFIG_PREFIX)
public class OllamaInitializationProperties {

	public static final String CONFIG_PREFIX = "spring.ai.ollama.init";

	/**
	 * Chat models initialization settings.
	 */
	private final ModelTypeInit chat = new ModelTypeInit();

	/**
	 * Embedding models initialization settings.
	 */
	private final ModelTypeInit embedding = new ModelTypeInit();

	/**
	 * Whether to pull models at startup-time and how.
	 */
	private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;

	/**
	 * How long to wait for a model to be pulled.
	 */
	private Duration timeout = Duration.ofMinutes(5);

	/**
	 * Maximum number of retries for the model pull operation.
	 */
	private int maxRetries = 0;

	public PullModelStrategy getPullModelStrategy() {
		return this.pullModelStrategy;
	}

	public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
		this.pullModelStrategy = pullModelStrategy;
	}

	public ModelTypeInit getChat() {
		return this.chat;
	}

	public ModelTypeInit getEmbedding() {
		return this.embedding;
	}

	public Duration getTimeout() {
		return this.timeout;
	}

	public void setTimeout(Duration timeout) {
		this.timeout = timeout;
	}

	public int getMaxRetries() {
		return this.maxRetries;
	}

	public void setMaxRetries(int maxRetries) {
		this.maxRetries = maxRetries;
	}

	public static class ModelTypeInit {

		/**
		 * Include this type of models in the initialization task.
		 */
		private boolean include = true;

		/**
		 * Additional models to initialize besides the ones configured via default
		 * properties.
		 */
		private List<String> additionalModels = List.of();

		public boolean isInclude() {
			return this.include;
		}

		public void setInclude(boolean include) {
			this.include = include;
		}

		public List<String> getAdditionalModels() {
			return this.additionalModels;
		}

		public void setAdditionalModels(List<String> additionalModels) {
			this.additionalModels = additionalModels;
		}

	}

}

OllamaInitializationProperties提供了spring.ai.ollama.init即ollama初始化的相关配置,其中ModelTypeInit可以指定初始化哪些额外的model

示例

pom.xml

代码语言:javascript代码运行次数:0运行复制
<dependency>
   <groupId>org.springframework.ai</groupId>
   <artifactId>spring-ai-starter-model-ollama</artifactId>
</dependency>

配置

代码语言:javascript代码运行次数:0运行复制
spring:
  ai:
    model:
      embedding: ollama
    ollama:
      init:
        timeout: 5m
        max-retries: 0
        embedding:
          include: true
          additional-models: []
      base-url: http://localhost:11434
      embedding:
        enabled: true
        options:
          model: bge-m3:latest
          truncate: true

example

代码语言:javascript代码运行次数:0运行复制
    @Test
    public void testCall() {
        EmbeddingRequest request = new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
                OllamaOptions.builder()
                        .model("bge-m3:latest")
                        .truncate(false)
                        .build());
        EmbeddingResponse embeddingResponse = embeddingModel.call(request);
        log.info("resp:{}", JSON.toJSONString(embeddingResponse));
    }

小结

Spring AI定义了EmbeddingModel接口,它继承了Model接口,其入参类型为EmbeddingRequest,返回类型为EmbeddingResponse,它定义了call、embed接口,提供了embed、embedForResponse、dimensions的默认实现;AbstractEmbeddingModel实现了EmbeddingModel接口定义的dimensions方法,它在不同模块有不同的实现子类,比如spring-ai-openai的OpenAiEmbeddingModel、spring-ai-ollama的OllamaEmbeddingModel、spring-ai-minimax的MiniMaxEmbeddingModel等;OllamaEmbeddingAutoConfiguration在spring.ai.model.embeddingollama时启用,它自动配置了OllamaEmbeddingModel。

doc

  • embeddings
  • ollama-embeddings
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。原始发表:2025-04-02,如有侵权请联系 cloudcommunity@tencent 删除springpublicreturn接口配置

本文标签: 聊聊Spring AI的EmbeddingModel