示例#1
0
def main():

    cwd = Path(__file__).parent
    if not os.path.exists(os.path.join(cwd, "post_dataset.pkl")):
        print("Downloading sample data of TopDup")
        urllib.request.urlretrieve(
            "https://storage.googleapis.com/topdup/dataset/post_dataset.pkl",
            os.path.join(cwd, "post_dataset.pkl"),
        )

    print("Preprocessing documents")
    processor = ViPreProcessor()
    data = pickle.load(open(os.path.join(cwd, "post_dataset.pkl"), "rb"))
    docs = list()
    for d in tqdm(data[:1000]):
        content, meta = data_prep(jsonpickle.loads(d))
        doc = processor.clean({"text": content})
        for m in meta.keys():
            if isinstance(meta[m], list):  # serialize list
                meta[m] = "|".join(meta[m])
        doc["meta"] = meta
        docs.append(doc)

    print("Ingesting data to SQLite database")
    db_path = os.path.join(cwd, "topdup.db")
    if os.path.exists(db_path):
        os.remove(db_path)
    with sqlite3.connect(db_path):
        document_store = FAISSDocumentStore(sql_url=f"sqlite:///{db_path}")
        document_store.write_documents(docs)

    pass
示例#2
0
def test_construction():
    with pytest.raises(ValueError):
        retriever = Retriever()
        print(type(retriever))

    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)
    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(100)
    rtrv_vectorizer = TfidfDocVectorizer(256)

    with pytest.raises(ValueError):
        retriever = Retriever(
            document_store=document_store,
            candidate_vectorizer=cand_vectorizer,
            retriever_vectorizer=rtrv_vectorizer,
        )
        assert isinstance(retriever, Retriever)

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    assert isinstance(retriever, Retriever)

    # Remove temp test database
    os.remove(temp_test_db_path)
示例#3
0
def get_connection(uri: str, vector_dim: int):
    try:
        conn = FAISSDocumentStore(sql_url=uri, vector_dim=vector_dim)
        return conn
    except Exception as e:
        logger.error(e)
        return None
示例#4
0
def test_update_embeddings():
    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)
    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    with pytest.raises(ValueError):
        retriever.update_embeddings()

    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    retriever.train_candidate_vectorizer()
    save_path = os.path.join(parent_cwd, "document_store.test_pkl")
    retriever.update_embeddings(save_path=save_path)
    # Remove temp test database
    os.remove(temp_test_db_path)
示例#5
0
def test_retriever_with_database():

    print("Init vectorizers")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)

    print("Init DocumentStore")
    document_store = FAISSDocumentStore(sql_url=f"sqlite:///{test_db_path}")

    print("Init retriever")
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )

    # Train vectorizers for two phases of searching
    print("Training vectorizers")
    retriever.train_candidate_vectorizer()
    retriever.train_retriever_vectorizer()

    # Update trained embeddings to DocumentStore
    print("Updating embeddings to DocumentStore")
    retriever.update_embeddings()

    # Get a document from the database as input for retriever
    print("Query sample input from database")
    url = f"sqlite:///{test_db_path}"
    engine = create_engine(url)
    ORMBase.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)
    session = Session()

    query = session.query(DocumentORM).filter_by().limit(5)

    # Get text from query to input
    input_doc = query.all()[0].text
    print(" ".join(input_doc.split(" ")[:50]))  # print the query doc

    # Init expected result
    expected_text_result = input_doc
    expected_score_result = 1

    print("Retrieving")

    result = retriever.retrieve([input_doc], top_k_candidates=10)
    print(
        " ".join(result[input_doc]["retrieve_result"].split(" ")[:50])
    )  # print the retrieved doc
    assert result[input_doc]["retrieve_result"] == expected_text_result
    assert result[input_doc]["similarity_score"] == expected_score_result
示例#6
0
def test_get_candidates():
    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)
    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    with pytest.raises(ValueError):
        retriever.get_candidates(query_docs=["Test candidate"])

    # Remove temp test database
    os.remove(temp_test_db_path)
