예제 #1
0
파일: utils.py 프로젝트: voxlogic/haystack
def eval_data_from_file(filename: str) -> Tuple[List[Document], List[Label]]:
    """
    Read Documents + Labels from a SQuAD-style file.
    Document and Labels can then be indexed to the DocumentStore and be used for evaluation.

    :param filename: Path to file in SQuAD format
    :return: (List of Documents, List of Labels)
    """
    docs = []
    labels = []

    with open(filename, "r") as file:
        data = json.load(file)
        for document in data["data"]:
            # get all extra fields from document level (e.g. title)
            meta_doc = {k: v for k, v in document.items() if k not in ("paragraphs", "title")}
            for paragraph in document["paragraphs"]:
                cur_meta = {"name": document["title"]}
                # all other fields from paragraph level
                meta_paragraph = {k: v for k, v in paragraph.items() if k not in ("qas", "context")}
                cur_meta.update(meta_paragraph)
                # meta from parent document
                cur_meta.update(meta_doc)
                # Create Document
                cur_doc = Document(text=paragraph["context"], meta=cur_meta)
                docs.append(cur_doc)

                # Get Labels
                for qa in paragraph["qas"]:
                    if len(qa["answers"]) > 0:
                        for answer in qa["answers"]:
                            label = Label(
                                question=qa["question"],
                                answer=answer["text"],
                                is_correct_answer=True,
                                is_correct_document=True,
                                document_id=cur_doc.id,
                                offset_start_in_doc=answer["answer_start"],
                                no_answer=qa["is_impossible"],
                                origin="gold_label",
                                )
                            labels.append(label)
                    else:
                        label = Label(
                            question=qa["question"],
                            answer="",
                            is_correct_answer=True,
                            is_correct_document=True,
                            document_id=cur_doc.id,
                            offset_start_in_doc=0,
                            no_answer=qa["is_impossible"],
                            origin="gold_label",
                        )
                        labels.append(label)
        return docs, labels
예제 #2
0
    def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
        index = index or self.label_index
        label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]

        for label in label_objects:
            label_id = uuid.uuid4()
            self.indexes[index][label_id] = label
예제 #3
0
    def write_labels(self,
                     labels: Union[List[Label], List[dict]],
                     index: Optional[str] = None):
        index = index or self.label_index
        if index and not self.client.indices.exists(index=index):
            self._create_label_index(index)

        # Make sure we comply to Label class format
        label_objects = [
            Label.from_dict(l) if isinstance(l, dict) else l for l in labels
        ]

        labels_to_index = []
        for label in label_objects:
            _label = {
                "_op_type":
                "index" if self.update_existing_documents else "create",
                "_index": index,
                **label.to_dict()
            }  # type: Dict[str, Any]

            labels_to_index.append(_label)
        bulk(self.client,
             labels_to_index,
             request_timeout=300,
             refresh="wait_for")
