Beispiel #1
0
def test_finder_get_answers():
    test_docs = [{
        "name": "testing the finder 1",
        "text": "testing the finder with pyhton unit test 1",
        "meta": {
            "test": "test"
        }
    }, {
        "name": "testing the finder 2",
        "text": "testing the finder with pyhton unit test 2",
        "meta": {
            "test": "test"
        }
    }, {
        "name": "testing the finder 3",
        "text": "testing the finder with pyhton unit test 3",
        "meta": {
            "test": "test"
        }
    }]

    document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
    document_store.write_documents(test_docs)
    retriever = TfidfRetriever(document_store=document_store)
    reader = TransformersReader(
        model="distilbert-base-uncased-distilled-squad",
        tokenizer="distilbert-base-uncased",
        use_gpu=-1)
    finder = Finder(reader, retriever)
    prediction = finder.get_answers(question="testing finder",
                                    top_k_retriever=10,
                                    top_k_reader=5)
    assert prediction is not None
Beispiel #2
0
def test_db_write_read():
    sql_document_store = SQLDocumentStore()
    write_documents_to_db(document_store=sql_document_store,
                          document_dir="samples/docs")
    documents = sql_document_store.get_all_documents()
    assert len(documents) == 2
    doc = sql_document_store.get_document_by_id("1")
    assert doc.keys() == {"id", "name", "text", "tags"}
Beispiel #3
0
def test_sql_write_read():
    sql_document_store = SQLDocumentStore()
    write_documents_to_db(document_store=sql_document_store,
                          document_dir="samples/docs")
    documents = sql_document_store.get_all_documents()
    assert len(documents) == 2
    doc = sql_document_store.get_document_by_id("1")
    assert doc.id
    assert doc.text
Beispiel #4
0
def test_sql_write_read():
    sql_document_store = SQLDocumentStore()
    documents = convert_files_to_dicts(dir_path="samples/docs")
    sql_document_store.write_documents(documents)
    documents = sql_document_store.get_all_documents()
    assert len(documents) == 2
    doc = sql_document_store.get_document_by_id("1")
    assert doc.id
    assert doc.text
Beispiel #5
0
def document_store(request, test_docs_xs, elasticsearch_fixture):
    if request.param == "sql":
        if os.path.exists("qa_test.db"):
            os.remove("qa_test.db")
        document_store = SQLDocumentStore(url="sqlite:///qa_test.db")

    if request.param == "memory":
        document_store = InMemoryDocumentStore()

    if request.param == "elasticsearch":
        # make sure we start from a fresh index
        client = Elasticsearch()
        client.indices.delete(index='haystack_test', ignore=[404])
        document_store = ElasticsearchDocumentStore(index="haystack_test")

    return document_store
Beispiel #6
0
def get_document_store(document_store_type):
    if document_store_type == "sql":
        if os.path.exists("haystack_test.db"):
            os.remove("haystack_test.db")
        document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
    elif document_store_type == "memory":
        document_store = InMemoryDocumentStore()
    elif document_store_type == "elasticsearch":
        # make sure we start from a fresh index
        client = Elasticsearch()
        client.indices.delete(index='haystack_test*', ignore=[404])
        document_store = ElasticsearchDocumentStore(index="haystack_test")
    elif document_store_type == "faiss":
        if os.path.exists("haystack_test_faiss.db"):
            os.remove("haystack_test_faiss.db")
        document_store = FAISSDocumentStore(sql_url="sqlite:///haystack_test_faiss.db")
    else:
        raise Exception(f"No document store fixture for '{document_store_type}'")

    return document_store
Beispiel #7
0
    faq_question_field=FAQ_QUESTION_FIELD_NAME,
)




if EMBEDDING_MODEL_PATH:
    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model=EMBEDDING_MODEL_PATH,
        model_format=EMBEDDING_MODEL_FORMAT,
        gpu=USE_GPU
    )  # type: BaseRetriever
else:
    retriever = ElasticsearchRetriever(document_store=document_store)'''
documentstore = SQLDocumentStore(url="sqlite:///qa.db")
retriever = TfidfRetriever(document_store=documentstore)

if READER_MODEL_PATH:  # for extractive doc-qa
    '''reader = FARMReader(
        model_name_or_path=str(READER_MODEL_PATH),
        batch_size=BATCHSIZE,
        use_gpu=USE_GPU,
        context_window_size=CONTEXT_WINDOW_SIZE,
        top_k_per_candidate=TOP_K_PER_CANDIDATE,
        no_ans_boost=NO_ANS_BOOST,
        num_processes=MAX_PROCESSES,
        max_seq_len=MAX_SEQ_LEN,
        doc_stride=DOC_STRIDE,
    )  # type: Optional[FARMReader]'''
Beispiel #8
0
#TODO Enable CORS

MODELS_DIRS = ["saved_models", "models", "model"]
USE_GPU = False
BATCH_SIZE = 16
DATABASE_URL = "sqlite:///qa.db"
MODEL_PATHS = ['deepset/bert-base-cased-squad2']

app = FastAPI(title="Haystack API", version="0.1")

if len(MODEL_PATHS) == 0:
    logger.error(
        f"No model to load. Please specify one via MODEL_PATHS (e.g. ['deepset/bert-base-cased-squad2']"
    )

datastore = SQLDocumentStore(url=DATABASE_URL)
retriever = TfidfRetriever(datastore=datastore)

FINDERS = {}
for idx, model_dir in enumerate(MODEL_PATHS, start=1):
    reader = FARMReader(model_name_or_path=str(model_dir),
                        batch_size=BATCH_SIZE,
                        use_gpu=USE_GPU)
    FINDERS[idx] = Finder(reader, retriever)
    logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")

logger.info(
    "Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info(
    """ Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/finders/1/ask' --data '{"question": "Who is the father of Arya Starck?"}'"""
)