【鉴赏】当点击dify的召回测试按钮时,后台程序做了一些什么?
本文深入分析了Dify项目源代码的知识库文档的召回测试的业务逻辑和架构
·
【鉴赏】当点击dify的召回测试按钮时,后台程序做了一些什么?
*dify构建知识库,召回率是最重要的指标,如果没法正确召回知识库文档,后续的答复质量堪忧,因此非常有必要深入研究dify的召回逻辑。用户点击召回测试按钮,可以参考 《GoogleChrome浏览器开发者模式查看dify接口》 ,知道对应的后台接口 http://[serverIP]/console/api/datasets/[datasetID]/hit-testing *
1. 知识库 dataset id 作为URL子路径传递到api后台程序 hit_testing.py
class HitTestingApi(Resource, DatasetsHitTestingBase):
...
def post(self, dataset_id):
...
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
类 HitTestingApi的post方法接受请求,调用perform_hit_testing()处理逻辑
2. 执行召回测试操作 perform_hit_testing()
程序:hit_testing_base.py
class DatasetsHitTestingBase:
...
def perform_hit_testing(dataset, args):
try:
response = HitTestingService.retrieve(
dataset=dataset, <-- 知识库ID
query=args["query"], <-- 用户提的问题
下面有一大堆的异常捕获
3. 做两件事情:1是调用 RetrievalService.retrieve() 获得向量数据库里面匹配文档; 2. 把本次召回测试记录到数据库
程序:hit_testing_service.py
class HitTestingService:
@classmethod
def retrieve(
cls,
...
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id, <-- 知识库ID
query=query, <-- 用户提的问题
4. 接下来是具体召回操作逻辑,包括三类召回的操作
程序:retrieval_service.py
- class RetrievalService - retrieve()
三段逻辑:
if retrieval_method == "keyword_search": <-- 逻辑1:关键字搜索
if RetrievalMethod.is_support_semantic_search(retrieval_method): <-- 逻辑2:向量搜索
...
futures.append(
executor.submit(
cls.embedding_search, <--丢到celery worker的作业
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id, <-- 知识库ID
query=query, <-- 用户提的问题
...
if RetrievalMethod.is_support_fulltext_search(retrieval_method): <-- 逻辑:全文搜索
...
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: <-- 推荐的混合模式
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke( <-- 这里调用rerank的工厂类进行重排序,逻辑有点复杂,后续开一个新文章描述
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
5. 以文本嵌入搜索为例, embedding_search()
程序:retrieval_service.py
@classmethod
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
...
vector = Vector(dataset=dataset) <-- vector_factory.py 定义工厂类Vector embed_query()
documents = vector.search_by_vector( <-- 这个向量搜索包含了大量的实现,各类逻辑在 core/rag/datasource/vdb/目录下
query,
search_type="similarity_score_threshold",
top_k=top_k,
...
6. Vector工厂类缓存搜索:生成query的hash值,并base64压缩,写入redis
程序:cached_embedding.py
def embed_query(self, text: str) -> list[float]:
...
hash = helper.generate_text_hash(text) <-- 生成query问题的hash值
...
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str) <-- 写入redis库,这样同样的问题,直接从redis拿结果
7. dify默认的向量数据库 weaviate为例 search_by_vector()
程序:weaviate_vector.py
...
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
...
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
...
vector = {"vector": query_vector} <-- 这里的query_vector是向量化后的query问题
document_ids_filter = kwargs.get("document_ids_filter")
...
result = ( <-- 这里拿到对应的document数据集
query_obj.with_near_vector(vector) <-- 这里是向量数据库搜索的方法
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
...
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) <-- 文档根据score排序
return docs
全文完毕
更多推荐
所有评论(0)