回到目录

【鉴赏】当点击dify的召回测试按钮时,后台程序做了一些什么?

*dify构建知识库,召回率是最重要的指标,如果没法正确召回知识库文档,后续的答复质量堪忧,因此非常有必要深入研究dify的召回逻辑。用户点击召回测试按钮,可以参考 《GoogleChrome浏览器开发者模式查看dify接口》 ,知道对应的后台接口 http://[serverIP]/console/api/datasets/[datasetID]/hit-testing *

1. 知识库 dataset id 作为URL子路径传递到api后台程序 hit_testing.py

class HitTestingApi(Resource, DatasetsHitTestingBase):
...
    def post(self, dataset_id):
...
            return self.perform_hit_testing(dataset, args) 

  api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

类 HitTestingApi的post方法接受请求,调用perform_hit_testing()处理逻辑

2. 执行召回测试操作 perform_hit_testing()

程序:hit_testing_base.py

class DatasetsHitTestingBase:
...
    def perform_hit_testing(dataset, args):
        try:
            response = HitTestingService.retrieve(
                dataset=dataset,   <-- 知识库ID
                query=args["query"],   <-- 用户提的问题

下面有一大堆的异常捕获

3. 做两件事情:1是调用 RetrievalService.retrieve() 获得向量数据库里面匹配文档; 2. 把本次召回测试记录到数据库

程序:hit_testing_service.py

class HitTestingService:
    @classmethod
    def retrieve(
        cls,
...
        all_documents = RetrievalService.retrieve(
            retrieval_method=retrieval_model.get("search_method", "semantic_search"),
            dataset_id=dataset.id,   <-- 知识库ID
            query=query,  <-- 用户提的问题

4. 接下来是具体召回操作逻辑,包括三类召回的操作

程序:retrieval_service.py

  • class RetrievalService - retrieve()
    三段逻辑:
            if retrieval_method == "keyword_search":   <-- 逻辑1:关键字搜索
            if RetrievalMethod.is_support_semantic_search(retrieval_method):    <-- 逻辑2:向量搜索
...
                futures.append(
                    executor.submit(
                        cls.embedding_search,  <--丢到celery worker的作业
                        flask_app=current_app._get_current_object(),  # type: ignore
                        dataset_id=dataset_id,   <-- 知识库ID
                        query=query,  <-- 用户提的问题
...
            if RetrievalMethod.is_support_fulltext_search(retrieval_method):   <-- 逻辑:全文搜索
...

        if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:  <-- 推荐的混合模式
            data_post_processor = DataPostProcessor(
                str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
            )
            all_documents = data_post_processor.invoke(  <-- 这里调用rerank的工厂类进行重排序,逻辑有点复杂,后续开一个新文章描述
                query=query,
                documents=all_documents,
                score_threshold=score_threshold,
                top_n=top_k,
            )

5. 以文本嵌入搜索为例, embedding_search()

程序:retrieval_service.py

    @classmethod
    def embedding_search(
        cls,
        flask_app: Flask,
        dataset_id: str,
        query: str,
...
                vector = Vector(dataset=dataset)   <-- vector_factory.py 定义工厂类Vector embed_query() 
                documents = vector.search_by_vector(   <-- 这个向量搜索包含了大量的实现,各类逻辑在 core/rag/datasource/vdb/目录下
                    query,
                    search_type="similarity_score_threshold",
                    top_k=top_k,
...

6. Vector工厂类缓存搜索:生成query的hash值,并base64压缩,写入redis

程序:cached_embedding.py

    def embed_query(self, text: str) -> list[float]:
...
        hash = helper.generate_text_hash(text)   <-- 生成query问题的hash...
            # encode embedding to base64
            embedding_vector = np.array(embedding_results)
            vector_bytes = embedding_vector.tobytes()
            # Transform to Base64
            encoded_vector = base64.b64encode(vector_bytes)
            # Transform to string
            encoded_str = encoded_vector.decode("utf-8")
            redis_client.setex(embedding_cache_key, 600, encoded_str)   <-- 写入redis库,这样同样的问题,直接从redis拿结果

7. dify默认的向量数据库 weaviate为例 search_by_vector()

程序:weaviate_vector.py

...
class WeaviateVector(BaseVector):
    def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
...
    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
        """Look up similar documents by embedding vector in Weaviate."""
...
        vector = {"vector": query_vector}  <-- 这里的query_vector是向量化后的query问题
        document_ids_filter = kwargs.get("document_ids_filter")
...
        result = (  <-- 这里拿到对应的document数据集
            query_obj.with_near_vector(vector)  <-- 这里是向量数据库搜索的方法
            .with_limit(kwargs.get("top_k", 4))
            .with_additional(["vector", "distance"])
            .do()
        )
...
        # Sort the documents by score in descending order
        docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)  <-- 文档根据score排序
        return docs

全文完毕

回到目录

Logo

更多推荐