示例#1
0
    def make_marginal_drop_instance(
            question_tokens: List[Token],
            passage_tokens: List[Token],
            number_tokens: List[Token],
            number_indices: List[int],
            token_indexers: Dict[str, TokenIndexer],
            passage_text: str,
            answer_info: Dict[str, Any] = None,
            additional_metadata: Dict[str, Any] = None) -> Instance:
        additional_metadata = additional_metadata or {}
        fields: Dict[str, Field] = {}
        passage_offsets = [(token.idx, token.idx + len(token.text))
                           for token in passage_tokens]
        question_offsets = [(token.idx, token.idx + len(token.text))
                            for token in question_tokens]

        # This is separate so we can reference it later with a known type.
        fields["passage"] = TextField(passage_tokens, token_indexers)
        fields["question"] = TextField(question_tokens, token_indexers)
        number_index_fields = [
            IndexField(index, fields["passage"]) for index in number_indices
        ]
        fields["number_indices"] = ListField(number_index_fields)
        # This field is actually not required in the model,
        # it is used to create the `answer_as_plus_minus_combinations` field, which is a `SequenceLabelField`.
        # We cannot use `number_indices` field for creating that, because the `ListField` will not be empty
        # when we want to create a new empty field. That will lead to error.
        fields["numbers_in_passage"] = TextField(number_tokens, token_indexers)
        metadata = {
            "original_passage": passage_text,
            "passage_token_offsets": passage_offsets,
            "question_token_offsets": question_offsets,
            "question_tokens": [token.text for token in question_tokens],
            "passage_tokens": [token.text for token in passage_tokens],
            "number_tokens": [token.text for token in number_tokens],
            "number_indices": number_indices
        }
        if answer_info:
            metadata["answer_texts"] = answer_info["answer_texts"]

            passage_span_fields = \
                [SpanField(span[0], span[1], fields["passage"]) for span in answer_info["answer_passage_spans"]]
            if not passage_span_fields:
                passage_span_fields.append(SpanField(-1, -1,
                                                     fields["passage"]))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

            question_span_fields = \
                [SpanField(span[0], span[1], fields["question"]) for span in answer_info["answer_question_spans"]]
            if not question_span_fields:
                question_span_fields.append(
                    SpanField(-1, -1, fields["question"]))
            fields["answer_as_question_spans"] = ListField(
                question_span_fields)

            add_sub_signs_field = []
            for signs_for_one_add_sub_expression in answer_info[
                    "signs_for_add_sub_expressions"]:
                add_sub_signs_field.append(
                    SequenceLabelField(signs_for_one_add_sub_expression,
                                       fields["numbers_in_passage"]))
            if not add_sub_signs_field:
                add_sub_signs_field.append(
                    SequenceLabelField([0] * len(fields["numbers_in_passage"]),
                                       fields["numbers_in_passage"]))
            fields["answer_as_add_sub_expressions"] = ListField(
                add_sub_signs_field)

            count_fields = [
                LabelField(count_label, skip_indexing=True)
                for count_label in answer_info["counts"]
            ]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)

        metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)
        return Instance(fields)
示例#2
0
 def test_span_field_raises_on_ill_defined_span(self):
     with pytest.raises(ValueError):
         _ = SpanField(4, 1, self.text)
示例#3
0
 def test_empty_span_field_works(self):
     span_field = SpanField(1, 3, self.text)
     empty_span = span_field.empty_field()
     assert empty_span.span_start == -1
     assert empty_span.span_end == -1
示例#4
0
    def text_to_instance(
            self,  # type: ignore
            sentences: List[List[str]],
            document_id: str,
            sentence_id: int,
            gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
            user_threshold: Optional[float] = 0.0) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        document_id : ``str``, required.
            A string representing the document ID.
        sentence_id : ``int``, required.
            An int representing the sentence ID.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        user_threshold: ``Optional[float]``, optional (default = 0.0)
            approximate % of gold labels to label to hold out as user input.
            EX = 0.5, 0.33, 0.25, 0.125

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]

        metadata: Dict[str, Any] = {
            "original_text": flattened_sentences,
            "ID": document_id + ";" + str(sentence_id)
        }
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
            metadata["num_gold_clusters"] = len(gold_clusters)

        text_field = TextField([Token(word) for word in flattened_sentences],
                               self._token_indexers)

        user_threshold_mod = int(
            1 / user_threshold
        ) if self._simulate_user_inputs and user_threshold > 0 else 0
        cluster_dict = {}
        simulated_user_cluster_dict = {}

        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for i in range(len(cluster)):
                    # use modulo to have a relatively even distribution of user labels across length of document,
                    # (since clusters are sorted)--so user simulated clusters are spread evenly across document
                    if user_threshold_mod == 0 or i % user_threshold_mod != user_threshold_mod - 1:
                        cluster_dict[tuple(cluster[i])] = cluster_id
                    simulated_user_cluster_dict[tuple(cluster[i])] = cluster_id

        # Note simulated_user_cluster_dict encompasses ALL gold labels, including those in cluster_dict
        # Consequently user_labels encompasses all gold labels
        spans: List[Field] = []
        if gold_clusters is not None:
            span_labels: Optional[List[int]] = []
            user_labels: Optional[List[
                int]] = [] if self._simulate_user_inputs and user_threshold > 0 else None
        else:
            span_labels = user_labels = None

        # our must-link and cannot-link constraints, derived from user labels
        # using gold_clusters being None as an indicator of whether we're running training or not
        must_link: Optional[
            List[int]] = [] if gold_clusters is not None else None
        cannot_link: Optional[
            List[int]] = [] if gold_clusters is not None else None

        sentence_offset = 0
        doc_info = None
        if self._saved_labels is not None and metadata[
                'ID'] in self._saved_labels:
            doc_info = self._saved_labels[metadata['ID']]
            span_labels = doc_info['span_labels'].tolist()
            if 'must_link' in doc_info:
                must_link = doc_info['must_link'].squeeze(-1).tolist()
                cannot_link = doc_info['cannot_link'].squeeze(-1).tolist()
        for sentence in sentences:
            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if span_labels is not None:
                    if doc_info is None:
                        # only do if we haven't already loaded span labels
                        if (start, end) in cluster_dict:
                            span_labels.append(cluster_dict[(start, end)])
                        else:
                            span_labels.append(-1)
                    if self._simulate_user_inputs and user_threshold > 0:
                        if (start, end) in simulated_user_cluster_dict:
                            user_labels.append(
                                simulated_user_cluster_dict[(start, end)])
                        else:
                            user_labels.append(-1)

                spans.append(SpanField(start, end, text_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }

        if must_link is not None and len(must_link) > 0:
            must_link_field = []
            cannot_link_field = []
            for link in must_link:
                must_link_field.append(
                    PairField(
                        IndexField(link[0], span_field),
                        IndexField(link[1], span_field),
                    ))
            for link in cannot_link:
                cannot_link_field.append(
                    PairField(
                        IndexField(link[0], span_field),
                        IndexField(link[1], span_field),
                    ))
            must_link_field = ListField(must_link_field)
            cannot_link_field = ListField(cannot_link_field)
            fields["must_link"] = must_link_field
            fields["cannot_link"] = cannot_link_field

        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)
            if user_labels is not None:
                fields["user_labels"] = SequenceLabelField(
                    user_labels, span_field)

        # sanity checks
        if doc_info is not None:
            assert (fields["span_labels"].as_tensor(
                fields["span_labels"].get_padding_lengths()) !=
                    doc_info['span_labels']).nonzero().size(0) == 0
            if 'must_link' in doc_info:
                assert 'must_link' in fields
                assert (fields["must_link"].as_tensor(
                    fields["must_link"].get_padding_lengths()) !=
                        doc_info['must_link']).nonzero().size(0) == 0
                assert (fields["cannot_link"].as_tensor(
                    fields["cannot_link"].get_padding_lengths()) !=
                        doc_info['cannot_link']).nonzero().size(0) == 0

        return Instance(fields)