예제 #4
0
 def get_all_labels(
         self,
         index: Optional[str] = None,
         filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
     index = index or self.label_index
     result = self.get_all_documents_in_index(index=index, filters=filters)
     labels = [Label.from_dict(hit["_source"]) for hit in result]
     return labels
예제 #5
0
 def _convert_sql_row_to_label(self, row) -> Label:
     label = Label(
         document_id=row.document_id,
         no_answer=row.no_answer,
         origin=row.origin,
         question=row.question,
         is_correct_answer=row.is_correct_answer,
         is_correct_document=row.is_correct_document,
         answer=row.answer,
         offset_start_in_doc=row.offset_start_in_doc,
         model_id=row.model_id,
     )
     return label
예제 #6
0
def test_labels(document_store):
    label = Label(
        question="question",
        answer="answer",
        is_correct_answer=True,
        is_correct_document=True,
        document_id="123",
        offset_start_in_doc=12,
        no_answer=False,
        origin="gold_label",
    )
    document_store.write_labels([label], index="haystack_test_label")
    labels = document_store.get_all_labels(index="haystack_test_label")
    assert len(labels) == 1

    labels = document_store.get_all_labels()
    assert len(labels) == 0
예제 #7
0
    def write_labels(self, labels, index=None):

        labels = [
            Label.from_dict(l) if isinstance(l, dict) else l for l in labels
        ]
        index = index or self.label_index
        for label in labels:
            label_orm = LabelORM(
                document_id=label.document_id,
                no_answer=label.no_answer,
                origin=label.origin,
                question=label.question,
                is_correct_answer=label.is_correct_answer,
                is_correct_document=label.is_correct_document,
                answer=label.answer,
                offset_start_in_doc=label.offset_start_in_doc,
                model_id=label.model_id,
                index=index,
            )
            self.session.add(label_orm)
        self.session.commit()
예제 #8
0
 def get_all_labels(self,
                    index: str = "label",
                    filters: Optional[dict] = None) -> List[Label]:
     result = self.get_all_documents_in_index(index=index, filters=filters)
     labels = [Label.from_dict(hit["_source"]) for hit in result]
     return labels
예제 #9
0
def test_multilabel_no_answer(document_store):
    labels = [
        Label(
            question="question",
            answer="",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="777",
            offset_start_in_doc=0,
            no_answer=True,
            origin="gold_label",
        ),
        # no answer in different doc
        Label(
            question="question",
            answer="",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="123",
            offset_start_in_doc=0,
            no_answer=True,
            origin="gold_label",
        ),
        # no answer in same doc, should be excluded
        Label(
            question="question",
            answer="",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="777",
            offset_start_in_doc=0,
            no_answer=True,
            origin="gold_label",
        ),
        # no answer with is_correct_answer=False, should be excluded
        Label(
            question="question",
            answer="",
            is_correct_answer=False,
            is_correct_document=True,
            document_id="321",
            offset_start_in_doc=0,
            no_answer=True,
            origin="gold_label",
        ),
    ]

    document_store.write_labels(labels,
                                index="haystack_test_multilabel_no_answer")
    multi_labels = document_store.get_all_labels_aggregated(
        index="haystack_test_multilabel_no_answer")
    labels = document_store.get_all_labels(
        index="haystack_test_multilabel_no_answer")

    assert len(multi_labels) == 1
    assert len(labels) == 4

    assert len(multi_labels[0].multiple_document_ids) == 2
    assert len(multi_labels[0].multiple_answers) \
           == len(multi_labels[0].multiple_document_ids) \
           == len(multi_labels[0].multiple_offset_start_in_docs)

    # clean up
    document_store.delete_all_documents(
        index="haystack_test_multilabel_no_answer")
예제 #10
0
def test_multilabel(document_store):
    labels = [
        Label(
            question="question",
            answer="answer1",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="123",
            offset_start_in_doc=12,
            no_answer=False,
            origin="gold_label",
        ),
        # different answer in same doc
        Label(
            question="question",
            answer="answer2",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="123",
            offset_start_in_doc=42,
            no_answer=False,
            origin="gold_label",
        ),
        # answer in different doc
        Label(
            question="question",
            answer="answer3",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="321",
            offset_start_in_doc=7,
            no_answer=False,
            origin="gold_label",
        ),
        # 'no answer', should be excluded from MultiLabel
        Label(
            question="question",
            answer="",
            is_correct_answer=True,
            is_correct_document=True,
            document_id="777",
            offset_start_in_doc=0,
            no_answer=True,
            origin="gold_label",
        ),
        # is_correct_answer=False, should be excluded from MultiLabel
        Label(
            question="question",
            answer="answer5",
            is_correct_answer=False,
            is_correct_document=True,
            document_id="123",
            offset_start_in_doc=99,
            no_answer=True,
            origin="gold_label",
        ),
    ]
    document_store.write_labels(labels, index="haystack_test_multilabel")
    multi_labels = document_store.get_all_labels_aggregated(
        index="haystack_test_multilabel")
    labels = document_store.get_all_labels(index="haystack_test_multilabel")

    assert len(multi_labels) == 1
    assert len(labels) == 5

    assert len(multi_labels[0].multiple_answers) == 3
    assert len(multi_labels[0].multiple_answers) \
           == len(multi_labels[0].multiple_document_ids) \
           == len(multi_labels[0].multiple_offset_start_in_docs)

    multi_labels = document_store.get_all_labels_aggregated()
    assert len(multi_labels) == 0

    # clean up
    document_store.delete_all_documents(index="haystack_test_multilabel")