示例#7
0
    def update_embeddings(self,
                          retrain: bool = True,
                          save_path: str = None,
                          sql_url: str = None):
        """Updates embeddings of documents with candidate vectorizer to `document_store`.
        """
        if retrain:
            if not self.candidate_vectorizer.is_trained:
                raise ValueError(
                    "Candidate vectorizer is not trained yet."
                    " Try to call train_candidate_vectorizer first.")

            self.document_store.update_embeddings(self.candidate_vectorizer)
            if save_path:
                self.document_store.save(file_path=save_path)
        else:
            self.document_store = FAISSDocumentStore.load(
                faiss_file_path=save_path, sql_url=sql_url)
示例#8
0
def test_train_candidate_vectorizer(mocker):
    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)
    save_path = os.path.join(parent_cwd, "cand_vector.test_pkl")

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )

    if os.path.exists(save_path):
        os.remove(save_path)
    retriever.train_candidate_vectorizer(save_path=save_path)
    assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase)
    assert retriever.candidate_vectorizer.is_trained == True
    os.path.exists(save_path)

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=100)
    retriever = Retriever(document_store=document_store)
    with pytest.raises(ValueError):
        retriever.train_candidate_vectorizer(save_path=save_path,
                                             retrain=False)

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}", vector_dim=128)
    retriever = Retriever(document_store=document_store)
    retriever.train_candidate_vectorizer(save_path=save_path, retrain=False)
    assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase)
    assert retriever.candidate_vectorizer.is_trained == True

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    retriever = Retriever(document_store=document_store)
    with pytest.raises(ValueError):
        retriever.train_candidate_vectorizer(save_path=save_path)

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    mocker.patch(
        'modules.ml.document_store.faiss.FAISSDocumentStore.get_all_documents',
        return_value=[])

    with pytest.raises(ValueError):
        retriever.train_candidate_vectorizer()

    url = f"sqlite:///{temp_test_db_path}"
    engine = create_engine(url)
    ORMBase.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)
    session = Session()

    query = session.query(DocumentORM).filter_by().limit(20)
    training_documents = [query_result.text for query_result in query.all()]

    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )
    retriever.train_candidate_vectorizer(training_documents=training_documents)
    assert isinstance(retriever.candidate_vectorizer, DocVectorizerBase)
    assert retriever.candidate_vectorizer.is_trained == True

    # Remove temp test database
    os.remove(temp_test_db_path)