示例#5
0
    def text_to_instance(self,  # type: ignore
                         sentences: List[List[str]],
                         gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
                         speaker_ids: Optional[List[int]] = None,
                         genre: Optional[int] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        speaker_ids : ``Optional[List[int]]``, optional (default = None)
            A list that maps each gold cluster with a speaker. Each speaker from the text
            is given a unique id.
        genre : ``Optional[int]``, optional (default = None)
            A unique id that represents the genre of the text
        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [self._normalize_word(word)
                               for sentence in sentences
                               for word in sentence]

        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
        if speaker_ids is not None:
            metadata["speaker_ids"] = speaker_ids
        if genre is not None:
            metadata["genre"] = genre

        text_field = TextField([Token(word) for word in flattened_sentences], self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[List[int]] = [] if gold_clusters is not None else None

        sentence_offset = 0
        for sentence in sentences:
            for start, end in enumerate_spans(sentence,
                                              offset=sentence_offset,
                                              max_span_width=self._max_span_width):
                if span_labels is not None:
                    if (start, end) in cluster_dict:
                        span_labels.append(cluster_dict[(start, end)])
                    else:
                        span_labels.append(-1)

                spans.append(SpanField(start, end, text_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {"text": text_field,
                                    "spans": span_field,
                                    "metadata": metadata_field}
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
示例#6
0
 def _make_span_field(cls, s, text_field, offset=1):
     return SpanField(s[0] + offset, s[1] - 1 + offset, text_field)
示例#7
0
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         numbers_in_passage: List[Any],
                         number_words: List[str],
                         number_indices: List[int],
                         number_len: List[int],
                         question_id: str = None,
                         answer_annotations: List[List[str]] = None,
                         dataset: str = None) -> Union[Instance, None]:
        # Tokenize question and passage
        question_tokens = self.tokenizer.tokenize(question_text)
        qlen = len(question_tokens)
        plen = len(passage_tokens)

        question_passage_tokens = [Token('[CLS]')] + question_tokens + [
            Token('[SEP]')
        ] + passage_tokens
        if len(question_passage_tokens) > self.max_pieces - 1:
            question_passage_tokens = question_passage_tokens[:self.
                                                              max_pieces - 1]
            passage_tokens = passage_tokens[:self.max_pieces - qlen - 3]
            plen = len(passage_tokens)
            if len(number_indices) > 0:
                number_indices, number_len, numbers_in_passage = \
                    clipped_passage_num(number_indices, number_len, numbers_in_passage, plen)

        question_passage_tokens += [Token('[SEP]')]
        number_indices = [index + qlen + 2 for index in number_indices] + [-1]
        # Not done in-place so they won't change the numbers saved for the passage
        number_len = number_len + [1]
        numbers_in_passage = numbers_in_passage + [0]
        number_tokens = [Token(str(number)) for number in numbers_in_passage]
        extra_number_tokens = [Token(str(num)) for num in self.extra_numbers]

        mask_indices = [0, qlen + 1, len(question_passage_tokens) - 1]

        fields: Dict[str, Field] = {}

        # Add feature fields
        question_passage_field = TextField(question_passage_tokens,
                                           self.token_indexers)
        fields["question_passage"] = question_passage_field

        number_token_indices = \
            [ArrayField(np.arange(start_ind, start_ind + number_len[i]), padding_value=-1)
             for i, start_ind in enumerate(number_indices)]
        fields["number_indices"] = ListField(number_token_indices)
        numbers_in_passage_field = TextField(number_tokens,
                                             self.token_indexers)
        extra_numbers_field = TextField(extra_number_tokens,
                                        self.token_indexers)
        all_numbers_field = TextField(extra_number_tokens + number_tokens,
                                      self.token_indexers)
        mask_index_fields: List[Field] = [
            IndexField(index, question_passage_field) for index in mask_indices
        ]
        fields["mask_indices"] = ListField(mask_index_fields)

        # Compile question, passage, answer metadata
        metadata = {
            "original_passage": passage_text,
            "original_question": question_text,
            "original_numbers": numbers_in_passage,
            "original_number_words": number_words,
            "extra_numbers": self.extra_numbers,
            "passage_tokens": passage_tokens,
            "question_tokens": question_tokens,
            "question_passage_tokens": question_passage_tokens,
            "question_id": question_id,
            "dataset": dataset
        }

        if answer_annotations:
            answer_texts = answer_annotations[0]
            answer_type = "span"
            tokenized_answer_texts = []
            num_spans = min(len(answer_texts), self.max_spans)
            for answer_text in answer_texts:
                answer_tokens = self.tokenizer.tokenize(answer_text)
                tokenized_answer_texts.append(' '.join(
                    token.text for token in answer_tokens))

            metadata["answer_annotations"] = answer_annotations
            metadata["answer_texts"] = answer_texts
            metadata["answer_tokens"] = tokenized_answer_texts

            # Find answer text in question and passage
            valid_question_spans = DropReader.find_valid_spans(
                question_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_question_spans):
                valid_question_spans[span_ind] = (span[0] + 1, span[1] + 1)
            valid_passage_spans = DropReader.find_valid_spans(
                passage_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_passage_spans):
                valid_passage_spans[span_ind] = (span[0] + qlen + 2,
                                                 span[1] + qlen + 2)

            # Get target numbers
            target_numbers = []
            for answer_text in answer_texts:
                if answer_text.strip().count(" ") == 0:
                    number = self.word_to_num(answer_text, True)
                    if number is not None:
                        target_numbers.append(number)

            # Get possible ways to arrive at target numbers with add/sub

            valid_expressions: List[List[int]] = []
            exp_strings = None
            if answer_type in ["number", "date"]:
                if self.exp_search == 'full':
                    expressions = get_full_exp(
                        list(enumerate(self.extra_numbers +
                                       numbers_in_passage)), target_numbers,
                        self.operations, self.op_dict, self.max_depth)
                    zipped = list(zip(*expressions))
                    if zipped:
                        valid_expressions = list(zipped[0])
                        exp_strings = list(zipped[1])
                elif self.exp_search == 'add_sub':
                    valid_expressions = \
                        DropReader.find_valid_add_sub_expressions(self.extra_numbers + numbers_in_passage,
                                                                  target_numbers,
                                                                  self.max_numbers_expression)
                elif self.exp_search == 'template':
                    valid_expressions, exp_strings = \
                        get_template_exp(self.extra_numbers + numbers_in_passage,
                                         target_numbers,
                                         self.templates,
                                         self.template_strings)
                    exp_strings = sum(exp_strings, [])

            # Get possible ways to arrive at target numbers with counting
            valid_counts: List[int] = []
            if answer_type in ["number"]:
                numbers_for_count = list(range(self.max_count + 1))
                valid_counts = DropReader.find_valid_counts(
                    numbers_for_count, target_numbers)

            # Update metadata with answer info
            answer_info = {
                "answer_passage_spans": valid_passage_spans,
                "answer_question_spans": valid_question_spans,
                "num_spans": num_spans,
                "expressions": valid_expressions,
                "counts": valid_counts
            }
            if self.exp_search in ['template', 'full']:
                answer_info['expr_text'] = exp_strings
            metadata["answer_info"] = answer_info

            # Add answer fields
            passage_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_passage_field)
                for span in valid_passage_spans
            ]
            if not passage_span_fields:
                passage_span_fields.append(
                    SpanField(-1, -1, question_passage_field))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

            question_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_passage_field)
                for span in valid_question_spans
            ]
            if not question_span_fields:
                question_span_fields.append(
                    SpanField(-1, -1, question_passage_field))
            fields["answer_as_question_spans"] = ListField(
                question_span_fields)

            if self.exp_search == 'add_sub':
                add_sub_signs_field: List[Field] = []
                extra_signs_field: List[Field] = []
                for signs_for_one_add_sub_expressions in valid_expressions:
                    extra_signs = signs_for_one_add_sub_expressions[:len(
                        self.extra_numbers)]
                    normal_signs = signs_for_one_add_sub_expressions[
                        len(self.extra_numbers):]
                    add_sub_signs_field.append(
                        SequenceLabelField(normal_signs,
                                           numbers_in_passage_field))
                    extra_signs_field.append(
                        SequenceLabelField(extra_signs, extra_numbers_field))
                if not add_sub_signs_field:
                    add_sub_signs_field.append(
                        SequenceLabelField([0] * len(number_tokens),
                                           numbers_in_passage_field))
                if not extra_signs_field:
                    extra_signs_field.append(
                        SequenceLabelField([0] * len(self.extra_numbers),
                                           extra_numbers_field))
                fields["answer_as_expressions"] = ListField(
                    add_sub_signs_field)
                if self.extra_numbers:
                    fields["answer_as_expressions_extra"] = ListField(
                        extra_signs_field)
            elif self.exp_search in ['template', 'full']:
                expression_indices = []
                for expression in valid_expressions:
                    if not expression:
                        expression.append(3 * [-1])
                    expression_indices.append(
                        ArrayField(np.array(expression), padding_value=-1))
                if not expression_indices:
                    expression_indices = \
                        [ArrayField(np.array([3 * [-1]]), padding_value=-1) for _ in range(len(self.templates))]
                fields["answer_as_expressions"] = ListField(expression_indices)

            count_fields: List[Field] = [
                LabelField(count_label, skip_indexing=True)
                for count_label in valid_counts
            ]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)
            fields["impossible_answer"] = LabelField(0, skip_indexing=True)

            #fields["num_spans"] = LabelField(num_spans, skip_indexing=True)

        else:
            fields["answer_as_passage_spans"] = ListField(
                [SpanField(-1, -1, question_passage_field)])
            fields["answer_as_counts"] = ListField(
                [LabelField(-1, skip_indexing=True)])
            fields["answer_as_expressions"] = ListField([
                SequenceLabelField([0] * len(numbers_in_passage_field),
                                   numbers_in_passage_field)
            ])
            fields["impossible_answer"] = LabelField(1, skip_indexing=True)
            metadata["answer_annotations"] = [{'spans': [""]}]
            fields["answer_as_question_spans"] = ListField(
                [SpanField(-1, -1, question_passage_field)])

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
示例#8
0
    def text_to_instance(
            self,  # type: ignore
            tokens: List[str],
            pos_tags: List[str],
            gold_tree: Tree = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags ``List[str]``, required.
            The pos tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The pos tags of the words in the sentence.
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constiutency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
        """
        # pylint: disable=arguments-differ
        text_field = TextField([Token(x) for x in tokens],
                               token_indexers=self._token_indexers)
        pos_tag_field = SequenceLabelField(pos_tags, text_field, "pos_tags")
        fields = {"tokens": text_field, "pos_tags": pos_tag_field}
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags)
            gold_spans = {
                span: label
                for (span, label) in gold_spans_with_pos_tags.items()
                if "-POS" not in label
            }
        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in gold_spans.keys():
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append("NO-LABEL")

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(gold_labels,
                                                       span_list_field)

        return Instance(fields)
示例#9
0
 def test_as_tensor_converts_span_field_correctly(self):
     span_field = SpanField(2, 3, self.text)
     tensor = span_field.as_tensor(span_field.get_padding_lengths()).detach().cpu().numpy()
     numpy.testing.assert_array_equal(tensor, numpy.array([2, 3]))
示例#10
0
    def text_to_instance(
        self,  # type: ignore
        tokens: List[Token],
        #pos_tags: List[str] = None,
        #chunk_tags: List[str] = None,
        ner_tags: List[str] = None
    ) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
        """
        sequence = TextField(tokens, self._token_indexers)
        instance_fields: Dict[str, Field] = {'tokens': sequence}

        def _remove_BI(_one_tag):
            if _one_tag == 'O':
                return _one_tag
            else:
                return _one_tag[2:]

        if self.coding_scheme == "BIOUL":
            coded_ner = to_bioul(ner_tags,
                                 encoding=self._original_coding_scheme
                                 ) if ner_tags is not None else None
        else:
            # the default IOB1
            coded_ner = ner_tags

        # TODO:
        # ner_tags -> spans of NE
        # return something like spans, span_labels ("O" if span not in golden_spans, "PER", "LOC"... otherwise)
        spans: List[Field] = []
        span_labels: List[str] = []

        gold_spans: List[Field] = []
        gold_span_labels: List[str] = []

        assert len(ner_tags) == len(tokens), "sentence:%s but ner_tags:%s" % (
            str(tokens), str(ner_tags))
        ner_gold_spans = _extract_spans(
            ner_tags
        )  # ner_gold_spans: Dict[tuple(startid, endid), str(entity_type)]
        #print(self._max_span_width), exit(0)
        for start, end in enumerate_spans(ner_tags,
                                          offset=0,
                                          max_span_width=self._max_span_width):
            span_labels.append(ner_gold_spans.get((start, end), 'O'))
            spans.append(SpanField(start, end, sequence))
            pass

        _dict_gold_spans = {}
        for ky, val in ner_gold_spans.items():
            gold_span_labels.append(val)
            gold_spans.append(SpanField(ky[0], ky[1], sequence))
            if val != 'O':
                _dict_gold_spans[ky] = val
            pass

        instance_fields["metadata"] = MetadataField({
            "words": [x.text for x in tokens],
            "gold_spans":
            _dict_gold_spans
        })

        assert len(spans) == len(
            span_labels), "span length not equal to span label length..."
        span_field = ListField(spans)  # a list of (start, end) tuples...

        # contains all possible spans and there tags
        instance_fields['spans'] = span_field
        instance_fields['span_labels'] = SequenceLabelField(
            span_labels, span_field, "span_tags")

        # only contain gold_spans and there tags
        # e.g. (0,0,O), (1,1,O), (2,3,PER), (4,4,O) for 'I am Donald Trump .'
        gold_span_field = ListField(gold_spans)
        instance_fields['gold_spans'] = gold_span_field
        instance_fields['gold_span_labels'] = SequenceLabelField(
            gold_span_labels, gold_span_field, "span_tags")

        # Add "feature labels" to instance
        if 'ner' in self.feature_labels:
            if coded_ner is None:
                raise ConfigurationError(
                    "Dataset reader was specified to use NER tags as "
                    " features. Pass them to text_to_instance.")
            instance_fields['ner_tags'] = SequenceLabelField(
                coded_ner, sequence, "token_tags")

        # Add "tag label" to instance
        if self.tag_label == 'ner' and coded_ner is not None:
            instance_fields['tags'] = SequenceLabelField(
                coded_ner, sequence, 'token_tags')

        return Instance(instance_fields)
示例#11
0
文件: dygie.py 项目: MSLars/mare
    def _process_sentence(self, sent: Sentence, dataset: str):
        # Get the sentence text and define the `text_field`.
        sentence_text = [self._normalize_word(word) for word in sent.text]
        text_field = TextField([Token(word) for word in sentence_text],
                               self._token_indexers)

        # Enumerate trigger spans.
        trigger_spans = []
        for start, end in enumerate_spans(
                sentence_text, max_span_width=self._max_trigger_span_width):
            trigger_spans.append(SpanField(start, end, text_field))
        trigger_span_field = ListField(trigger_spans)
        trigger_span_tuples = [(span.span_start, span.span_end)
                               for span in trigger_spans]

        # Enumerate spans.
        spans = []
        for start, end in enumerate_spans(sentence_text,
                                          max_span_width=self._max_span_width):
            spans.append(SpanField(start, end, text_field))
        span_field = ListField(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]

        # Convert data to fields.
        # NOTE: The `ner_labels` and `coref_labels` would ideally have type
        # `ListField[SequenceLabelField]`, where the sequence labels are over the `SpanField` of
        # `spans`. But calling `as_tensor_dict()` fails on this specific data type. Matt G
        # recognized that this is an AllenNLP API issue and suggested that represent these as
        # `ListField[ListField[LabelField]]` instead.
        fields = {}
        fields["text"] = text_field
        fields["trigger_spans"] = trigger_span_field
        fields["spans"] = span_field

        if sent.ner is not None:
            ner_labels = self._process_ner(span_tuples, sent)
            fields["ner_labels"] = ListField([
                LabelField(entry, label_namespace=f"{dataset}__ner_labels")
                for entry in ner_labels
            ])
        if sent.cluster_dict is not None:
            # Skip indexing for coref labels, which are ints.
            coref_labels = self._process_coref(span_tuples, sent)
            fields["coref_labels"] = ListField([
                LabelField(entry,
                           label_namespace="coref_labels",
                           skip_indexing=True) for entry in coref_labels
            ])
        if sent.relations is not None:
            relation_labels, relation_indices = self._process_relations(
                span_tuples, sent)
            fields["relation_labels"] = AdjacencyField(
                indices=relation_indices,
                sequence_field=span_field,
                labels=relation_labels,
                label_namespace=f"{dataset}__relation_labels")
        if sent.events is not None:
            trigger_labels, argument_labels, argument_indices = self._process_events(
                trigger_span_tuples, span_tuples, sent)
            fields["trigger_labels"] = ListField([
                LabelField(entry, label_namespace=f"{dataset}__trigger_labels")
                for entry in trigger_labels
            ])
            fields["argument_labels"] = AdjacencyFieldAssym(
                indices=argument_indices,
                row_field=trigger_span_field,
                col_field=span_field,
                labels=argument_labels,
                label_namespace=f"{dataset}__argument_labels")

        return fields
    def text_to_instance(
            self,  # type: ignore
            words: List[str],
            upos_tags: List[str],
            dependencies: List[Tuple[str, int]] = None,
            entities: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : ``List[str]``, required.
            The words in the sentence to be encoded.
        upos_tags : ``List[str]``, required.
            The universal dependencies POS tags for each word.
        dependencies : ``List[Tuple[str, int]]``, optional (default = None)
            A list of  (head tag, head index) tuples. Indices are 1 indexed,
            meaning an index of 0 corresponds to that word being the root of
            the dependency tree.
        Returns
        -------
        An instance containing words, upos tags, dependency head tags and head
        indices as fields.
        """
        fields: Dict[str, Field] = {}

        # if self.tokenizer is not None:
        #     tokens = self.tokenizer.tokenize(" ".join(words))
        # else:
        #     tokens = [Token(t) for t in words]

        characters = [c for word in words for c in word]

        characters = [Token(c) for c in characters]
        character_field = TextField(characters, self._token_indexers)

        spans = []
        start = 0
        for word in words:
            spans.append(
                SpanField(start, start + len(word) - 1, character_field))
            start += len(word)
        character_span_field = ListField(spans)
        fields["character_spans"] = character_span_field

        # text_field = TextField(tokens, self._token_indexers)
        fields["characters"] = character_field

        fields["pos_tags"] = SequenceLabelField(upos_tags,
                                                character_span_field,
                                                label_namespace="pos")
        if dependencies is not None:
            # We don't want to expand the label namespace with an additional dummy token, so we'll
            # always give the 'ROOT_HEAD' token a label of 'root'.
            fields["head_tags"] = SequenceLabelField(
                [x[0] for x in dependencies],
                character_span_field,
                label_namespace="head_tags")
            fields["head_indices"] = SequenceLabelField(
                [int(x[1]) for x in dependencies],
                character_span_field,
                label_namespace="head_index_tags")

        fields["metadata"] = MetadataField({
            "words": words,
            "pos": upos_tags,
            "entities": entities
        })
        return Instance(fields)
示例#13
0
    def text_to_instance(
        self,  # type: ignore
        sentences: List[List[str]],
        gold_clusters: Optional[List[List[Tuple[int,
                                                int]]]] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]
        # align clusters
        gold_clusters = self.align_clusters_to_tokens(flattened_sentences,
                                                      gold_clusters)

        def tokenizer(s: str):
            return self.token_indexer.wordpiece_tokenizer(s)

        # we nee dto try this with the other one.
        flattened_sentences = tokenizer(" ".join(flattened_sentences))
        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters

        text_field = TextField([Token(["[CLS]"])] +
                               [Token(word) for word in flattened_sentences] +
                               [Token(["[SEP]"])], self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None
        sentence_offset = 0
        normal = []
        for sentence in sentences:
            # enumerate the spans.
            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if span_labels is not None:
                    if (start, end) in cluster_dict:
                        span_labels.append(cluster_dict[(start, end)])
                    else:
                        span_labels.append(-1)
                # align the spans to the BERT tokeniation
                normal.append((start, end))
                span_field = TextField(
                    [Token(["[CLS]"])] +
                    [Token(word)
                     for word in flattened_sentences] + [Token(["[SEP]"])],
                    self._token_indexers)
                # span field for Span, which needs to be a flattened esnetnece.
                spans.append(SpanField(start, end, span_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)
        return Instance(fields)
示例#14
0
    def predict_json(self, inputs: JsonDict, cuda_device: int = 0) -> JsonDict:

        instances, results, words, verb_indexes = self._sentence_to_qasrl_instances(
            inputs)

        # Expand vocab
        cleansed_words = cleanse_sentence_text(words)
        added_words = []
        added_vectors = []
        for w in cleansed_words:
            w = w.lower()
            if self._model_vocab.get_token_index(
                    w) == 1 and w in self._pretrained_vectors:
                added_words.append(w)
                added_vectors.append(self._pretrained_vectors[w])
        if added_words:
            first_ind = self._model_vocab.get_vocab_size("tokens")
            for w in added_words:
                self._model_vocab.add_token_to_namespace(w, "tokens")

            num_added_words = len(added_words)
            added_weights = torch.cat(added_vectors, dim=0)

            span_weights = self._model.span_detector.text_field_embedder.token_embedder_tokens.weight.data
            num_words, embsize = span_weights.size()
            new_weights = span_weights.new().resize_(
                num_words + num_added_words, embsize)
            new_weights[:num_words].copy_(span_weights)
            new_weights[num_words:].copy_(
                torch.reshape(
                    added_weights,
                    (added_weights.shape[0] / new_weights[num_words:].shape[1],
                     added_weights.shape[0] /
                     new_weights[num_words:].shape[0])))
            self._model.span_detector.text_field_embedder.token_embedder_tokens.weight = Parameter(
                new_weights)

            ques_weights = self._model.question_predictor.text_field_embedder.token_embedder_tokens.weight.data
            num_words, embsize = ques_weights.size()
            new_weights = ques_weights.new().resize_(
                num_words + num_added_words, embsize)
            new_weights[:num_words].copy_(ques_weights)
            new_weights[num_words:].copy_(
                torch.reshape(
                    added_weights,
                    (added_weights.shape[0] / new_weights[num_words:].shape[1],
                     added_weights.shape[0] /
                     new_weights[num_words:].shape[0])))
            self._model.question_predictor.text_field_embedder.token_embedder_tokens.weight = Parameter(
                new_weights)

        verbs_for_instances = results["verbs"]
        results["verbs"] = []

        span_outputs = self._model.span_detector.forward_on_instances(
            instances)

        instances_with_spans = []
        instance_spans = []
        for instance, span_output in zip(instances, span_outputs):
            field_dict = instance.fields
            text_field = field_dict['text']

            spans = [s[0] for s in span_output['spans'] if s[1] >= 0.5]
            instance_spans.append(spans)

            labeled_span_field = ListField([
                SpanField(span.start(), span.end(), text_field)
                for span in spans
            ])
            field_dict['labeled_spans'] = labeled_span_field
            instances_with_spans.append(Instance(field_dict))

        outputs = self._model.question_predictor.forward_on_instances(
            instances)

        for output, spans, verb, index in zip(outputs, instance_spans,
                                              verbs_for_instances,
                                              verb_indexes):
            questions = {}
            for question, span in zip(output['questions'], spans):
                question_text = self.make_question_text(question, verb)
                span_text = " ".join(
                    [words[i] for i in range(span.start(),
                                             span.end() + 1)])
                questions.setdefault(question_text, []).append(span_text)

            qa_pairs = []
            for question, spans in questions.items():
                qa_pairs.append({"question": question, "spans": spans})

            results["verbs"].append({
                "verb": verb,
                "qa_pairs": qa_pairs,
                "index": index
            })

        return results
示例#15
0
    def text_to_instance(
        self,  # type: ignore
        sentence: List[Token],
        gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
    ) -> Instance:
        """
        # Parameters

        sentence : `List[Token]`, required.
            The already tokenised sentence to analyse.
        gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None)
            A list of all clusters in the sentence, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.

        # Returns

        An `Instance` containing the following `Fields`:
            text : `TextField`
                The text of the full sentence.
            spans : `ListField[SpanField]`
                A ListField containing the spans represented as `SpanFields`
                with respect to the sentence text.
            span_labels : `SequenceLabelField`, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a `SequenceLabelField`
                 with respect to the `spans `ListField`.
        """
        metadata: Dict[str, Any] = {"original_text": sentence}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters

        text_field = TextField(sentence, self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None

        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append(SpanField(start, end, text_field))

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field,
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
示例#16
0
 def test_empty_span_field_works(self):
     span_field = SpanField(1, 3, self.text)
     empty_span = span_field.empty_field()
     assert empty_span.span_start == -1
     assert empty_span.span_end == -1
    def text_to_instance(
            self,  # type: ignore
            item_id: Any,
            question_text: str,
            choice_text_list: List[str],
            fact_text: str,
            answer_span: List[str],
            answer_relations: List[str],
            answer_starts: List[int] = None,
            answer_id: int = None,
            prefetched_sentences: Dict[str, List[str]] = None,
            prefetched_indices: str = None) -> Instance:
        fields: Dict[str, Field] = {}
        question_tokens = self._tokenizer.tokenize(question_text)
        fact_tokens = self._tokenizer.tokenize(fact_text)
        choices_tokens_list = [
            self._tokenizer.tokenize(x) for x in choice_text_list
        ]
        choice_kb_fields = []
        selected_tuples = []
        for choice in choice_text_list:
            kb_fields = []

            if self._use_cskg and self._use_elastic_search:
                max_sents_per_source = int(self._max_tuples / 2)
            else:
                max_sents_per_source = self._max_tuples
            selected_hits = []
            if self._use_elastic_search:
                elastic_search_hits = self.get_elasticsearch_sentences(
                    prefetched_sentences, prefetched_indices, answer_span,
                    choice, question_text, fact_text, max_sents_per_source)
                selected_hits.extend(elastic_search_hits)

            if self._use_cskg:
                cskg_sentences = self.get_cskg_sentences(
                    fact_text, answer_span, choice, max_sents_per_source)
                selected_hits.extend(cskg_sentences)
            # add a dummy entry to capture the embedding link
            if self._ignore_spans:
                fact_choice_sentence = fact_text + " || " + choice
                selected_hits.append(fact_choice_sentence)
            else:
                for answer in set(answer_span):
                    answer_choice_sentence = answer + " || " + choice
                    selected_hits.append(answer_choice_sentence)

            selected_tuples.append(selected_hits)
            for hit_text in selected_hits:
                kb_fields.append(
                    TextField(self._tokenizer.tokenize(hit_text),
                              self._token_indexers))

            choice_kb_fields.append(ListField(kb_fields))

        fields["choice_kb"] = ListField(choice_kb_fields)
        fields['fact'] = TextField(fact_tokens, self._token_indexers)

        if self._add_relation_labels:
            if answer_relations and len(answer_relations):
                relation_fields = []
                for relation in set(answer_relations):
                    relation_fields.append(
                        LabelField(relation,
                                   label_namespace="relation_labels"))
                fields["relations"] = ListField(relation_fields)
                selected_relations = self.collate_relations(answer_relations)
                fields["relation_label"] = MultiLabelField(
                    selected_relations, "relation_labels")
            else:
                fields["relations"] = ListField([
                    LabelField(-1,
                               label_namespace="relation_labels",
                               skip_indexing=True)
                ])
                fields["relation_label"] = MultiLabelField([],
                                                           "relation_labels")

        answer_fields = []
        answer_span_fields = []
        fact_offsets = [(token.idx, token.idx + len(token.text))
                        for token in fact_tokens]

        for idx, answer in enumerate(answer_span):
            answer_fields.append(
                TextField(self._tokenizer.tokenize(answer),
                          self._token_indexers))
            if answer_starts:
                if len(answer_starts) <= idx:
                    raise ValueError("Only {} answer_starts in json. "
                                     "Expected {} in {}".format(
                                         len(answer_starts), len(answer_span),
                                         item_id))
                offset = answer_starts[idx]
            else:
                offset = fact_text.index(answer)
                if offset == -1:
                    raise ValueError("Span: {} not found in fact: {}".format(
                        answer, fact_text))

            tok_span, err = char_span_to_token_span(
                fact_offsets, (offset, offset + len(answer)))
            if err:
                logger.info("Could not find token spans for '{}' in '{}'."
                            "Best guess: {} in {} at {}".format(
                                answer, fact_text,
                                [offset, offset + len(answer)], fact_offsets,
                                tok_span))
            answer_span_fields.append(
                SpanField(tok_span[0], tok_span[1], fields['fact']))

        fields["answer_text"] = ListField(answer_fields)
        fields["answer_spans"] = ListField(answer_span_fields)
        fields['question'] = TextField(question_tokens, self._token_indexers)

        fields['choices_list'] = ListField(
            [TextField(x, self._token_indexers) for x in choices_tokens_list])
        if answer_id is not None:
            fields['answer_id'] = LabelField(answer_id, skip_indexing=True)

        metadata = {
            "id":
            item_id,
            "question_text":
            question_text,
            "fact_text":
            fact_text,
            "choice_text_list":
            choice_text_list,
            "question_tokens": [x.text for x in question_tokens],
            "fact_tokens": [x.text for x in fact_tokens],
            "choice_tokens_list":
            [[x.text for x in ct] for ct in choices_tokens_list],
            "answer_text":
            answer_span,
            "answer_start":
            answer_starts,
            "answer_span_fields":
            [(x.span_start, x.span_end) for x in answer_span_fields],
            "relations":
            answer_relations,
            "selected_tuples":
            selected_tuples
        }

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
示例#18
0
def make_coref_instance(
    sentences: List[List[str]],
    token_indexers: Dict[str, TokenIndexer],
    max_span_width: int,
    gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
    wordpiece_modeling_tokenizer: PretrainedTransformerTokenizer = None,
    max_sentences: int = None,
    remove_singleton_clusters: bool = True,
) -> Instance:
    """
    # Parameters

    sentences : `List[List[str]]`, required.
        A list of lists representing the tokenised words and sentences in the document.
    token_indexers : `Dict[str, TokenIndexer]`
        This is used to index the words in the document.  See :class:`TokenIndexer`.
    max_span_width : `int`, required.
        The maximum width of candidate spans to consider.
    gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None)
        A list of all clusters in the document, represented as word spans with absolute indices
        in the entire document. Each cluster contains some number of spans, which can be nested
        and overlap. If there are exact matches between clusters, they will be resolved
        using `_canonicalize_clusters`.
    wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = None)
        If not None, this dataset reader does subword tokenization using the supplied tokenizer
        and distribute the labels to the resulting wordpieces. All the modeling will be based on
        wordpieces. If this is set to `False` (default), the user is expected to use
        `PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`,
        and the modeling will be on the word-level.
    max_sentences: int, optional (default = None)
        The maximum number of sentences in each document to keep. By default keeps all sentences.
    remove_singleton_clusters : `bool`, optional (default = True)
        Some datasets contain clusters that are singletons (i.e. no coreferents). This option allows
        the removal of them.

    # Returns

    An `Instance` containing the following `Fields`:
        text : `TextField`
            The text of the full document.
        spans : `ListField[SpanField]`
            A ListField containing the spans represented as `SpanFields`
            with respect to the document text.
        span_labels : `SequenceLabelField`, optional
            The id of the cluster which each possible span belongs to, or -1 if it does
                not belong to a cluster. As these labels have variable length (it depends on
                how many spans we are considering), we represent this a as a `SequenceLabelField`
                with respect to the spans `ListField`.
    """
    if max_sentences is not None and len(sentences) > max_sentences:
        sentences = sentences[:max_sentences]
        total_length = sum(len(sentence) for sentence in sentences)

        if gold_clusters is not None:
            new_gold_clusters = []

            for cluster in gold_clusters:
                new_cluster = []
                for mention in cluster:
                    if mention[1] < total_length:
                        new_cluster.append(mention)
                if new_cluster:
                    new_gold_clusters.append(new_cluster)

            gold_clusters = new_gold_clusters

    flattened_sentences = [
        _normalize_word(word) for sentence in sentences for word in sentence
    ]

    if wordpiece_modeling_tokenizer is not None:
        flat_sentences_tokens, offsets = wordpiece_modeling_tokenizer.intra_word_tokenize(
            flattened_sentences)
        flattened_sentences = [t.text for t in flat_sentences_tokens]
    else:
        flat_sentences_tokens = [Token(word) for word in flattened_sentences]

    text_field = TextField(flat_sentences_tokens, token_indexers)

    cluster_dict = {}
    if gold_clusters is not None:
        gold_clusters = _canonicalize_clusters(gold_clusters)
        if remove_singleton_clusters:
            gold_clusters = [
                cluster for cluster in gold_clusters if len(cluster) > 1
            ]

        if wordpiece_modeling_tokenizer is not None:
            for cluster in gold_clusters:
                for mention_id, mention in enumerate(cluster):
                    start = offsets[mention[0]][0]
                    end = offsets[mention[1]][1]
                    cluster[mention_id] = (start, end)

        for cluster_id, cluster in enumerate(gold_clusters):
            for mention in cluster:
                cluster_dict[tuple(mention)] = cluster_id

    spans: List[Field] = []
    span_labels: Optional[
        List[int]] = [] if gold_clusters is not None else None

    sentence_offset = 0
    for sentence in sentences:
        for start, end in enumerate_spans(sentence,
                                          offset=sentence_offset,
                                          max_span_width=max_span_width):
            if wordpiece_modeling_tokenizer is not None:
                start = offsets[start][0]
                end = offsets[end][1]

                # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces
                # We have to do this check here because we use a span width embedding that has
                # only `max_span_width` entries, and since we are doing wordpiece
                # modeling, the span width embedding operates on wordpiece lengths. So a check
                # here is necessary or else we wouldn't know how many entries there would be.
                if end - start + 1 > max_span_width:
                    continue
                # We also don't generate spans that contain special tokens
                if start < wordpiece_modeling_tokenizer.num_added_start_tokens:
                    continue
                if (end >= len(flat_sentences_tokens) -
                        wordpiece_modeling_tokenizer.num_added_end_tokens):
                    continue

            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append(SpanField(start, end, text_field))
        sentence_offset += len(sentence)

    span_field = ListField(spans)

    metadata: Dict[str, Any] = {"original_text": flattened_sentences}
    if gold_clusters is not None:
        metadata["clusters"] = gold_clusters
    metadata_field = MetadataField(metadata)

    fields: Dict[str, Field] = {
        "text": text_field,
        "spans": span_field,
        "metadata": metadata_field,
    }
    if span_labels is not None:
        fields["span_labels"] = SequenceLabelField(span_labels, span_field)

    return Instance(fields)
示例#19
0
    def read_instances(self,
                       token_indexers: Dict[str, TokenIndexer],
                       sentence_id: str,
                       sentence_tokens: List[str],
                       verb_index: int,
                       verb_inflected_forms: Dict[str, str],
                       question_labels): # Iterable[Dict[str, ?Field]]
        verb_fields = get_verb_fields(token_indexers, sentence_tokens, verb_index)
        tan_strings = []
        tan_string_fields = []
        all_answer_fields = []
        if self._clause_info is not None:
            clause_string_fields = []
            clause_strings = []
            qarg_fields = []
            gold_tuples = []

        if len(question_labels) == 0:
            tan_string_list_field = ListField([LabelField(label = -1, label_namespace = "tan-string-labels", skip_indexing = True)])
            clause_string_list_field = ListField([LabelField(label = -1, label_namespace = "abst-clause-labels", skip_indexing = True)])
            qarg_list_field = ListField([LabelField(label = -1, label_namespace = "qarg-labels", skip_indexing = True)])
            answer_spans_field = ListField([ListField([SpanField(-1, -1, verb_fields["text"])])])
            num_answers_field = ListField([LabelField(-1, skip_indexing = True)])
            num_invalids_field = ListField([LabelField(-1, skip_indexing = True)])
        else:
            for question_label in question_labels:

                tan_string = get_tan_string(question_label)
                tan_string_field = LabelField(label = tan_string, label_namespace = "tan-string-labels")
                tan_strings.append(tan_string)
                tan_string_fields.append(tan_string_field)

                answer_fields = get_answer_fields(question_label, verb_fields["text"])
                all_answer_fields.append(answer_fields)

                if self._clause_info is not None:
                    clause_slots = {}
                    try:
                        clause_slots = self._clause_info[sentence_id][verb_index][question_label["questionString"]]["slots"]
                    except KeyError:
                        logger.info("Omitting instance without clause data: %s / %s / %s" % (sentence_id, verb_index, question_label["questionString"]))
                        continue

                    def abst_noun(x):
                        return "something" if (x == "someone") else x
                    clause_slots["abst-subj"] = abst_noun(clause_slots["subj"])
                    clause_slots["abst-verb"] = "verb[pss]" if question_label["isPassive"] else "verb"
                    clause_slots["abst-obj"] = abst_noun(clause_slots["obj"])
                    clause_slots["abst-prep1-obj"] = abst_noun(clause_slots["prep1-obj"])
                    clause_slots["abst-prep2-obj"] = abst_noun(clause_slots["prep2-obj"])
                    clause_slots["abst-misc"] = abst_noun(clause_slots["misc"])
                    abst_slot_names = ["abst-subj", "abst-verb", "abst-obj", "prep1", "abst-prep1-obj", "prep2", "abst-prep2-obj", "abst-misc"]
                    clause_string = " ".join([clause_slots[slot_name] for slot_name in abst_slot_names])
                    clause_string_field = LabelField(label = clause_string, label_namespace = "abst-clause-labels")
                    clause_strings.append(clause_string)
                    clause_string_fields.append(clause_string_field)

                    qarg_fields.append(LabelField(label = clause_slots["qarg"], label_namespace = "qarg-labels"))

                    for span_field in answer_fields["answer_spans"]:
                        if span_field.span_start > -1:
                            s = (span_field.span_start, span_field.span_end)
                            gold_tuples.append((clause_string, clause_slots["qarg"], s))

            tan_string_list_field = ListField(tan_string_fields)
            answer_spans_field = ListField([f["answer_spans"] for f in all_answer_fields])
            num_answers_field = ListField([f["num_answers"] for f in all_answer_fields])
            num_invalids_field = ListField([f["num_invalids"] for f in all_answer_fields])

            if self._clause_info is not None:
                clause_string_list_field = ListField(clause_string_fields)
                qarg_list_field = ListField(qarg_fields)

        if self._clause_info is not None:
            all_clause_strings = set(clause_strings)
            all_spans = set([t[2] for t in gold_tuples])
            all_qargs = set([t[1] for t in gold_tuples])
            qarg_pretrain_clause_fields = []
            qarg_pretrain_span_fields = []
            qarg_pretrain_multilabel_fields = []
            for clause_string in all_clause_strings:
                for span in all_spans:
                    valid_qargs = [qarg for qarg in all_qargs if (clause_string, qarg, span) in gold_tuples]
                    qarg_pretrain_clause_fields.append(LabelField(clause_string, label_namespace = "abst-clause-labels"))
                    qarg_pretrain_span_fields.append(SpanField(span[0], span[1], verb_fields["text"]))
                    qarg_pretrain_multilabel_fields.append(MultiLabelField_New(valid_qargs, label_namespace = "qarg-labels"))

            if len(qarg_pretrain_clause_fields) > 0:
                qarg_labeled_clauses_field = ListField(qarg_pretrain_clause_fields)
                qarg_labeled_spans_field = ListField(qarg_pretrain_span_fields)
                qarg_labels_field = ListField(qarg_pretrain_multilabel_fields)
            else:
                qarg_labeled_clauses_field = ListField([LabelField(-1, label_namespace = "abst-clause-labels", skip_indexing = True)])
                qarg_labeled_spans_field = ListField([SpanField(-1, -1, verb_fields["text"])])
                qarg_labels_field = ListField([MultiLabelField_New(set(), label_namespace = "qarg-labels")])

        tan_multilabel_field = MultiLabelField_New(list(set(tan_strings)), label_namespace = "tan-string-labels")

        if self._clause_info is not None:
            yield {
                **verb_fields,
                "clause_strings": clause_string_list_field,
                "clause_set": MultiLabelField_New(clause_strings, label_namespace = "abst-clause-labels"),
                "tan_strings": tan_string_list_field,
                "tan_set": tan_multilabel_field,
                "qargs": qarg_list_field,
                "answer_spans": answer_spans_field,
                "num_answers": num_answers_field,
                "num_invalids": num_invalids_field,
                "metadata": MetadataField({
                    "gold_set": set(gold_tuples) # TODO make it a multiset so we can change span selection policy?
                }),
                "qarg_labeled_clauses": qarg_labeled_clauses_field,
                "qarg_labeled_spans": qarg_labeled_spans_field,
                "qarg_labels": qarg_labels_field,
            }
        else:
            yield {
                **verb_fields,
                "tan_strings": tan_string_list_field,
                "tan_set": tan_multilabel_field,
                "answer_spans": answer_spans_field,
                "num_answers": num_answers_field,
                "num_invalids": num_invalids_field,
                "metadata": MetadataField({}),
            }
    def text_to_instance(
        self,  # type: ignore
        tokens: List[str],
        pos_tags: List[str] = None,
        gold_tree: Tree = None,
    ) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        # Parameters

        tokens : `List[str]`, required.
            The tokens in a given sentence.
        pos_tags : `List[str]`, optional, (default = `None`).
            The POS tags for the words in the sentence.
        gold_tree : `Tree`, optional (default = `None`).
            The gold parse tree to create span labels from.

        # Returns

        An `Instance` containing the following fields:
            tokens : `TextField`
                The tokens in the sentence.
            pos_tags : `SequenceLabelField`
                The POS tags of the words in the sentence.
                Only returned if `use_pos_tags` is `True`
            spans : `ListField[SpanField]`
                A ListField containing all possible subspans of the
                sentence.
            span_labels : `SequenceLabelField`, optional.
                The constituency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a `NO-LABEL` label.
            gold_tree : `MetadataField(Tree)`
                The gold NLTK parse tree for use in evaluation.
        """

        if self._convert_parentheses:
            tokens = [PTB_PARENTHESES.get(token, token) for token in tokens]
        text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        pos_namespace = self._label_namespace_prefix + self._pos_label_namespace
        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags, text_field, label_namespace=pos_namespace)
            fields["pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError(
                "use_pos_tags was set to True but no gold pos"
                " tags were passed to the dataset reader."
            )
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans)

        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                gold_labels.append(gold_spans.get((start, end), "NO-LABEL"))

        metadata = {"tokens": tokens}
        if gold_tree:
            metadata["gold_tree"] = gold_tree
        if self._use_pos_tags:
            metadata["pos_tags"] = pos_tags

        fields["metadata"] = MetadataField(metadata)

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(
                gold_labels,
                span_list_field,
                label_namespace=self._label_namespace_prefix + "labels",
            )
        return Instance(fields)
示例#21
0
    def text_to_instance(
            self,
            para_id: str,
            sentence_texts: List[str],
            participants: List[str],
            states: List[
                List[str]] = None,  # states[i][j] is ith participant at time j
            filename: str = '',
            score: float = None) -> Instance:

        tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(
            pos_tags=True))

        paragraph = " ".join(sentence_texts)

        # Tokenize the sentences
        sentences = [
            tokenizer.tokenize(sentence_text)
            for sentence_text in sentence_texts
        ]

        # Find the verbs
        verb_indexes = [[
            1 if token.pos_ == "VERB" else 0 for token in sentence
        ] for sentence in sentences]

        if states is not None:
            # Actions is (num_participants, num_events)
            actions = [_infer_actions(states_i) for states_i in states]

            tokenized_states = [[
                tokenizer.tokenize(state_ij) for state_ij in states_i
            ] for states_i in states]

            location_spans = [
                _compute_location_spans(states_i, sentences)
                for states_i in tokenized_states
            ]

        # Create indicators for the participants.
        participant_tokens = [
            tokenizer.tokenize(participant) for participant in participants
        ]
        participant_indicators: List[List[List[int]]] = []

        for participant_i_tokens in participant_tokens:
            targets = [
                list(token_group)
                for is_semicolon, token_group in itertools.groupby(
                    participant_i_tokens, lambda t: t.text == ";")
                if not is_semicolon
            ]

            participant_i_indicators: List[List[int]] = []

            for sentence in sentences:
                sentence_indicator = [0 for _ in sentence]

                for target in targets:
                    start = 0
                    while True:
                        span_start, span_end = _find_span(target,
                                                          sentence,
                                                          start,
                                                          target_is_noun=True)
                        if span_start >= 0:
                            for j in range(span_start, span_end + 1):
                                sentence_indicator[j] = 1
                            start = span_start + 1
                        else:
                            break

                participant_i_indicators.append(sentence_indicator)

            participant_indicators.append(participant_i_indicators)

        fields: Dict[str, Field] = {}
        fields["paragraph"] = TextField(tokenizer.tokenize(paragraph),
                                        self._token_indexers)
        fields["participants"] = ListField([
            TextField(tokenizer.tokenize(participant), self._token_indexers)
            for participant in participants
        ])

        # One per sentence
        fields["sentences"] = ListField([
            TextField(sentence, self._token_indexers) for sentence in sentences
        ])

        # One per sentence
        fields["verbs"] = ListField([
            SequenceLabelField(verb_indexes[i],
                               fields["sentences"].field_list[i])
            for i in range(len(sentences))
        ])
        # And also at the paragraph level
        fields["paragraph_verbs"] = SequenceLabelField([
            verb_indicator for verb_indexes_i in verb_indexes
            for verb_indicator in verb_indexes_i
        ], fields["paragraph"])

        if states is not None:
            # Outer ListField is one per participant
            fields["actions"] = ListField([
                # Inner ListField is one per sentence
                ListField([
                    # action is an Enum, so call .value to get an int
                    LabelField(action.value, skip_indexing=True)
                    for action in participant_actions
                ]) for participant_actions in actions
            ])

            # Outer ListField is one per participant
            fields["before_locations"] = ListField([
                # Inner ListField is one per sentence
                ListField([
                    SpanField(start, end, fields["sentences"].field_list[i])
                    for i, ((start, end),
                            _) in enumerate(participant_location_spans)
                ]) for participant_location_spans in location_spans
            ])
            # Outer ListField is one per participant
            fields["after_locations"] = ListField([
                # Inner ListField is one per sentence
                ListField([
                    SpanField(start, end, fields["sentences"].field_list[i])
                    for i, (_, (start,
                                end)) in enumerate(participant_location_spans)
                ]) for participant_location_spans in location_spans
            ])

        # one per participant
        fields["participant_indicators"] = ListField([
            # one per sentence
            ListField([
                SequenceLabelField(sentence_indicator,
                                   fields["sentences"].field_list[i]) for i,
                sentence_indicator in enumerate(participant_i_indicators)
            ]) for participant_i_indicators in participant_indicators
        ])

        # and also at the paragraph level
        # one per participant
        fields["paragraph_participant_indicators"] = ListField([
            SequenceLabelField([
                indicator for sentence_indicator in participant_i_indicators
                for indicator in sentence_indicator
            ], fields["paragraph"])
            for participant_i_indicators in participant_indicators
        ])

        # Finally, we want to indicate before / inside / after for each sentence.
        paragraph_sentence_indicators: List[SequenceLabelField] = []
        for i in range(len(sentences)):
            before_length = sum(len(sentence) for sentence in sentences[:i])
            sentence_length = len(sentences[i])
            after_length = sum(
                len(sentence) for sentence in sentences[(i + 1):])
            paragraph_sentence_indicators.append(
                SequenceLabelField([0] * before_length +
                                   [1] * sentence_length + [2] * after_length,
                                   fields["paragraph"]))

        fields["paragraph_sentence_indicators"] = ListField(
            paragraph_sentence_indicators)

        # These fields are passed on to the decoder trainer that internally uses it
        # to compute commonsense scores for predicted actions
        fields["para_id"] = MetadataField(para_id)
        fields["participant_strings"] = MetadataField(participants)

        fields["filename"] = MetadataField(filename)

        if score is not None:
            fields["score"] = MetadataField(score)

        return Instance(fields)
示例#22
0
 def predict(self, inputs: JsonDict) -> JsonDict:
     # produce different sets of instances to account for
     # the possibility of different token indexers as well as different vocabularies
     span_verb_instances = list(
         self._span_model_dataset_reader.sentence_json_to_instances(
             inputs, verbs_only=True))
     span_to_question_verb_instances = list(
         self._span_to_question_model_dataset_reader.
         sentence_json_to_instances(inputs, verbs_only=True))
     span_outputs = self._span_model.forward_on_instances(
         span_verb_instances)
     verb_dicts = []
     for (verb_instance,
          span_output) in zip(span_to_question_verb_instances,
                              span_outputs):
         beam = []
         scored_spans = [(s, p) for s, p in span_output["spans"]
                         if p >= self._span_minimum_threshold]
         span_fields = [
             SpanField(span.start(), span.end(), verb_instance["text"])
             for span, _ in scored_spans
         ]
         if len(span_fields) > 0:
             verb_instance.index_fields(self._span_to_question_model.vocab)
             verb_instance.add_field("answer_spans", ListField(span_fields),
                                     self._span_to_question_model.vocab)
             qgen_input_tensors = move_to_device(
                 Batch([verb_instance]).as_tensor_dict(),
                 self._span_to_question_model._get_prediction_device())
             question_beams = self._span_to_question_model.beam_decode(
                 text=qgen_input_tensors["text"],
                 predicate_indicator=qgen_input_tensors[
                     "predicate_indicator"],
                 predicate_index=qgen_input_tensors["predicate_index"],
                 answer_spans=qgen_input_tensors["answer_spans"],
                 max_beam_size=self._question_beam_size,
                 min_beam_probability=self._question_minimum_threshold)
             for (span, span_prob), (_, slot_values, question_probs) in zip(
                     scored_spans, question_beams):
                 for i in range(len(question_probs)):
                     question_slots = {
                         slot_name: slot_values[slot_name][i]
                         for slot_name in
                         self._span_to_question_model.get_slot_names()
                     }
                     beam.append({
                         "questionSlots": question_slots,
                         "questionProb": question_probs[i],
                         "span": [span.start(),
                                  span.end() + 1],
                         "spanProb": span_prob
                     })
         verb_dicts.append({
             "verbIndex":
             verb_instance["metadata"]["verb_index"],
             "verbInflectedForms":
             verb_instance["metadata"]["verb_inflected_forms"],
             "beam":
             beam
         })
     return {
         "sentenceId": inputs["sentenceId"],
         "sentenceTokens": inputs["sentenceTokens"],
         "verbs": verb_dicts
     }
示例#23
0
文件: drop.py 项目: MyPaperCode/RAIN
    def make_marginal_bert_drop_instance(passage_question_tokens: List[Token],
                                    #passage_tokens: List[Token],
                                    implicit_tokens: List[Token],
                                    number_tokens: List[Token],
                                    number_indices: List[int],
                                    token_indexers: Dict[str, TokenIndexer],
                                    passage_text: str,
                                    answer_info: Dict[str, Any] = None,
                                    additional_metadata: Dict[str, Any] = None) -> Instance:
        
        additional_metadata = additional_metadata or {}
        fields: Dict[str, Field] = {}

        passage_question_field = TextField(passage_question_tokens,token_indexers)
        fields["passage_question"] = passage_question_field
        
        number_index_fields: List[Field] = [IndexField(index, passage_question_field) for index in number_indices]
        fields["number_indices"] = ListField(number_index_fields)

        numbers_in_passage_question_field = TextField(number_tokens, token_indexers)
        
        implicit_token_field = TextField(implicit_tokens, token_indexers)

        metadata = {"original_passage": passage_text,
                    "passage_question_tokens": [token.text for token in passage_question_tokens],
                    "number_tokens": [token.text for token in number_tokens],
                    "number_indices": number_indices}

        if answer_info:
       
            metadata["answer_texts"] = answer_info["answer_texts"]

            """
            spans
            """
            span_fields: List[Field] = \
                [SpanField(span[0], span[1], passage_question_field) for span in answer_info["answer_spans"]]
            if not span_fields:
                span_fields.append(SpanField(-1, -1, passage_question_field))
            fields["answer_as_spans"] = ListField(span_fields)


            """
            number and date  
            """
            add_sub_signs_field: List[Field] = []
            for signs_for_one_add_sub_expression in answer_info["signs_for_add_sub_expressions"]:
                add_sub_signs_field.append(SequenceLabelField(signs_for_one_add_sub_expression,
                                                              numbers_in_passage_question_field))
            if not add_sub_signs_field:
                add_sub_signs_field.append(SequenceLabelField([0] * len(number_tokens),
                                                              numbers_in_passage_question_field))
            fields["answer_as_add_sub_expressions"] = ListField(add_sub_signs_field)

            """
            count
            """
            count_fields: List[Field] = [LabelField(count_label, skip_indexing=True)
                                         for count_label in answer_info["counts"]]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)



            answer_label = np.zeros((3))
            if answer_info["answer_spans"]:
                answer_label[0] = 1.0
            if answer_info["signs_for_add_sub_expressions"]:
                answer_label[1] = 1.0
            if answer_info["counts"]:
                answer_label[2] = 1.0
            if sum(answer_label)!=0:
                answer_label = answer_label /float(sum(answer_label))
            fields["answer_type"] = ArrayField(answer_label, -1)

        metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)
        return Instance(fields)
示例#24
0
    def text_to_instance(self, sentence: List[str],
                         ner_dict: Dict[Tuple[int, int], str], relation_dict,
                         cluster_dict, trigger_dict, argument_dict,
                         doc_key: str, dataset: str, sentence_num: int,
                         groups: List[str], start_ix: int, end_ix: int):
        """
        TODO(dwadden) document me.
        """

        sentence = [self._normalize_word(word) for word in sentence]

        text_field = TextField([Token(word) for word in sentence],
                               self._token_indexers)
        text_field_with_context = TextField([Token(word) for word in groups],
                                            self._token_indexers)

        # Put together the metadata.
        metadata = dict(sentence=sentence,
                        ner_dict=ner_dict,
                        relation_dict=relation_dict,
                        cluster_dict=cluster_dict,
                        trigger_dict=trigger_dict,
                        argument_dict=argument_dict,
                        doc_key=doc_key,
                        dataset=dataset,
                        groups=groups,
                        start_ix=start_ix,
                        end_ix=end_ix,
                        sentence_num=sentence_num)
        metadata_field = MetadataField(metadata)

        # Trigger labels. One label per token in the input.
        token_trigger_labels = []
        for i in range(len(text_field)):
            token_trigger_labels.append(trigger_dict[i])

        trigger_label_field = SequenceLabelField(
            token_trigger_labels, text_field, label_namespace="trigger_labels")

        # Generate fields for text spans, ner labels, coref labels.
        spans = []
        span_ner_labels = []
        span_coref_labels = []
        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            span_ix = (start, end)
            span_ner_labels.append(ner_dict[span_ix])
            span_coref_labels.append(cluster_dict[span_ix])
            spans.append(SpanField(start, end, text_field))

        span_field = ListField(spans)
        ner_label_field = SequenceLabelField(span_ner_labels,
                                             span_field,
                                             label_namespace="ner_labels")
        coref_label_field = SequenceLabelField(span_coref_labels,
                                               span_field,
                                               label_namespace="coref_labels")

        # Generate labels for relations and arguments. Only store non-null values.
        # For the arguments, by convention the first span specifies the trigger, and the second
        # specifies the argument. Ideally we'd have an adjacency field between (token, span) pairs
        # for the event arguments field, but AllenNLP doesn't make it possible to express
        # adjacencies between two different sequences.
        n_spans = len(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]
        candidate_indices = [(i, j) for i in range(n_spans)
                             for j in range(n_spans)]

        relations = []
        relation_indices = []
        for i, j in candidate_indices:
            span_pair = (span_tuples[i], span_tuples[j])
            relation_label = relation_dict[span_pair]
            if relation_label:
                relation_indices.append((i, j))
                relations.append(relation_label)

        relation_label_field = AdjacencyField(
            indices=relation_indices,
            sequence_field=span_field,
            labels=relations,
            label_namespace="relation_labels")

        arguments = []
        argument_indices = []
        n_tokens = len(sentence)
        candidate_indices = [(i, j) for i in range(n_tokens)
                             for j in range(n_spans)]
        for i, j in candidate_indices:
            token_span_pair = (i, span_tuples[j])
            argument_label = argument_dict[token_span_pair]
            if argument_label:
                argument_indices.append((i, j))
                arguments.append(argument_label)

        argument_label_field = AdjacencyFieldAssym(
            indices=argument_indices,
            row_field=text_field,
            col_field=span_field,
            labels=arguments,
            label_namespace="argument_labels")

        # Pull it  all together.
        fields = dict(text=text_field_with_context,
                      spans=span_field,
                      ner_labels=ner_label_field,
                      coref_labels=coref_label_field,
                      trigger_labels=trigger_label_field,
                      argument_labels=argument_label_field,
                      relation_labels=relation_label_field,
                      metadata=metadata_field)

        return Instance(fields)
示例#25
0
 def test_span_field_raises_on_incorrect_label_type(self):
     with pytest.raises(TypeError):
         _ = SpanField("hello", 3, self.text)
示例#26
0
    def predictions_to_labeled_instances(
        self, instance: Instance, outputs: Dict[str, numpy.ndarray]
    ) -> List[Instance]:
        new_instance = instance.duplicate()
        # For BiDAF
        if "best_span" in outputs:
            span_start_label = outputs["best_span"][0]
            span_end_label = outputs["best_span"][1]
            passage_field: SequenceField = new_instance["passage"]  # type: ignore
            new_instance.add_field(
                "span_start", IndexField(int(span_start_label), passage_field), self._model.vocab
            )
            new_instance.add_field(
                "span_end", IndexField(int(span_end_label), passage_field), self._model.vocab
            )

        # For NAQANet model. It has the fields: answer_as_passage_spans, answer_as_question_spans,
        # answer_as_add_sub_expressions, answer_as_counts. We need labels for all.
        elif "answer" in outputs:
            answer_type = outputs["answer"]["answer_type"]

            # When the problem is a counting problem
            if answer_type == "count":
                field = ListField([LabelField(int(outputs["answer"]["count"]), skip_indexing=True)])
                new_instance.add_field("answer_as_counts", field, self._model.vocab)

            # When the answer is in the passage
            elif answer_type == "passage_span":
                # TODO(mattg): Currently we only handle one predicted span.
                span = outputs["answer"]["spans"][0]

                # Convert character span indices into word span indices
                word_span_start = None
                word_span_end = None
                offsets = new_instance["metadata"].metadata["passage_token_offsets"]  # type: ignore
                for index, offset in enumerate(offsets):
                    if offset[0] == span[0]:
                        word_span_start = index
                    if offset[1] == span[1]:
                        word_span_end = index

                passage_field: SequenceField = new_instance["passage"]  # type: ignore
                field = ListField([SpanField(word_span_start, word_span_end, passage_field)])
                new_instance.add_field("answer_as_passage_spans", field, self._model.vocab)

            # When the answer is an arithmetic calculation
            elif answer_type == "arithmetic":
                # The different numbers in the passage that the model encounters
                sequence_labels = outputs["answer"]["numbers"]

                numbers_field: ListField = instance["number_indices"]  # type: ignore

                # The numbers in the passage are given signs, that's what we are labeling here.
                # Negative signs are given the class label 2 (for 0 and 1, the sign matches the
                # label).
                labels = []
                for label in sequence_labels:
                    if label["sign"] == -1:
                        labels.append(2)
                    else:
                        labels.append(label["sign"])
                # There's a dummy number added in the dataset reader to handle passages with no
                # numbers; it has a label of 0 (not included).
                labels.append(0)

                field = ListField([SequenceLabelField(labels, numbers_field)])
                new_instance.add_field("answer_as_add_sub_expressions", field, self._model.vocab)

            # When the answer is in the question
            elif answer_type == "question_span":
                span = outputs["answer"]["spans"][0]

                # Convert character span indices into word span indices
                word_span_start = None
                word_span_end = None
                question_offsets = new_instance["metadata"].metadata[  # type: ignore
                    "question_token_offsets"
                ]
                for index, offset in enumerate(question_offsets):
                    if offset[0] == span[0]:
                        word_span_start = index
                    if offset[1] == span[1]:
                        word_span_end = index

                question_field: SequenceField = new_instance["question"]  # type: ignore
                field = ListField([SpanField(word_span_start, word_span_end, question_field)])
                new_instance.add_field("answer_as_question_spans", field, self._model.vocab)

        return [new_instance]
示例#27
0
 def test_span_field_raises_if_span_end_is_greater_than_sentence_length(self):
     with pytest.raises(ValueError):
         _ = SpanField(1, 30, self.text)
示例#28
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        tokenized_question: List[Token],
        context: str,
        tokenized_context: List[Token],
        answers: List[str] = None,
        token_answer_span: Optional[Tuple[int, int]] = None,
        additional_metadata: Dict[str, Any] = None,
        always_add_answer_span: bool = False,
    ) -> Instance:
        fields = {}

        # make the question field
        question_field = TextField(
            self._tokenizer.add_special_tokens(tokenized_question, tokenized_context),
            self._token_indexers,
        )
        fields["question_with_context"] = question_field

        cls_index = self._find_cls_index(question_field.tokens)
        if self._include_cls_index:
            fields["cls_index"] = IndexField(cls_index, question_field)

        start_of_context = (
            len(self._tokenizer.sequence_pair_start_tokens)
            + len(tokenized_question)
            + len(self._tokenizer.sequence_pair_mid_tokens)
        )

        # make the answer span
        if token_answer_span is not None:
            assert all(i >= 0 for i in token_answer_span)
            assert token_answer_span[0] <= token_answer_span[1]

            fields["answer_span"] = SpanField(
                token_answer_span[0] + start_of_context,
                token_answer_span[1] + start_of_context,
                question_field,
            )
        elif always_add_answer_span:
            fields["answer_span"] = SpanField(cls_index, cls_index, question_field)

        # make the context span, i.e., the span of text from which possible answers should be drawn
        fields["context_span"] = SpanField(
            start_of_context, start_of_context + len(tokenized_context) - 1, question_field
        )

        # make the metadata
        metadata = {
            "question": question,
            "question_tokens": tokenized_question,
            "context": context,
            "context_tokens": tokenized_context,
            "answers": answers or [],
        }
        if additional_metadata is not None:
            metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
示例#29
0
 def test_printing_doesnt_crash(self):
     span_field = SpanField(2, 3, self.text)
     print(span_field)
示例#30
0
    def text_to_instance(
            self,  # type: ignore
            tokens: List[str],
            pos_tags: List[str] = None,
            gold_tree: Tree = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            The tokens in a given sentence.
        pos_tags ``List[str]``, optional, (default = None).
            The POS tags for the words in the sentence.
        gold_tree : ``Tree``, optional (default = None).
            The gold parse tree to create span labels from.

        Returns
        -------
        An ``Instance`` containing the following fields:
            tokens : ``TextField``
                The tokens in the sentence.
            pos_tags : ``SequenceLabelField``
                The POS tags of the words in the sentence.
                Only returned if ``use_pos_tags`` is ``True``
            spans : ``ListField[SpanField]``
                A ListField containing all possible subspans of the
                sentence.
            span_labels : ``SequenceLabelField``, optional.
                The constiutency tags for each of the possible spans, with
                respect to a gold parse tree. If a span is not contained
                within the tree, a span will have a ``NO-LABEL`` label.
            gold_tree : ``MetadataField(Tree)``
                The gold NLTK parse tree for use in evaluation.
        """
        # pylint: disable=arguments-differ
        text_field = TextField([Token(x) for x in tokens],
                               token_indexers=self._token_indexers)
        fields: Dict[str, Field] = {"tokens": text_field}

        if self._use_pos_tags and pos_tags is not None:
            pos_tag_field = SequenceLabelField(pos_tags,
                                               text_field,
                                               label_namespace="pos")
            fields["pos_tags"] = pos_tag_field
        elif self._use_pos_tags:
            raise ConfigurationError(
                "use_pos_tags was set to True but no gold pos"
                " tags were passed to the dataset reader.")
        spans: List[Field] = []
        gold_labels = []

        if gold_tree is not None:
            gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {}
            self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags)
            gold_spans = {
                span: label
                for (span, label) in gold_spans_with_pos_tags.items()
                if "-POS" not in label
            }
        else:
            gold_spans = None
        for start, end in enumerate_spans(tokens):
            spans.append(SpanField(start, end, text_field))

            if gold_spans is not None:
                if (start, end) in gold_spans.keys():
                    gold_labels.append(gold_spans[(start, end)])
                else:
                    gold_labels.append("NO-LABEL")

        metadata = {"tokens": tokens}
        if gold_tree:
            metadata["gold_tree"] = gold_tree
        if self._use_pos_tags:
            metadata["pos_tags"] = pos_tags

        fields["metadata"] = MetadataField(metadata)

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field
        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(gold_labels,
                                                       span_list_field)
        return Instance(fields)
示例#31
0
 def test_as_tensor_converts_span_field_correctly(self):
     span_field = SpanField(2, 3, self.text)
     tensor = span_field.as_tensor(
         span_field.get_padding_lengths()).data.cpu().numpy()
     numpy.testing.assert_array_equal(tensor, numpy.array([2, 3]))
    def text_to_instance(
        self,  # type: ignore
        question: str,
        tokenized_question: List[Token],
        context: str,
        tokenized_context: List[Token],
        answers: List[str],
        token_answer_span: Optional[Tuple[int, int]],
        additional_metadata: Dict[str, Any] = None,
    ) -> Instance:
        fields = {}

        # make the question field
        cls_token = Token(
            self._tokenizer.tokenizer.cls_token,
            text_id=self._tokenizer.tokenizer.cls_token_id,
            type_id=self.non_content_type_id,
        )

        sep_token = Token(
            self._tokenizer.tokenizer.sep_token,
            text_id=self._tokenizer.tokenizer.sep_token_id,
            type_id=self.non_content_type_id,
        )

        question_field = TextField(
            ([cls_token] + tokenized_question + [sep_token, sep_token] +
             tokenized_context + [sep_token]),
            self._token_indexers,
        )
        fields["question_with_context"] = question_field
        start_of_context = 1 + len(tokenized_question) + 2

        # make the answer span
        if token_answer_span is not None:
            assert all(i >= 0 for i in token_answer_span)
            assert token_answer_span[0] <= token_answer_span[1]

            fields["answer_span"] = SpanField(
                token_answer_span[0] + start_of_context,
                token_answer_span[1] + start_of_context,
                question_field,
            )
        else:
            # We have to put in something even when we don't have an answer, so that this instance can be batched
            # together with other instances that have answers.
            fields["answer_span"] = SpanField(-1, -1, question_field)

        # make the context span, i.e., the span of text from which possible answers should be drawn
        fields["context_span"] = SpanField(
            start_of_context, start_of_context + len(tokenized_context) - 1,
            question_field)

        # make the metadata
        metadata = {
            "question": question,
            "question_tokens": tokenized_question,
            "context": context,
            "context_tokens": tokenized_context,
            "answers": answers,
        }
        if additional_metadata is not None:
            metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)