コード例 #1
0
    def get_all_labels_aggregated(self,
                                  index: Optional[str] = None,
                                  filters: Optional[Dict[str, List[str]]] = None) -> List[MultiLabel]:
        aggregated_labels = []
        all_labels = self.get_all_labels(index=index, filters=filters)

        # Collect all answers to a question in a dict
        question_ans_dict = {} # type: ignore
        for l in all_labels:
            # only aggregate labels with correct answers, as only those can be currently used in evaluation
            if not l.is_correct_answer:
                continue

            if l.question in question_ans_dict:
                question_ans_dict[l.question].append(l)
            else:
                question_ans_dict[l.question] = [l]

        # Aggregate labels
        for q, ls in question_ans_dict.items():
            ls = list(set(ls))  # get rid of exact duplicates
            # check if there are both text answer and "no answer" present
            t_present = False
            no_present = False
            no_idx = []
            for idx, l in enumerate(ls):
                if len(l.answer) == 0:
                    no_present = True
                    no_idx.append(idx)
                else:
                    t_present = True
            # if both text and no answer are present, remove no answer labels
            if t_present and no_present:
                logger.warning(
                    f"Both text label and 'no answer possible' label is present for question: {ls[0].question}")
                for remove_idx in no_idx[::-1]:
                    ls.pop(remove_idx)

            # construct Aggregated_label
            for i, l in enumerate(ls):
                if i == 0:
                    agg_label = MultiLabel(question=l.question,
                                           multiple_answers=[l.answer],
                                           is_correct_answer=l.is_correct_answer,
                                           is_correct_document=l.is_correct_document,
                                           origin=l.origin,
                                           multiple_document_ids=[l.document_id],
                                           multiple_offset_start_in_docs=[l.offset_start_in_doc],
                                           no_answer=l.no_answer,
                                           model_id=l.model_id,
                                           )
                else:
                    agg_label.multiple_answers.append(l.answer)
                    agg_label.multiple_document_ids.append(l.document_id)
                    agg_label.multiple_offset_start_in_docs.append(l.offset_start_in_doc)
            aggregated_labels.append(agg_label)
        return aggregated_labels
コード例 #2
0
ファイル: base.py プロジェクト: stmnk/haystack
    def get_all_labels_aggregated(
        self,
        index: Optional[str] = None,
        filters: Optional[Dict[str, List[str]]] = None,
        open_domain: bool = True,
        aggregate_by_meta: Optional[Union[str,
                                          list]] = None) -> List[MultiLabel]:
        """
        Return all labels in the DocumentStore, aggregated into MultiLabel objects. 
        This aggregation step helps, for example, if you collected multiple possible answers for one question and you
        want now all answers bundled together in one place for evaluation.
        How they are aggregated is defined by the open_domain and aggregate_by_meta parameters.
        If the questions are being asked to a single document (i.e. SQuAD style), you should set open_domain=False to aggregate by question and document.
        If the questions are being asked to your full collection of documents, you should set open_domain=True to aggregate just by question.
        If the questions are being asked to a subslice of your document set (e.g. product review use cases),
        you should set open_domain=True and populate aggregate_by_meta with the names of Label meta fields to aggregate by question and your custom meta fields.
        For example, in a product review use case, you might set aggregate_by_meta=["product_id"] so that Labels
        with the same question but different answers from different documents are aggregated into the one MultiLabel
        object, provided that they have the same product_id (to be found in Label.meta["product_id"])

        :param index: Name of the index to get the labels from. If None, the
                      DocumentStore's default index (self.index) will be used.
        :param filters: Optional filters to narrow down the labels to return.
                        Example: {"name": ["some", "more"], "category": ["only_one"]}
        :param open_domain: When True, labels are aggregated purely based on the question text alone.
                            When False, labels are aggregated in a closed domain fashion based on the question text
                            and also the id of the document that the label is tied to. In this setting, this function
                            might return multiple MultiLabel objects with the same question string.
        :param aggregate_by_meta: The names of the Label meta fields by which to aggregate. For example: ["product_id"]

        """
        aggregated_labels = []
        all_labels = self.get_all_labels(index=index, filters=filters)

        # Collect all answers to a question in a dict
        question_ans_dict: dict = {}
        for l in all_labels:
            # This group_by_id determines the key by which we aggregate labels. Its contents depend on
            # whether we are in an open / closed domain setting,
            # or if there are fields in the meta data that we should group by (set using group_by_meta)
            group_by_id_list: list = []
            if open_domain:
                group_by_id_list = [l.question]
            else:
                group_by_id_list = [l.document_id, l.question]
            if aggregate_by_meta:
                if type(aggregate_by_meta) == str:
                    aggregate_by_meta = [aggregate_by_meta]
                for meta_key in aggregate_by_meta:
                    curr_meta = l.meta.get(meta_key, None)
                    if curr_meta:
                        group_by_id_list.append(curr_meta)
            group_by_id = tuple(group_by_id_list)

            # only aggregate labels with correct answers, as only those can be currently used in evaluation
            if not l.is_correct_answer:
                continue

            if group_by_id in question_ans_dict:
                question_ans_dict[group_by_id].append(l)
            else:
                question_ans_dict[group_by_id] = [l]

        # Aggregate labels
        for q, ls in question_ans_dict.items():
            ls = list(set(ls))  # get rid of exact duplicates
            # check if there are both text answer and "no answer" present
            t_present = False
            no_present = False
            no_idx = []
            for idx, l in enumerate(ls):
                if len(l.answer) == 0:
                    no_present = True
                    no_idx.append(idx)
                else:
                    t_present = True
            # if both text and no answer are present, remove no answer labels
            if t_present and no_present:
                logger.warning(
                    f"Both text label and 'no answer possible' label is present for question: {ls[0].question}"
                )
                for remove_idx in no_idx[::-1]:
                    ls.pop(remove_idx)

            # construct Aggregated_label
            for i, l in enumerate(ls):
                if i == 0:
                    # Keep only the label metadata that we are aggregating by
                    if aggregate_by_meta:
                        meta_new = {
                            k: v
                            for k, v in l.meta.items()
                            if k in aggregate_by_meta
                        }
                    else:
                        meta_new = {}

                    agg_label = MultiLabel(
                        question=l.question,
                        multiple_answers=[l.answer],
                        is_correct_answer=l.is_correct_answer,
                        is_correct_document=l.is_correct_document,
                        origin=l.origin,
                        multiple_document_ids=[l.document_id],
                        multiple_offset_start_in_docs=[l.offset_start_in_doc],
                        no_answer=l.no_answer,
                        model_id=l.model_id,
                        meta=meta_new)
                else:
                    agg_label.multiple_answers.append(l.answer)
                    agg_label.multiple_document_ids.append(l.document_id)
                    agg_label.multiple_offset_start_in_docs.append(
                        l.offset_start_in_doc)
            aggregated_labels.append(agg_label)
        return aggregated_labels