示例#9
0
def test_sequential_retrieve():
    number_input_doc = 3

    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)

    # Get a document from the database as input for retriever
    print("Query sample input from database")
    url = f"sqlite:///{temp_test_db_path}"
    engine = create_engine(url)
    ORMBase.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)
    session = Session()

    query = session.query(DocumentORM).filter_by().limit(number_input_doc)

    # Get text from query to input
    input_docs = [query.all()[i].text for i in range(number_input_doc)]
    # print(" ".join(input_doc.split(" ")[:50]))  # print the query doc

    # Init expected result
    expected_text_results = []
    expected_score_results = []

    for i in range(number_input_doc):
        expected_text_results.append(input_docs[i])
        expected_score_results.append(1)

    print("Init vectorizers")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)

    print("Init DocumentStore")
    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")

    print("Init retriever")
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )

    # Train vectorizers for two phases of searching
    print("Training vectorizers")
    retriever.train_candidate_vectorizer()
    retriever.train_retriever_vectorizer()

    # Subtets 1 check raise ERROR synchronize
    with pytest.raises(ValueError):
        results = retriever.sequential_retrieve(input_docs,
                                                top_k_candidates=10)

    # Update trained embeddings to DocumentStore
    print("Updating embeddings to DocumentStore")
    retriever.update_embeddings()

    print("Retrieving")

    # Test without process input data
    num_doc_before = retriever.document_store.get_document_count()
    results = retriever.sequential_retrieve(input_docs,
                                            top_k_candidates=10,
                                            processe_query_docs=False)

    for i in range(number_input_doc):
        assert results[i]["query_doc"] == input_docs[i]
        assert results[i]["retrieve_result"] == expected_text_results[i]
        assert results[i]["similarity_score"] == expected_score_results[i]

    num_doc_after = retriever.document_store.get_document_count()
    assert num_doc_after == num_doc_before + number_input_doc

    # Test with processing input data
    num_doc_before = retriever.document_store.get_document_count()
    results = retriever.sequential_retrieve(input_docs,
                                            top_k_candidates=10,
                                            processe_query_docs=True)
    for i in range(number_input_doc):
        assert results[i]["query_doc"] == input_docs[i]
        assert results[i]["retrieve_result"] == expected_text_results[i]
        assert results[i]["similarity_score"] == expected_score_results[i]

    num_doc_after = retriever.document_store.get_document_count()
    assert num_doc_after == num_doc_before + number_input_doc

    # Test with meta_docs
    num_doc_before = retriever.document_store.get_document_count()
    results = retriever.sequential_retrieve(input_docs,
                                            top_k_candidates=10,
                                            processe_query_docs=True,
                                            meta_docs=[{
                                                "author":
                                                "duclt",
                                                "task":
                                                ["test", "retrieve", "query"],
                                                "author":
                                                "duclt"
                                            }])
    for i in range(number_input_doc):
        assert results[i]["query_doc"] == input_docs[i]
        assert results[i]["retrieve_result"] == expected_text_results[i]
        assert results[i]["similarity_score"] == expected_score_results[i]

    num_doc_after = retriever.document_store.get_document_count()
    assert num_doc_after == num_doc_before + number_input_doc
    # Remove temp test database
    os.remove(temp_test_db_path)
示例#10
0
def test_batch_retriever():
    # Copy new temp test database
    copyfile(test_db_path, temp_test_db_path)

    # Get a document from the database as input for retriever
    print("Query sample input from database")
    url = f"sqlite:///{temp_test_db_path}"
    engine = create_engine(url)
    ORMBase.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)
    session = Session()

    query = session.query(DocumentORM).filter_by().limit(5)

    # Get text from query to input
    input_doc = query.all()[0].text
    print(" ".join(input_doc.split(" ")[:50]))  # print the query doc

    # Init expected result
    expected_text_result = input_doc
    expected_score_result = 1

    print("Init vectorizers")
    cand_vectorizer = TfidfDocVectorizer(128)
    rtrv_vectorizer = TfidfDocVectorizer(256)

    print("Init DocumentStore")
    document_store = FAISSDocumentStore(
        sql_url=f"sqlite:///{temp_test_db_path}")

    print("Init retriever")
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )

    # Train vectorizers for two phases of searching
    print("Training vectorizers")
    retriever.train_candidate_vectorizer()
    retriever.train_retriever_vectorizer()

    # Subtets 1 check raise ERROR synchronize
    with pytest.raises(ValueError):
        result = retriever.batch_retrieve([input_doc], top_k_candidates=10)

    # Update trained embeddings to DocumentStore
    print("Updating embeddings to DocumentStore")
    retriever.update_embeddings()

    print("Retrieving")

    # Test without process input data
    result = retriever.batch_retrieve([input_doc], top_k_candidates=10)
    # print the retrieved doc
    print(" ".join(result[0]["retrieve_result"].split(" ")[:50]))

    assert result[0]["query_doc"] == input_doc
    assert result[0]["retrieve_result"] == expected_text_result
    assert result[0]["similarity_score"] == expected_score_result

    # Test with processing input data
    result = retriever.batch_retrieve([input_doc],
                                      top_k_candidates=10,
                                      processe_query_docs=True)
    # print the retrieved doc
    print(" ".join(result[0]["retrieve_result"].split(" ")[:50]))

    assert result[0]["query_doc"] == input_doc
    assert result[0]["retrieve_result"] == expected_text_result
    assert result[0]["similarity_score"] == expected_score_result

    # Remove temp test database
    os.remove(temp_test_db_path)
