Ejemplo n.º 1
0
def test_elasticsearch_retrieval_filters(document_store_with_docs):
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
    assert res[0].text == "My name is Carla and I live in Berlin"
    assert len(res) == 1
    assert res[0].meta["name"] == "filename1"

    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
    assert len(res) == 0

    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
    assert len(res) == 0

    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
    assert res[0].text == "My name is Carla and I live in Berlin"
    assert len(res) == 1
    assert res[0].meta["name"] == "filename1"

    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
    assert len(res) == 0
Ejemplo n.º 2
0
def get_hard_negative_context(retriever: ElasticsearchRetriever,
                              question: str,
                              answer: str,
                              n_ctxs: int = 30,
                              n_chars: int = 600):
    list_hard_neg_ctxs = []
    retrieved_docs = retriever.retrieve(query=question,
                                        top_k=n_ctxs,
                                        index="document")
    for retrieved_doc in retrieved_docs:
        retrieved_doc_id = retrieved_doc.meta["name"]
        retrieved_doc_text = retrieved_doc.text
        if answer.lower() in retrieved_doc_text.lower():
            continue
        list_hard_neg_ctxs.append({
            "title": retrieved_doc_id,
            "text": retrieved_doc_text[:n_chars]
        })

    return list_hard_neg_ctxs
Ejemplo n.º 3
0
def test_elasticsearch_custom_query(elasticsearch_fixture):
    client = Elasticsearch()
    client.indices.delete(index="haystack_test_custom", ignore=[404])
    document_store = ElasticsearchDocumentStore(
        index="haystack_test_custom",
        text_field="custom_text_field",
        embedding_field="custom_embedding_field")
    documents = [
        {
            "text": "test_1",
            "meta": {
                "year": "2019"
            }
        },
        {
            "text": "test_2",
            "meta": {
                "year": "2020"
            }
        },
        {
            "text": "test_3",
            "meta": {
                "year": "2021"
            }
        },
        {
            "text": "test_4",
            "meta": {
                "year": "2021"
            }
        },
        {
            "text": "test_5",
            "meta": {
                "year": "2021"
            }
        },
    ]
    document_store.write_documents(documents)

    # test custom "terms" query
    retriever = ElasticsearchRetriever(
        document_store=document_store,
        custom_query="""
            {
                "size": 10, 
                "query": {
                    "bool": {
                        "should": [{
                            "multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
                            "filter": [{"terms": {"year": ${years}}}]}}}""",
    )
    results = retriever.retrieve(query="test",
                                 filters={"years": ["2020", "2021"]})
    assert len(results) == 4

    # test custom "term" query
    retriever = ElasticsearchRetriever(
        document_store=document_store,
        custom_query="""
                {
                    "size": 10, 
                    "query": {
                        "bool": {
                            "should": [{
                                "multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
                                "filter": [{"term": {"year": ${years}}}]}}}""",
    )
    results = retriever.retrieve(query="test", filters={"years": "2021"})
    assert len(results) == 3
Ejemplo n.º 4
0
        paragraphs = [
            x['text'] for x in example['positive_ctxs']
            if len(x['text']) < 1200
        ]
        gold_paragraphs.update(paragraphs)
    return gold_paragraphs


train_para = get_all_positive_contexts(StrategyQADataset().train_set())
dev_para = get_all_positive_contexts(StrategyQADataset().dev_set())
all_paras = train_para.union(dev_para)

import dpr.experiments.document_store as doc_store_utils

elastic_ds = doc_store_utils.get_elastic_document_store()
retriever = ElasticsearchRetriever(document_store=elastic_ds)
mistakes = 0
for i, s in enumerate(all_paras):
    retrieve = retriever.retrieve(s, top_k=1)[0].text
    if not s == retrieve:
        print('expected ', s)
        print('got', retrieve)
        retrieve = [x.text for x in retriever.retrieve(s, top_k=20)]
        retrieve = [x for x in retrieve if x == s]
        if retrieve:
            print('found on second try', retrieve)
            continue
        mistakes += 1
        print('mistakes', mistakes)
        print('total', i + 1)
Ejemplo n.º 5
0
def test_elasticsearch_retrieval(document_store_with_docs):
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
    res = retriever.retrieve(query="Who lives in Berlin?")
    assert res[0].text == "My name is Carla and I live in Berlin"
    assert len(res) == 3
    assert res[0].meta["name"] == "filename1"
Ejemplo n.º 6
0
from pprint import pprint
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.sparse import ElasticsearchRetriever

if __name__ == '__main__':

    document_store = ElasticsearchDocumentStore(
        host="192.168.8.106",
        username="",
        password="",
        index="drqa_wiki",
    )

    retriever = ElasticsearchRetriever(document_store=document_store)
    while True:
        q = input("utter question: ")
        documents = retriever.retrieve(q, top_k=3)
        pprint([d.text for d in documents])