def test_construction(): with pytest.raises(ValueError): retriever = Retriever() print(type(retriever)) # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(100) rtrv_vectorizer = TfidfDocVectorizer(256) with pytest.raises(ValueError): retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) assert isinstance(retriever, Retriever) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) assert isinstance(retriever, Retriever) # Remove temp test database os.remove(temp_test_db_path)
def test_update_embeddings(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) with pytest.raises(ValueError): retriever.update_embeddings() cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) retriever.train_candidate_vectorizer() save_path = os.path.join(parent_cwd, "document_store.test_pkl") retriever.update_embeddings(save_path=save_path) # Remove temp test database os.remove(temp_test_db_path)
def test_get_candidates(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) with pytest.raises(ValueError): retriever.get_candidates(query_docs=["Test candidate"]) # Remove temp test database os.remove(temp_test_db_path)
def test_retriever_with_database(): print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore(sql_url=f"sqlite:///{test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(5) # Get text from query to input input_doc = query.all()[0].text print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_result = input_doc expected_score_result = 1 print("Retrieving") result = retriever.retrieve([input_doc], top_k_candidates=10) print( " ".join(result[input_doc]["retrieve_result"].split(" ")[:50]) ) # print the retrieved doc assert result[input_doc]["retrieve_result"] == expected_text_result assert result[input_doc]["similarity_score"] == expected_score_result
def update_remote_db(remote_doc_store): """This method updates embeddings and vector ids on remote database """ remote_retriever = Retriever( document_store=remote_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) remote_retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) remote_retriever.update_embeddings(retrain=True) logger.info("Remote embeddings and vector ids updated")
def update_remote_db(remote_doc_store): """ Write a proper docstring later This method runs in serial as follow: 1. Update embeddings on large FAISS index 2. Update vector ids on remote db 3. Update meta data of documents on local db to remote db 4. Clear local db """ remote_retriever = Retriever( document_store=remote_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) remote_retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) remote_retriever.update_embeddings(retrain=True) logger.info("Remote embeddings and vector ids updated") local_doc_store.delete_all_documents()
def update_local_db(local_doc_store, remote_doc_store): """ Write a proper docstring later This method runs in serial as follow: 1. Get document ids from remote and local db 2. Check if there is new document If Yes: 3. Write new document to local db 4. Update embeddings on small FAISS index 5. Update vector ids on local db 6. Run sequential retriever to pre-calculate the similarity scores and update on local db meta data """ if not local_doc_store or not remote_doc_store: logger.warning("DB connection not initialized, trying re-connect...") local_doc_store = get_connection(LOCAL_DB_URI, CAND_DIM) remote_doc_store = get_connection(POSTGRES_URI, CAND_DIM) if not local_doc_store or not remote_doc_store: logger.error("DB initialization failed, quit local_update...") return remote_reindex = not os.path.exists(REMOTE_IDX_PATH) if remote_reindex: new_ids = remote_doc_store.get_document_ids(from_time=datetime.now() - timedelta(days=365), index=INDEX) else: new_ids = remote_doc_store.get_document_ids(from_time=datetime.now() - timedelta(minutes=3), index=INDEX) if not new_ids: logger.info(f"No new updates in local db at {datetime.now()}") return local_ids = local_doc_store.get_document_ids(index=INDEX) # Filter existing ids in local out of recent updated ids from remote db new_ids = sorted([_id for _id in new_ids if _id not in local_ids]) docs = remote_doc_store.get_documents_by_id(new_ids, index=INDEX) logger.info(f"Retrieved {len(docs)} at {datetime.now()}") local_doc_store.write_documents(docs) logger.info("Stored documents to local db") local_retriever = Retriever( document_store=local_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) remote_retriever = Retriever( document_store=remote_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) if not os.path.exists(CAND_PATH) or not os.path.exists(RTRV_PATH): remote_retriever.train_candidate_vectorizer(retrain=True, save_path=CAND_PATH) remote_retriever.train_retriever_vectorizer(retrain=True, save_path=RTRV_PATH) logger.info("Vectorizers retrained") else: remote_retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) remote_retriever.train_retriever_vectorizer(retrain=False, save_path=RTRV_PATH) local_retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) local_retriever.train_retriever_vectorizer(retrain=False, save_path=RTRV_PATH) logger.info("Vectorizers loaded") local_retriever.update_embeddings(retrain=True, save_path=LOCAL_IDX_PATH, sql_url=LOCAL_DB_URI) remote_retriever.update_embeddings(retrain=remote_reindex, save_path=REMOTE_IDX_PATH, sql_url=POSTGRES_URI) logger.info("Embeddings updated") docs = [doc.text for doc in docs] local_results = local_retriever.batch_retrieve(docs) if remote_reindex: remote_result = local_results.copy() else: remote_result = remote_retriever.batch_retrieve(docs) for _id, l, r in tqdm(zip(new_ids, local_results, remote_result), total=len(new_ids)): local_sim = l.get("similarity_score", 0) remote_sim = r.get("similarity_score", 0) if (local_sim > HARD_SIM_THRESHOLD) & (remote_sim > HARD_SIM_THRESHOLD): if local_sim >= remote_sim: sim_data = { "sim_score": local_sim, "similar_to": l["retrieve_result"] } else: sim_data = { "sim_score": remote_sim, "similar_to": r["retrieve_result"] } remote_doc_store.update_document_meta(_id, sim_data) logger.info("Similarity scores updated into metadata")
def test_train_candidate_vectorizer(mocker): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) save_path = os.path.join(parent_cwd, "cand_vector.test_pkl") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) if os.path.exists(save_path): os.remove(save_path) retriever.train_candidate_vectorizer(save_path=save_path) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True os.path.exists(save_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=100) retriever = Retriever(document_store=document_store) with pytest.raises(ValueError): retriever.train_candidate_vectorizer(save_path=save_path, retrain=False) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=128) retriever = Retriever(document_store=document_store) retriever.train_candidate_vectorizer(save_path=save_path, retrain=False) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") retriever = Retriever(document_store=document_store) with pytest.raises(ValueError): retriever.train_candidate_vectorizer(save_path=save_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) mocker.patch( 'modules.ml.document_store.faiss.FAISSDocumentStore.get_all_documents', return_value=[]) with pytest.raises(ValueError): retriever.train_candidate_vectorizer() url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(20) training_documents = [query_result.text for query_result in query.all()] document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) retriever.train_candidate_vectorizer(training_documents=training_documents) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True # Remove temp test database os.remove(temp_test_db_path)
def test_sequential_retrieve(): number_input_doc = 3 # Copy new temp test database copyfile(test_db_path, temp_test_db_path) # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(number_input_doc) # Get text from query to input input_docs = [query.all()[i].text for i in range(number_input_doc)] # print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_results = [] expected_score_results = [] for i in range(number_input_doc): expected_text_results.append(input_docs[i]) expected_score_results.append(1) print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Subtets 1 check raise ERROR synchronize with pytest.raises(ValueError): results = retriever.sequential_retrieve(input_docs, top_k_candidates=10) # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() print("Retrieving") # Test without process input data num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=False) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Test with processing input data num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=True) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Test with meta_docs num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=True, meta_docs=[{ "author": "duclt", "task": ["test", "retrieve", "query"], "author": "duclt" }]) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Remove temp test database os.remove(temp_test_db_path)
def test_batch_retriever(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(5) # Get text from query to input input_doc = query.all()[0].text print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_result = input_doc expected_score_result = 1 print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Subtets 1 check raise ERROR synchronize with pytest.raises(ValueError): result = retriever.batch_retrieve([input_doc], top_k_candidates=10) # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() print("Retrieving") # Test without process input data result = retriever.batch_retrieve([input_doc], top_k_candidates=10) # print the retrieved doc print(" ".join(result[0]["retrieve_result"].split(" ")[:50])) assert result[0]["query_doc"] == input_doc assert result[0]["retrieve_result"] == expected_text_result assert result[0]["similarity_score"] == expected_score_result # Test with processing input data result = retriever.batch_retrieve([input_doc], top_k_candidates=10, processe_query_docs=True) # print the retrieved doc print(" ".join(result[0]["retrieve_result"].split(" ")[:50])) assert result[0]["query_doc"] == input_doc assert result[0]["retrieve_result"] == expected_text_result assert result[0]["similarity_score"] == expected_score_result # Remove temp test database os.remove(temp_test_db_path)
) AS url_table INNER JOIN "document" d ON url_table.document_id = d.id WHERE CAST(levenshtein('{0}', url) AS DECIMAL) / CAST(length(url) AS DECIMAL) < {1} ORDER BY levenshtein('{0}', url) LIMIT 1 """ # Default methods preprocessor = ViPreProcessor(split_by="sentence") document_store = FAISSDocumentStore(sql_url=POSTGRES_URI, vector_dim=CAND_DIM, index_buffer_size=5000) retriever = Retriever( document_store=document_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) remote_doc_store = FAISSDocumentStore(sql_url=POSTGRES_URI, vector_dim=CAND_DIM) tags_metadata = [ { "name": "get", "description": "Properly do nothing now" }, { "name": "compare",
def update_local_db(local_doc_store, remote_doc_store): """This method runs in serial as follow: - Compares list of `document_id` between local and remote database - Fetches and writes new documents into local database - Updates embeddings and vector ids on small FAISS index - Runs batch retriever to pre-calculate the similarity scores and updates metadata on remote database """ if not local_doc_store or not remote_doc_store: logger.warning("DB connection not initialized, try to re-connect...") local_doc_store = get_connection(LOCAL_DB_URI, CAND_DIM) remote_doc_store = get_connection(POSTGRES_URI, CAND_DIM) if not local_doc_store or not remote_doc_store: logger.error("DB initialization failed, quit local_update...") return remote_reindex = not os.path.exists(REMOTE_IDX_PATH) now = datetime.now() if remote_reindex: new_ids = remote_doc_store.get_document_ids(from_time=now - timedelta(days=365), index=INDEX) else: new_ids = remote_doc_store.get_document_ids(from_time=now - timedelta(days=1), index=INDEX) local_ids = local_doc_store.get_document_ids(index=INDEX) # Filter existing ids in local out of recent updated ids from remote db new_ids = sorted([_id for _id in new_ids if _id not in local_ids]) if not new_ids: logger.info(f"No new updates in local db") return docs = remote_doc_store.get_documents_by_id(new_ids, index=INDEX) logger.info(f"Retrieved {len(docs)} docs") local_doc_store.write_documents(docs) logger.info(f"Stored {len(docs)} docs to local db") local_retriever = Retriever( document_store=local_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) remote_retriever = Retriever( document_store=remote_doc_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) if not os.path.exists(CAND_PATH) or not os.path.exists(RTRV_PATH): remote_retriever.train_candidate_vectorizer(retrain=True, save_path=CAND_PATH) remote_retriever.train_retriever_vectorizer(retrain=True, save_path=RTRV_PATH) logger.info("Vectorizers retrained") local_retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) local_retriever.train_retriever_vectorizer(retrain=False, save_path=RTRV_PATH) logger.info("Vectorizers loaded") local_retriever.update_embeddings(retrain=True, save_path=LOCAL_IDX_PATH, sql_url=LOCAL_DB_URI) logger.info("Embeddings updated") results = local_retriever.batch_retrieve(docs) # Split payloads to chunks to reduce pressure on the database results_chunks = list(chunks(results, 1000)) for i in tqdm(range(len(results_chunks)), desc="Updating meta..... "): id_meta = list() for r in results_chunks[i]: rank = "_".join(list(r.keys())[-1].split("_")[-2:]) sim_score = r.get(f"sim_score_{rank}", 0) if sim_score > HARD_SIM_THRESHOLD: sim_data = { "document_id": r["document_id"], f"sim_score_{rank}": sim_score, f"similar_to_{rank}": r[f"sim_document_id_{rank}"], } id_meta.append(sim_data) if id_meta: remote_doc_store.update_documents_meta(id_meta=id_meta) logger.info("Similarity scores updated into metadata") consolidate_sim_docs(remote_doc_store)
def test_retriever_with_database(): cand_dim = 768 rtrv_dim = 1024 sql_url = "postgresql+psycopg2://user:pwd@host/topdup_articles" print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(cand_dim) rtrv_vectorizer = TfidfDocVectorizer(rtrv_dim) print("Init DocumentStore") document_store = FAISSDocumentStore(sql_url=sql_url, vector_dim=cand_dim, index_buffer_size=5000) print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching if os.path.exists(os.path.join(parent_cwd, "cand.bin")): print("Loading vectorizers") retriever.train_candidate_vectorizer(retrain=False, save_path=os.path.join( parent_cwd, "cand.bin")) retriever.train_retriever_vectorizer(retrain=False, save_path=os.path.join( parent_cwd, "rtrv.bin")) else: print("Training vectorizers") retriever.train_candidate_vectorizer(retrain=True, save_path=os.path.join( parent_cwd, "cand.bin")) retriever.train_retriever_vectorizer(retrain=True, save_path=os.path.join( parent_cwd, "rtrv.bin")) # Update trained embeddings to index of FAISSDocumentStore if os.path.exists(os.path.join(parent_cwd, "index.bin")): print("Loading index of FAISSDocumentStore") retriever.update_embeddings( retrain=False, save_path=os.path.join(parent_cwd, "index.bin"), sql_url=sql_url, ) else: print("Updating embeddings to index of FAISSDocumentStore") retriever.update_embeddings( retrain=True, save_path=os.path.join(parent_cwd, "index.bin"), sql_url=sql_url, ) # Get a pair of duplicated articles from topdup.xyz to test input_doc = document_store.get_document_by_id( id="76a0874a-b0db-477c-a0ca-9e65b8ccf2f3").text print(" ".join(input_doc.split(" ")[:50])) # print the query doc print("Retrieving") result = retriever.batch_retrieve([input_doc], top_k_candidates=10) print(" ".join(result[0]["retrieve_result"].split(" ") [:50])) # print the retrieved doc