def main(): cwd = Path(__file__).parent if not os.path.exists(os.path.join(cwd, "post_dataset.pkl")): print("Downloading sample data of TopDup") urllib.request.urlretrieve( "https://storage.googleapis.com/topdup/dataset/post_dataset.pkl", os.path.join(cwd, "post_dataset.pkl"), ) print("Preprocessing documents") processor = ViPreProcessor() data = pickle.load(open(os.path.join(cwd, "post_dataset.pkl"), "rb")) docs = list() for d in tqdm(data[:1000]): content, meta = data_prep(jsonpickle.loads(d)) doc = processor.clean({"text": content}) for m in meta.keys(): if isinstance(meta[m], list): # serialize list meta[m] = "|".join(meta[m]) doc["meta"] = meta docs.append(doc) print("Ingesting data to SQLite database") db_path = os.path.join(cwd, "topdup.db") if os.path.exists(db_path): os.remove(db_path) with sqlite3.connect(db_path): document_store = FAISSDocumentStore(sql_url=f"sqlite:///{db_path}") document_store.write_documents(docs) pass
def test_construction(): with pytest.raises(ValueError): retriever = Retriever() print(type(retriever)) # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(100) rtrv_vectorizer = TfidfDocVectorizer(256) with pytest.raises(ValueError): retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) assert isinstance(retriever, Retriever) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) assert isinstance(retriever, Retriever) # Remove temp test database os.remove(temp_test_db_path)
def get_connection(uri: str, vector_dim: int): try: conn = FAISSDocumentStore(sql_url=uri, vector_dim=vector_dim) return conn except Exception as e: logger.error(e) return None
def test_update_embeddings(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) with pytest.raises(ValueError): retriever.update_embeddings() cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) retriever.train_candidate_vectorizer() save_path = os.path.join(parent_cwd, "document_store.test_pkl") retriever.update_embeddings(save_path=save_path) # Remove temp test database os.remove(temp_test_db_path)
def test_retriever_with_database(): print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore(sql_url=f"sqlite:///{test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(5) # Get text from query to input input_doc = query.all()[0].text print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_result = input_doc expected_score_result = 1 print("Retrieving") result = retriever.retrieve([input_doc], top_k_candidates=10) print( " ".join(result[input_doc]["retrieve_result"].split(" ")[:50]) ) # print the retrieved doc assert result[input_doc]["retrieve_result"] == expected_text_result assert result[input_doc]["similarity_score"] == expected_score_result
def test_get_candidates(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) with pytest.raises(ValueError): retriever.get_candidates(query_docs=["Test candidate"]) # Remove temp test database os.remove(temp_test_db_path)
def update_embeddings(self, retrain: bool = True, save_path: str = None, sql_url: str = None): """Updates embeddings of documents with candidate vectorizer to `document_store`. """ if retrain: if not self.candidate_vectorizer.is_trained: raise ValueError( "Candidate vectorizer is not trained yet." " Try to call train_candidate_vectorizer first.") self.document_store.update_embeddings(self.candidate_vectorizer) if save_path: self.document_store.save(file_path=save_path) else: self.document_store = FAISSDocumentStore.load( faiss_file_path=save_path, sql_url=sql_url)
def test_train_candidate_vectorizer(mocker): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) save_path = os.path.join(parent_cwd, "cand_vector.test_pkl") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) if os.path.exists(save_path): os.remove(save_path) retriever.train_candidate_vectorizer(save_path=save_path) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True os.path.exists(save_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=100) retriever = Retriever(document_store=document_store) with pytest.raises(ValueError): retriever.train_candidate_vectorizer(save_path=save_path, retrain=False) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=128) retriever = Retriever(document_store=document_store) retriever.train_candidate_vectorizer(save_path=save_path, retrain=False) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") retriever = Retriever(document_store=document_store) with pytest.raises(ValueError): retriever.train_candidate_vectorizer(save_path=save_path) document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) mocker.patch( 'modules.ml.document_store.faiss.FAISSDocumentStore.get_all_documents', return_value=[]) with pytest.raises(ValueError): retriever.train_candidate_vectorizer() url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(20) training_documents = [query_result.text for query_result in query.all()] document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) retriever.train_candidate_vectorizer(training_documents=training_documents) assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase) assert retriever.candidate_vectorizer.is_trained == True # Remove temp test database os.remove(temp_test_db_path)
def test_sequential_retrieve(): number_input_doc = 3 # Copy new temp test database copyfile(test_db_path, temp_test_db_path) # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(number_input_doc) # Get text from query to input input_docs = [query.all()[i].text for i in range(number_input_doc)] # print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_results = [] expected_score_results = [] for i in range(number_input_doc): expected_text_results.append(input_docs[i]) expected_score_results.append(1) print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Subtets 1 check raise ERROR synchronize with pytest.raises(ValueError): results = retriever.sequential_retrieve(input_docs, top_k_candidates=10) # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() print("Retrieving") # Test without process input data num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=False) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Test with processing input data num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=True) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Test with meta_docs num_doc_before = retriever.document_store.get_document_count() results = retriever.sequential_retrieve(input_docs, top_k_candidates=10, processe_query_docs=True, meta_docs=[{ "author": "duclt", "task": ["test", "retrieve", "query"], "author": "duclt" }]) for i in range(number_input_doc): assert results[i]["query_doc"] == input_docs[i] assert results[i]["retrieve_result"] == expected_text_results[i] assert results[i]["similarity_score"] == expected_score_results[i] num_doc_after = retriever.document_store.get_document_count() assert num_doc_after == num_doc_before + number_input_doc # Remove temp test database os.remove(temp_test_db_path)
def test_batch_retriever(): # Copy new temp test database copyfile(test_db_path, temp_test_db_path) # Get a document from the database as input for retriever print("Query sample input from database") url = f"sqlite:///{temp_test_db_path}" engine = create_engine(url) ORMBase.metadata.create_all(engine) Session = sessionmaker(bind=engine) session = Session() query = session.query(DocumentORM).filter_by().limit(5) # Get text from query to input input_doc = query.all()[0].text print(" ".join(input_doc.split(" ")[:50])) # print the query doc # Init expected result expected_text_result = input_doc expected_score_result = 1 print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(128) rtrv_vectorizer = TfidfDocVectorizer(256) print("Init DocumentStore") document_store = FAISSDocumentStore( sql_url=f"sqlite:///{temp_test_db_path}") print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching print("Training vectorizers") retriever.train_candidate_vectorizer() retriever.train_retriever_vectorizer() # Subtets 1 check raise ERROR synchronize with pytest.raises(ValueError): result = retriever.batch_retrieve([input_doc], top_k_candidates=10) # Update trained embeddings to DocumentStore print("Updating embeddings to DocumentStore") retriever.update_embeddings() print("Retrieving") # Test without process input data result = retriever.batch_retrieve([input_doc], top_k_candidates=10) # print the retrieved doc print(" ".join(result[0]["retrieve_result"].split(" ")[:50])) assert result[0]["query_doc"] == input_doc assert result[0]["retrieve_result"] == expected_text_result assert result[0]["similarity_score"] == expected_score_result # Test with processing input data result = retriever.batch_retrieve([input_doc], top_k_candidates=10, processe_query_docs=True) # print the retrieved doc print(" ".join(result[0]["retrieve_result"].split(" ")[:50])) assert result[0]["query_doc"] == input_doc assert result[0]["retrieve_result"] == expected_text_result assert result[0]["similarity_score"] == expected_score_result # Remove temp test database os.remove(temp_test_db_path)
FROM meta m WHERE lower(name) IN ( 'href' ,'url' ) ) AS url_table INNER JOIN "document" d ON url_table.document_id = d.id WHERE CAST(levenshtein('{0}', url) AS DECIMAL) / CAST(length(url) AS DECIMAL) < {1} ORDER BY levenshtein('{0}', url) LIMIT 1 """ # Default methods preprocessor = ViPreProcessor(split_by="sentence") document_store = FAISSDocumentStore(sql_url=POSTGRES_URI, vector_dim=CAND_DIM, index_buffer_size=5000) retriever = Retriever( document_store=document_store, candidate_vectorizer=TfidfDocVectorizer(CAND_DIM), retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM), ) retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH) remote_doc_store = FAISSDocumentStore(sql_url=POSTGRES_URI, vector_dim=CAND_DIM) tags_metadata = [ { "name": "get",
def test_retriever_with_database(): cand_dim = 768 rtrv_dim = 1024 sql_url = "postgresql+psycopg2://user:pwd@host/topdup_articles" print("Init vectorizers") cand_vectorizer = TfidfDocVectorizer(cand_dim) rtrv_vectorizer = TfidfDocVectorizer(rtrv_dim) print("Init DocumentStore") document_store = FAISSDocumentStore(sql_url=sql_url, vector_dim=cand_dim, index_buffer_size=5000) print("Init retriever") retriever = Retriever( document_store=document_store, candidate_vectorizer=cand_vectorizer, retriever_vectorizer=rtrv_vectorizer, ) # Train vectorizers for two phases of searching if os.path.exists(os.path.join(parent_cwd, "cand.bin")): print("Loading vectorizers") retriever.train_candidate_vectorizer(retrain=False, save_path=os.path.join( parent_cwd, "cand.bin")) retriever.train_retriever_vectorizer(retrain=False, save_path=os.path.join( parent_cwd, "rtrv.bin")) else: print("Training vectorizers") retriever.train_candidate_vectorizer(retrain=True, save_path=os.path.join( parent_cwd, "cand.bin")) retriever.train_retriever_vectorizer(retrain=True, save_path=os.path.join( parent_cwd, "rtrv.bin")) # Update trained embeddings to index of FAISSDocumentStore if os.path.exists(os.path.join(parent_cwd, "index.bin")): print("Loading index of FAISSDocumentStore") retriever.update_embeddings( retrain=False, save_path=os.path.join(parent_cwd, "index.bin"), sql_url=sql_url, ) else: print("Updating embeddings to index of FAISSDocumentStore") retriever.update_embeddings( retrain=True, save_path=os.path.join(parent_cwd, "index.bin"), sql_url=sql_url, ) # Get a pair of duplicated articles from topdup.xyz to test input_doc = document_store.get_document_by_id( id="76a0874a-b0db-477c-a0ca-9e65b8ccf2f3").text print(" ".join(input_doc.split(" ")[:50])) # print the query doc print("Retrieving") result = retriever.batch_retrieve([input_doc], top_k_candidates=10) print(" ".join(result[0]["retrieve_result"].split(" ") [:50])) # print the retrieved doc