示例#11
0
        FROM meta m
        WHERE lower(name) IN (
                'href'
                ,'url'
                )
        ) AS url_table
    INNER JOIN "document" d ON url_table.document_id = d.id
    WHERE CAST(levenshtein('{0}', url) AS DECIMAL) / CAST(length(url) AS DECIMAL) < {1}
    ORDER BY levenshtein('{0}', url) LIMIT 1
"""

# Default methods
preprocessor = ViPreProcessor(split_by="sentence")

document_store = FAISSDocumentStore(sql_url=POSTGRES_URI,
                                    vector_dim=CAND_DIM,
                                    index_buffer_size=5000)

retriever = Retriever(
    document_store=document_store,
    candidate_vectorizer=TfidfDocVectorizer(CAND_DIM),
    retriever_vectorizer=TfidfDocVectorizer(RTRV_DIM),
)
retriever.train_candidate_vectorizer(retrain=False, save_path=CAND_PATH)

remote_doc_store = FAISSDocumentStore(sql_url=POSTGRES_URI,
                                      vector_dim=CAND_DIM)

tags_metadata = [
    {
        "name": "get",
示例#12
0
def test_retriever_with_database():

    cand_dim = 768
    rtrv_dim = 1024
    sql_url = "postgresql+psycopg2://user:pwd@host/topdup_articles"

    print("Init vectorizers")
    cand_vectorizer = TfidfDocVectorizer(cand_dim)
    rtrv_vectorizer = TfidfDocVectorizer(rtrv_dim)

    print("Init DocumentStore")
    document_store = FAISSDocumentStore(sql_url=sql_url,
                                        vector_dim=cand_dim,
                                        index_buffer_size=5000)

    print("Init retriever")
    retriever = Retriever(
        document_store=document_store,
        candidate_vectorizer=cand_vectorizer,
        retriever_vectorizer=rtrv_vectorizer,
    )

    # Train vectorizers for two phases of searching
    if os.path.exists(os.path.join(parent_cwd, "cand.bin")):
        print("Loading vectorizers")
        retriever.train_candidate_vectorizer(retrain=False,
                                             save_path=os.path.join(
                                                 parent_cwd, "cand.bin"))
        retriever.train_retriever_vectorizer(retrain=False,
                                             save_path=os.path.join(
                                                 parent_cwd, "rtrv.bin"))
    else:
        print("Training vectorizers")
        retriever.train_candidate_vectorizer(retrain=True,
                                             save_path=os.path.join(
                                                 parent_cwd, "cand.bin"))
        retriever.train_retriever_vectorizer(retrain=True,
                                             save_path=os.path.join(
                                                 parent_cwd, "rtrv.bin"))

    # Update trained embeddings to index of FAISSDocumentStore
    if os.path.exists(os.path.join(parent_cwd, "index.bin")):
        print("Loading index of FAISSDocumentStore")
        retriever.update_embeddings(
            retrain=False,
            save_path=os.path.join(parent_cwd, "index.bin"),
            sql_url=sql_url,
        )
    else:
        print("Updating embeddings to index of FAISSDocumentStore")
        retriever.update_embeddings(
            retrain=True,
            save_path=os.path.join(parent_cwd, "index.bin"),
            sql_url=sql_url,
        )

    # Get a pair of duplicated articles from topdup.xyz to test
    input_doc = document_store.get_document_by_id(
        id="76a0874a-b0db-477c-a0ca-9e65b8ccf2f3").text
    print(" ".join(input_doc.split(" ")[:50]))  # print the query doc

    print("Retrieving")
    result = retriever.batch_retrieve([input_doc], top_k_candidates=10)
    print(" ".join(result[0]["retrieve_result"].split(" ")
                   [:50]))  # print the retrieved doc