admin管理员组

文章数量:1444647

基于Spring Ai + Ollama + Qwen2.5/Deepseek+Milvus实现RAG

安装mivus

参考官网的步骤就可以实现,作者使用docker实现

.md

安装Ollama

参考官网的步骤就可以实现,同事可在ollama下载文生文、文生图、ocr、embbeding、deepseek、Qwen系列等模型

前端Vue

vue3、element plus等

后端

spring boot 3.4.3、Java17

关键代码

代码语言:javascript代码运行次数:0运行复制
package com.sb.rag;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2mon.ConsistencyLevel;
import io.milvus.v2mon.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.database.request.CreateDatabaseReq;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.request.UpsertReq;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.SearchResp;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.TemplateFormat;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import javax.validation.constraints.NotNull;
import java.io.File;
import java.io.FileFilter;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.*;

@RestController
@RequestMapping("/ai/rag")
public class RAGController {

    @Autowired
    private MilvusClientV2 vectorStore;
    @Autowired
    private OllamaEmbeddingModel embeddingModel;
    @Autowired
    private OllamaChatModel chatModel;
    @Autowired
    private ChatClient chatClient;

    /**
     * 用户导入文件构建rag知识库
     * @return
     */
    @CrossOrigin
    @PostMapping("/import")
    public String importData(@RequestParam("file") MultipartFile files) {
        ArrayList<String> textLineList = new ArrayList<>();
        // 1.1、可以是用户的文件
        if (ObjectUtil.isNotEmpty(files)){
            String content = null;
            try {
                // 使用 Hutool 读取文件内容
                content = IoUtil.read(files.getInputStream(), CharsetUtil.UTF_8);
            } catch (Exception e) {
                return "文件读取失败:" + e.getMessage();
            }
            textLineList.add(content);
        }
        // 1.2、可以是公司的私有数据文件或数据湖中数据
        /*String directoryPath = "E:/workspace/idea/wanyuanhui/src/main/resources/milvus_docs/en/faq";
        // 使用 Hutool 的 FileUtil 获取目录下所有 .md 文件
        List<File> mdFiles = FileUtil.loopFiles(directoryPath, new FileFilter() {
            @Override
            public boolean accept(File file) {
                return file.getName().endsWith(".md");  // 过滤出所有 .md 文件
            }
        });

        // 遍历并读取所有 .md 文件的内容
        for (File file : mdFiles) {
            // 读取文件内容(文件内容是字符串)
            String cont = FileUtil.readUtf8String(file);  // 默认以 UTF-8 编码读取文件

            String[] text = cont.split("#");
            for (String line : text) {
                textLineList.add(line);
            }
        }*/
        // 2、嵌入
        int dimension = 0;
        List<JsonObject> jsonObjects = new ArrayList<>();
        for (int i = 0; i < textLineList.size(); i++) {
            String text = textLineList.get(i);
            EmbeddingResponse call = embeddingModel.call(new EmbeddingRequest(
                    List.of(text),
                    OllamaOptions.builder()
                            .model("mxbai-embed-large:latest")
                            .build()
            ));
            float[] vectorArr = call.getResult().getOutput();
            dimension = vectorArr.length;
            String replaceText = text.replace("\"", "\'");
            String format = StrUtil.format("{\"id\": {}, \"vector\": {}, \"text\": \"{}\"}", i, Arrays.toString(vectorArr), replaceText);
            Gson gson = new Gson();
            JsonObject entries = gson.fromJson(format, JsonObject.class);
            jsonObjects.add(entries);
        }

        // 3、加入企业私域知识向量库
        Boolean rag = vectorStore.hasCollection(HasCollectionReq.builder()
                .collectionName("rag")
                .build());
        if (Boolean.FALSE.equals(rag)) {
            vectorStore.createCollection(CreateCollectionReq.builder()
                    .collectionName("rag")
                    .dimension(dimension)
                    .metricType(IndexParam.MetricType.IP.name())
                    .consistencyLevel(ConsistencyLevel.STRONG)
                    .build());
        }

        vectorStore.upsert(UpsertReq.builder()
                .collectionName("rag")
                .data(jsonObjects)
                .build());

        return "上传成功";
    }

    /**
     * rag对话
     * @param question
     * @return
     * @throws InterruptedException
     */
    @CrossOrigin
    @GetMapping("/dialog")
    public String dialog(@NotNull(message = "不能为空") String question) throws InterruptedException {
        // 1、embedding
        float[] embed = embeddingModel.embed(question);
        // 2、similar search
        String rag = "rag";
        FloatVec floatVec = new FloatVec(embed);
        HashMap<String, Object> searchParams = new HashMap<>();
        searchParams.put("metric_type", "IP");
        SearchResp text = vectorStore.search(SearchReq.builder()
                .collectionName(rag)
                .data(Collections.singletonList(floatVec))
                .limit(3L)
                .topK(10)
                .searchParams(searchParams)
                .outputFields(Arrays.asList("text"))
                .build());
        List<List<SearchResp.SearchResult>> searchResults = text.getSearchResults();

        ArrayList<String> similarSearchResults = new ArrayList<>();
        for (List<SearchResp.SearchResult> searchResult : searchResults) {
            String content = (String) searchResult.get(0).getEntity().get("text");
            similarSearchResults.add(content);
        }
        String context = String.join("\n", similarSearchResults);
        // 3、llm
//        ChatClient chatClient = ChatClient.builder(chatModel)
//                .defaultSystem("Human: You are an AI assistant. You are able to find answers to the questions from the contextual passage snippets provided.")
//                .build();
        // 为 Lanage 模型定义系统和用户提示。该提示与从 Milvus 检索到的文档作为上下文实现RAG。
        ChatClient.CallResponseSpec responseSpec = chatClient.prompt(
                        new Prompt(StrUtil.format("Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.\n" +
                                "<context>{}</context>\n" +
                                "<question>{}</question>", context, question),
                                OllamaOptions.builder()
                                        .model("qwen2.5:3b")
                                        .build())
                ).system(sys -> sys.param("voice", "parameter customization"))
                .call();
        // 4、response
        String responseText = responseSpec.chatResponse().getResult().getOutput().getText();
        return responseText;
    }

}

实现效果

本文标签: 基于Spring AiOllamaQwen25DeepseekMilvus实现RAG