Exemplo n.º 1
0
 def text_to_instance(
         self,
         source_string: str,
         target_string: str = None) -> Instance:  # type: ignore
     # pylint: disable=arguments-differ
     tokenized_source = self._source_tokenizer.tokenize(source_string)
     if self._source_add_start_token:
         tokenized_source.insert(0, Token(START_SYMBOL))
     tokenized_source.append(Token(END_SYMBOL))
     source_field = TextField(tokenized_source, self._source_token_indexers)
     source_spans_indeces = enumerate_spans(
         tokenized_source, max_span_width=self._max_span_width)
     source_spans = ListField(
         [SpanField(i, j, source_field) for i, j in source_spans_indeces])
     if target_string is not None:
         tokenized_target = self._target_tokenizer.tokenize(target_string)
         tokenized_target.insert(0, Token(START_SYMBOL))
         tokenized_target.append(Token(END_SYMBOL))
         target_field = TextField(tokenized_target,
                                  self._target_token_indexers)
         return Instance({
             "source_spans": source_spans,
             "source_tokens": source_field,
             "target_tokens": target_field
         })
     else:
         return Instance({
             'source_spans': source_spans,
             'source_tokens': source_field
         })
Exemplo n.º 2
0
    def _process_sentence(self, sent: Sentence, dataset: str):
        # Get the sentence text and define the `text_field`.
        sentence_text = [self._normalize_word(word) for word in sent.text]
        text_field = TextField([Token(word) for word in sentence_text],
                               self._token_indexers)

        # Enumerate spans.
        spans = []
        for start, end in enumerate_spans(sentence_text,
                                          max_span_width=self._max_span_width):
            spans.append(SpanField(start, end, text_field))
        span_field = ListField(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]

        # Convert data to fields.
        # NOTE: The `ner_labels` and `coref_labels` would ideally have type
        # `ListField[SequenceLabelField]`, where the sequence labels are over the `SpanField` of
        # `spans`. But calling `as_tensor_dict()` fails on this specific data type. Matt G
        # recognized that this is an AllenNLP API issue and suggested that represent these as
        # `ListField[ListField[LabelField]]` instead.
        fields = {}
        fields["text"] = text_field
        fields["spans"] = span_field
        if sent.ner is not None:
            ner_labels = self._process_ner(span_tuples, sent)
            fields["ner_labels"] = ListField([
                LabelField(entry, label_namespace=f"{dataset}__ner_labels")
                for entry in ner_labels
            ])
        if sent.cluster_dict is not None:
            # Skip indexing for coref labels, which are ints.
            coref_labels = self._process_coref(span_tuples, sent)
            fields["coref_labels"] = ListField([
                LabelField(entry,
                           label_namespace="coref_labels",
                           skip_indexing=True) for entry in coref_labels
            ])
        if sent.relations is not None:
            relation_labels, relation_indices = self._process_relations(
                span_tuples, sent)
            fields["relation_labels"] = AdjacencyField(
                indices=relation_indices,
                sequence_field=span_field,
                labels=relation_labels,
                label_namespace=f"{dataset}__relation_labels")
        if sent.events is not None:
            trigger_labels, argument_labels, argument_indices = self._process_events(
                span_tuples, sent)
            fields["trigger_labels"] = SequenceLabelField(
                trigger_labels,
                text_field,
                label_namespace=f"{dataset}__trigger_labels")
            fields["argument_labels"] = AdjacencyFieldAssym(
                indices=argument_indices,
                row_field=text_field,
                col_field=span_field,
                labels=argument_labels,
                label_namespace=f"{dataset}__argument_labels")

        return fields
Exemplo n.º 3
0
    def build_spans(self, fields, ner_spans, wordpieces, start_offsets, offsets, text_field, sentence_list):
        new_spans = _convert_span_to_wordpiece_span(ner_spans, start_offsets, offsets)
        ner_spans_dict = {}
        for start, end, tag in new_spans:
            ner_spans_dict[(start, end)] = tag
        spans = []
        labels = []
        sentence_offset = 0
        wordpiece_offset = 1
        for sentence in sentence_list:
            wordpiece_sentence = wordpieces[wordpiece_offset:
                                            offsets[sentence_offset + len(sentence) - 1] + 1]
            for start, end in enumerate_spans(wordpiece_sentence,
                                              offset=wordpiece_offset,
                                              max_span_width=self._max_span_width):
                if (start, end) in ner_spans_dict:
                    labels.append(ner_spans_dict[(start, end)])
                    spans.append(SpanField(start, end, text_field))
                else:
                    labels.append("O")
                    spans.append(SpanField(start, end, text_field))

            sentence_offset += len(sentence)
            wordpiece_offset += len(wordpiece_sentence)

        span_field = ListField(spans)
        labels_field = SequenceLabelField(labels, span_field, 'labels')
        fields["spans"] = span_field
        fields["labels"] = labels_field
Exemplo n.º 4
0
    def get_mentions_with_gold(self,
                               text: str,
                               gold_spans,
                               gold_entities,
                               whitespace_tokenize=True,
                               keep_gold_only: bool = False):

        gold_spans_to_entities = {
            tuple(k): v
            for k, v in zip(gold_spans, gold_entities)
        }

        if whitespace_tokenize:
            tokens = self.whitespace_tokenizer(text)
        else:
            tokens = self.tokenizer(text)

        tokens = [t.text for t in tokens]
        if keep_gold_only:
            spans_with_gold = set(gold_spans_to_entities.keys())
        else:
            all_spans = enumerate_spans(tokens,
                                        max_span_width=5,
                                        filter_function=span_filter_func)
            spans_with_gold = set().union(all_spans,
                                          [tuple(span) for span in gold_spans])

        spans = []
        entities = []
        gold_entities = []
        priors = []
        for span in spans_with_gold:
            candidate_entities = self.process(tokens[span[0]:span[1] + 1])

            gold_entity = gold_spans_to_entities.get(span, "@@NULL@@")
            # Only keep spans which we have candidates for.
            # For a small number of gold candidates,
            # we don't have mention candidates for them,
            # we can't link to them.
            if not candidate_entities:
                continue

            candidate_names = [x[1] for x in candidate_entities]
            candidate_priors = [x[2] for x in candidate_entities]
            sum_priors = sum(candidate_priors)
            priors.append([x / sum_priors for x in candidate_priors])

            spans.append(list(span))
            entities.append(candidate_names)
            gold_entities.append(gold_entity)

        return {
            "tokenized_text": tokens,
            "candidate_spans": spans,
            "candidate_entities": entities,
            # TODO Change to priors
            "candidate_entity_prior": priors,
            "gold_entities": gold_entities
        }
Exemplo n.º 5
0
 def get_spans(self, text, traget_word):
     target_spans = []
     for start, end in enumerate_spans(text, max_span_width=3):
         temp = text[start:end + 1]
         if temp == traget_word:
             target_spans.append([start, end])
             break
     return target_spans
Exemplo n.º 6
0
 def __init__(self, filename, repeat=1):
     self.max_sentence_length = 64
     self.max_spans_num = len(
         enumerate_spans(range(self.max_sentence_length), max_span_width=3))
     self.repeat = repeat
     self.tokenizer = BertTokenizer.from_pretrained(
         'pretrained_models/Chinese-BERT-wwm/')
     self.data_list = self.read_file(filename)
     self.len = len(self.data_list)
     self.process_data_list = self.process_data()
Exemplo n.º 7
0
    def get_mentions_raw_text(self, text: str, whitespace_tokenize=False):
        """
        returns:
            {'tokenized_text': List[str],
             'candidate_spans': List[List[int]] list of (start, end) indices for candidates,
                    where span is tokenized_text[start:(end + 1)]
             'candidate_entities': List[List[str]] = for each entity,
                    the candidates to link to. value is synset id, e.g
                    able.a.02 or hot_dog.n.01
             'candidate_entity_priors': List[List[float]]
        }
        """
        if whitespace_tokenize:
            tokens = self.whitespace_tokenizer(text)
        else:
            tokens = self.tokenizer(text)

        tokens = [t.text for t in tokens]
        all_spans = enumerate_spans(tokens,
                                    max_span_width=5,
                                    filter_function=span_filter_func)

        spans_to_candidates = {}

        for span in all_spans:
            candidate_entities = self.process(tokens[span[0]:span[1] + 1])
            if candidate_entities:
                # Only keep spans which we have candidates for.
                spans_to_candidates[(span[0], span[1])] = candidate_entities

        spans = []
        entities = []
        priors = []
        for span, candidates in spans_to_candidates.items():
            spans.append(list(span))
            entities.append([x[1] for x in candidates])
            mention_priors = [x[2] for x in candidates]

            # priors may not be normalized because we merged the
            # lowercase + cased values.
            sum_priors = sum(mention_priors)
            priors.append([x / sum_priors for x in mention_priors])

        ret = {
            "tokenized_text": tokens,
            "candidate_spans": spans,
            "candidate_entities": entities,
            "candidate_entity_priors": priors
        }

        if len(spans) == 0:
            ret.update(get_empty_candidates())

        return ret
Exemplo n.º 8
0
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        tokens = [Token(t) for t in json_dict["tokens"]]
        # Attribut (_dataset_reader._token_indexers ) wird durch unseren DataReader hinzugefügt!
        # Nicht allgemein gültig...
        token_indexers = self._dataset_reader._token_indexers
        sequence = TextField(tokens, token_indexers=token_indexers)

        spans = []
        for start, end in enumerate_spans(tokens, max_span_width=10):
            spans.append(SpanField(start, end, sequence))

        span_field = ListField(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]

        instance_fields: Dict[str, Field] = {
            "tokens": sequence,
            "metadata": MetadataField({"words": [x.text for x in tokens]}),
            "spans": span_field
        }
        return Instance(instance_fields)
Exemplo n.º 9
0
    def text_to_instance(self,
                         tokens: List[Token],
                         entities: List = None,
                         relations: List = None) -> Instance:
        sequence = TextField(tokens, self._token_indexers)
        instance_fields: Dict[str, Field] = {"tokens": sequence}
        words = [x.text for x in tokens]
        spans = []
        for start, end in enumerate_spans(words,
                                          max_span_width=self._max_span_width):
            assert start >= 0
            assert end >= 0
            spans.append(SpanField(start, end, sequence))

        span_field = ListField(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]
        instance_fields["spans"] = span_field

        ner_labels = [[] for i in span_tuples]

        ner_list = [((e.start, e.end), e.role) for e in entities]

        for span, label in ner_list:
            if self._too_long(span):
                continue
            ix = span_tuples.index(span)
            # if "" in ner_labels[ix]:
            #     ner_labels[ix].remove("")

            ner_labels[ix] += [label]

        instance_fields["ner_labels"] = ListField([
            MultiLabelField(entry, label_namespace=self.label_namespace)
            for entry in ner_labels
        ])

        metadata = {"words": words, "relations": relations}
        instance_fields["metadata"] = MetadataField(metadata)

        return Instance(instance_fields)
Exemplo n.º 10
0
def create_instance(sentence, vocab, token_indexers):
    """
    Create an batch tensor from the input sentence.
    """
    text = TextField([Token(word) for word in sentence],
                     token_indexers=token_indexers)

    spans = []
    for start, end in enumerate_spans(sentence,
                                      offset=0,
                                      max_span_width=COREF_MAX_SPAN_WIDTH):
        spans.append(SpanField(start, end, text))
    span_field = ListField(spans)

    instance = Instance({"tokens": text, "spans": span_field})

    instances = [instance]
    batch = Batch(instances)
    batch.index_instances(vocab)
    batch_tensor = batch.as_tensor_dict(batch.get_padding_lengths())

    return batch_tensor
Exemplo n.º 11
0
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        if "text" in json_dict:
            text = json_dict["text"]
            words = [
                text[t["span"]["start"]:t["span"]["end"]]
                for t in json_dict["tokens"]
            ]
        else:
            words = json_dict["tokens"]

        tokens = [Token(w) for w in words]
        # Attribut (_dataset_reader._token_indexers ) wird durch unseren DataReader hinzugefügt!
        # Nicht allgemein gültig...
        token_indexers = self._dataset_reader._token_indexers
        sequence = TextField(tokens, token_indexers=token_indexers)

        context_size = len(words) + 1
        spans = []
        span_masks = []
        for start, end in enumerate_spans(
                tokens, max_span_width=self._dataset_reader._max_span_width):
            spans.append(SpanField(start, end, sequence))
            span_masks.append(create_mask(start, end, context_size))

        span_field = ListField(spans)
        # span_tuples = [(span.span_start, span.span_end) for span in spans]
        span_mask_field = ListField([
            ArrayField(np.array(si, dtype=np.int), dtype=np.int)
            for si in span_masks
        ])
        instance_fields: Dict[str, Field] = {
            "tokens": sequence,
            "metadata": MetadataField({"words": [x.text for x in tokens]}),
            "spans": span_field,
            "span_masks": span_mask_field
        }
        return Instance(instance_fields)
Exemplo n.º 12
0
def make_coref_instance(
    sentences: List[List[str]],
    token_indexers: Dict[str, TokenIndexer],
    max_span_width: int,
    gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
    wordpiece_modeling_tokenizer: PretrainedTransformerTokenizer = None,
    max_sentences: int = None,
) -> Instance:
    """
    # Parameters

    sentences : `List[List[str]]`, required.
        A list of lists representing the tokenised words and sentences in the document.
    token_indexers : `Dict[str, TokenIndexer]`
        This is used to index the words in the document.  See :class:`TokenIndexer`.
    max_span_width : `int`, required.
        The maximum width of candidate spans to consider.
    gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None)
        A list of all clusters in the document, represented as word spans with absolute indices
        in the entire document. Each cluster contains some number of spans, which can be nested
        and overlap. If there are exact matches between clusters, they will be resolved
        using `_canonicalize_clusters`.
    wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = None)
        If not None, this dataset reader does subword tokenization using the supplied tokenizer
        and distribute the labels to the resulting wordpieces. All the modeling will be based on
        wordpieces. If this is set to `False` (default), the user is expected to use
        `PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`,
        and the modeling will be on the word-level.
    max_sentences: int, optional (default = None)
        The maximum number of sentences in each document to keep. By default keeps all sentences.

    # Returns

    An `Instance` containing the following `Fields`:
        text : `TextField`
            The text of the full document.
        spans : `ListField[SpanField]`
            A ListField containing the spans represented as `SpanFields`
            with respect to the document text.
        span_labels : `SequenceLabelField`, optional
            The id of the cluster which each possible span belongs to, or -1 if it does
                not belong to a cluster. As these labels have variable length (it depends on
                how many spans we are considering), we represent this a as a `SequenceLabelField`
                with respect to the `spans `ListField`.
    """
    if max_sentences is not None and len(sentences) > max_sentences:
        sentences = sentences[:max_sentences]
        total_length = sum(len(sentence) for sentence in sentences)

        if gold_clusters is not None:
            new_gold_clusters = []

            for cluster in gold_clusters:
                new_cluster = []
                for mention in cluster:
                    if mention[1] < total_length:
                        new_cluster.append(mention)
                if new_cluster:
                    new_gold_clusters.append(new_cluster)

            gold_clusters = new_gold_clusters

    flattened_sentences = [
        _normalize_word(word) for sentence in sentences for word in sentence
    ]

    if wordpiece_modeling_tokenizer is not None:
        flat_sentences_tokens, offsets = wordpiece_modeling_tokenizer.intra_word_tokenize(
            flattened_sentences)
        flattened_sentences = [t.text for t in flat_sentences_tokens]
    else:
        flat_sentences_tokens = [Token(word) for word in flattened_sentences]

    text_field = TextField(flat_sentences_tokens, token_indexers)

    cluster_dict = {}
    if gold_clusters is not None:
        gold_clusters = _canonicalize_clusters(gold_clusters)

        if wordpiece_modeling_tokenizer is not None:
            for cluster in gold_clusters:
                for mention_id, mention in enumerate(cluster):
                    start = offsets[mention[0]][0]
                    end = offsets[mention[1]][1]
                    cluster[mention_id] = (start, end)

        for cluster_id, cluster in enumerate(gold_clusters):
            for mention in cluster:
                cluster_dict[tuple(mention)] = cluster_id

    spans: List[Field] = []
    span_labels: Optional[
        List[int]] = [] if gold_clusters is not None else None

    sentence_offset = 0
    for sentence in sentences:
        for start, end in enumerate_spans(sentence,
                                          offset=sentence_offset,
                                          max_span_width=max_span_width):
            if wordpiece_modeling_tokenizer is not None:
                start = offsets[start][0]
                end = offsets[end][1]

                # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces
                # We have to do this check here because we use a span width embedding that has
                # only `max_span_width` entries, and since we are doing wordpiece
                # modeling, the span width embedding operates on wordpiece lengths. So a check
                # here is necessary or else we wouldn't know how many entries there would be.
                if end - start + 1 > max_span_width:
                    continue
                # We also don't generate spans that contain special tokens
                if start < wordpiece_modeling_tokenizer.num_added_start_tokens:
                    continue
                if (end >= len(flat_sentences_tokens) -
                        wordpiece_modeling_tokenizer.num_added_end_tokens):
                    continue

            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append(SpanField(start, end, text_field))
        sentence_offset += len(sentence)

    span_field = ListField(spans)

    metadata: Dict[str, Any] = {"original_text": flattened_sentences}
    if gold_clusters is not None:
        metadata["clusters"] = gold_clusters
    metadata_field = MetadataField(metadata)

    fields: Dict[str, Field] = {
        "text": text_field,
        "spans": span_field,
        "metadata": metadata_field,
    }
    if span_labels is not None:
        fields["span_labels"] = SequenceLabelField(span_labels, span_field)

    return Instance(fields)
Exemplo n.º 13
0
    def text_to_instance(
        self,  # type: ignore
        sentences: List[List[str]],
        gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
    ) -> Instance:
        """
        # Parameters

        sentences : `List[List[str]]`, required.
            A list of lists representing the tokenised words and sentences in the document.
        gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.

        # Returns

        An `Instance` containing the following `Fields`:
            text : `TextField`
                The text of the full document.
            spans : `ListField[SpanField]`
                A ListField containing the spans represented as `SpanFields`
                with respect to the document text.
            span_labels : `SequenceLabelField`, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a `SequenceLabelField`
                 with respect to the `spans `ListField`.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]

        if self._wordpiece_modeling_tokenizer is not None:
            flat_sentences_tokens, offsets = self._wordpiece_modeling_tokenizer.intra_word_tokenize(
                flattened_sentences)
            flattened_sentences = [t.text for t in flat_sentences_tokens]
        else:
            flat_sentences_tokens = [
                Token(word) for word in flattened_sentences
            ]

        text_field = TextField(flat_sentences_tokens, self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            if self._wordpiece_modeling_tokenizer is not None:
                for cluster in gold_clusters:
                    for mention_id, mention in enumerate(cluster):
                        start = offsets[mention[0]][0]
                        end = offsets[mention[1]][1]
                        cluster[mention_id] = (start, end)

            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None

        sentence_offset = 0
        for sentence in sentences:
            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if self._wordpiece_modeling_tokenizer is not None:
                    start = offsets[start][0]
                    end = offsets[end][1]

                    # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces
                    # We have to do this check here because we use a span width embedding that has
                    # only `self._max_span_width` entries, and since we are doing wordpiece
                    # modeling, the span width embedding operates on wordpiece lengths. So a check
                    # here is necessary or else we wouldn't know how many entries there would be.
                    if end - start + 1 > self._max_span_width:
                        continue
                    # We also don't generate spans that contain special tokens
                    if start < self._wordpiece_modeling_tokenizer.num_added_start_tokens:
                        continue
                    if (end >= len(flat_sentences_tokens) -
                            self._wordpiece_modeling_tokenizer.
                            num_added_end_tokens):
                        continue

                if span_labels is not None:
                    if (start, end) in cluster_dict:
                        span_labels.append(cluster_dict[(start, end)])
                    else:
                        span_labels.append(-1)

                spans.append(SpanField(start, end, text_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)

        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field,
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
Exemplo n.º 14
0
    def text_to_instance(self,  # type: ignore
                         sentences: List[List[str]],
                         gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
                         speaker_ids: Optional[List[int]] = None,
                         genre: Optional[int] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        speaker_ids : ``Optional[List[int]]``, optional (default = None)
            A list that maps each gold cluster with a speaker. Each speaker from the text
            is given a unique id.
        genre : ``Optional[int]``, optional (default = None)
            A unique id that represents the genre of the text
        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [self._normalize_word(word)
                               for sentence in sentences
                               for word in sentence]

        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
        if speaker_ids is not None:
            metadata["speaker_ids"] = speaker_ids
        if genre is not None:
            metadata["genre"] = genre

        text_field = TextField([Token(word) for word in flattened_sentences], self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[List[int]] = [] if gold_clusters is not None else None

        sentence_offset = 0
        for sentence in sentences:
            for start, end in enumerate_spans(sentence,
                                              offset=sentence_offset,
                                              max_span_width=self._max_span_width):
                if span_labels is not None:
                    if (start, end) in cluster_dict:
                        span_labels.append(cluster_dict[(start, end)])
                    else:
                        span_labels.append(-1)

                spans.append(SpanField(start, end, text_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {"text": text_field,
                                    "spans": span_field,
                                    "metadata": metadata_field}
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
Exemplo n.º 15
0
    def text_to_instance(
            self,  # type: ignore
            sentences: List[List[str]],
            document_id: str,
            sentence_id: int,
            gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
            user_threshold: Optional[float] = 0.0) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        document_id : ``str``, required.
            A string representing the document ID.
        sentence_id : ``int``, required.
            An int representing the sentence ID.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        user_threshold: ``Optional[float]``, optional (default = 0.0)
            approximate % of gold labels to label to hold out as user input.
            EX = 0.5, 0.33, 0.25, 0.125

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]

        metadata: Dict[str, Any] = {
            "original_text": flattened_sentences,
            "ID": document_id + ";" + str(sentence_id)
        }
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
            metadata["num_gold_clusters"] = len(gold_clusters)

        text_field = TextField([Token(word) for word in flattened_sentences],
                               self._token_indexers)

        user_threshold_mod = int(
            1 / user_threshold
        ) if self._simulate_user_inputs and user_threshold > 0 else 0
        cluster_dict = {}
        simulated_user_cluster_dict = {}

        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for i in range(len(cluster)):
                    # use modulo to have a relatively even distribution of user labels across length of document,
                    # (since clusters are sorted)--so user simulated clusters are spread evenly across document
                    if user_threshold_mod == 0 or i % user_threshold_mod != user_threshold_mod - 1:
                        cluster_dict[tuple(cluster[i])] = cluster_id
                    simulated_user_cluster_dict[tuple(cluster[i])] = cluster_id

        # Note simulated_user_cluster_dict encompasses ALL gold labels, including those in cluster_dict
        # Consequently user_labels encompasses all gold labels
        spans: List[Field] = []
        if gold_clusters is not None:
            span_labels: Optional[List[int]] = []
            user_labels: Optional[List[
                int]] = [] if self._simulate_user_inputs and user_threshold > 0 else None
        else:
            span_labels = user_labels = None

        # our must-link and cannot-link constraints, derived from user labels
        # using gold_clusters being None as an indicator of whether we're running training or not
        must_link: Optional[
            List[int]] = [] if gold_clusters is not None else None
        cannot_link: Optional[
            List[int]] = [] if gold_clusters is not None else None

        sentence_offset = 0
        doc_info = None
        if self._saved_labels is not None and metadata[
                'ID'] in self._saved_labels:
            doc_info = self._saved_labels[metadata['ID']]
            span_labels = doc_info['span_labels'].tolist()
            if 'must_link' in doc_info:
                must_link = doc_info['must_link'].squeeze(-1).tolist()
                cannot_link = doc_info['cannot_link'].squeeze(-1).tolist()
        for sentence in sentences:
            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if span_labels is not None:
                    if doc_info is None:
                        # only do if we haven't already loaded span labels
                        if (start, end) in cluster_dict:
                            span_labels.append(cluster_dict[(start, end)])
                        else:
                            span_labels.append(-1)
                    if self._simulate_user_inputs and user_threshold > 0:
                        if (start, end) in simulated_user_cluster_dict:
                            user_labels.append(
                                simulated_user_cluster_dict[(start, end)])
                        else:
                            user_labels.append(-1)

                spans.append(SpanField(start, end, text_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }

        if must_link is not None and len(must_link) > 0:
            must_link_field = []
            cannot_link_field = []
            for link in must_link:
                must_link_field.append(
                    PairField(
                        IndexField(link[0], span_field),
                        IndexField(link[1], span_field),
                    ))
            for link in cannot_link:
                cannot_link_field.append(
                    PairField(
                        IndexField(link[0], span_field),
                        IndexField(link[1], span_field),
                    ))
            must_link_field = ListField(must_link_field)
            cannot_link_field = ListField(cannot_link_field)
            fields["must_link"] = must_link_field
            fields["cannot_link"] = cannot_link_field

        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)
            if user_labels is not None:
                fields["user_labels"] = SequenceLabelField(
                    user_labels, span_field)

        # sanity checks
        if doc_info is not None:
            assert (fields["span_labels"].as_tensor(
                fields["span_labels"].get_padding_lengths()) !=
                    doc_info['span_labels']).nonzero().size(0) == 0
            if 'must_link' in doc_info:
                assert 'must_link' in fields
                assert (fields["must_link"].as_tensor(
                    fields["must_link"].get_padding_lengths()) !=
                        doc_info['must_link']).nonzero().size(0) == 0
                assert (fields["cannot_link"].as_tensor(
                    fields["cannot_link"].get_padding_lengths()) !=
                        doc_info['cannot_link']).nonzero().size(0) == 0

        return Instance(fields)
Exemplo n.º 16
0
def make_coref_instance(
    sentences: List[List[str]],
    token_indexers: Dict[str, TokenIndexer],
    max_span_width: int,
    document_id: Optional[str] = None,
    words: List[str] = None,
    gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
    srl_frames: Optional[List[Tuple[int, List[Tuple[int, int, str]]]]] = None,
    include_srl: bool = False,
    named_entities: Optional[List[str]] = None,
    named_entity_spans: Optional[List[Tuple[int, int, str]]] = None,
    include_ner: bool = False,
    include_coref: bool = True,
    wordpiece_modeling_tokenizer: PretrainedTransformerTokenizer = None,
    max_sentences: int = None,
    remove_singleton_clusters: bool = True,
    span_label_map: Dict[Tuple[int,int], str] = None,
    language: str = None,
    sentence_objects: List[OntonotesSentence] = None,
    parallel_sentences: List[List[str]] = None,
) -> Instance:

    """
    # Parameters

    sentences : `List[List[str]]`, required.
        A list of lists representing the tokenised words and sentences in the document.
    token_indexers : `Dict[str, TokenIndexer]`
        This is used to index the words in the document.  See :class:`TokenIndexer`.
    max_span_width : `int`, required.
        The maximum width of candidate spans to consider.
    gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = `None`)
        A list of all clusters in the document, represented as word spans with absolute indices
        in the entire document. Each cluster contains some number of spans, which can be nested
        and overlap. If there are exact matches between clusters, they will be resolved
        using `_canonicalize_clusters`.
    wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = `None`)
        If not None, this dataset reader does subword tokenization using the supplied tokenizer
        and distribute the labels to the resulting wordpieces. All the modeling will be based on
        wordpieces. If this is set to `False` (default), the user is expected to use
        `PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`,
        and the modeling will be on the word-level.
    max_sentences: `int`, optional (default = `None`)
        The maximum number of sentences in each document to keep. By default keeps all sentences.
    remove_singleton_clusters : `bool`, optional (default = `True`)
        Some datasets contain clusters that are singletons (i.e. no coreferents). This option allows
        the removal of them.

    # Returns

    An `Instance` containing the following `Fields`:

        text : `TextField`
            The text of the full document.
        spans : `ListField[SpanField]`
            A ListField containing the spans represented as `SpanFields`
            with respect to the document text.
        span_labels : `SequenceLabelField`, optional
            The id of the cluster which each possible span belongs to, or -1 if it does
                not belong to a cluster. As these labels have variable length (it depends on
                how many spans we are considering), we represent this a as a `SequenceLabelField`
                with respect to the spans `ListField`.
    """
    if max_sentences is not None and len(sentences) > max_sentences:
        sentences = sentences[:max_sentences]
        total_length = sum(len(sentence) for sentence in sentences)

        if gold_clusters is not None:
            new_gold_clusters = []

            for cluster in gold_clusters:
                new_cluster = []
                for mention in cluster:
                    if mention[1] < total_length:
                        new_cluster.append(mention)
                if new_cluster:
                    new_gold_clusters.append(new_cluster)

            gold_clusters = new_gold_clusters

    flattened_sentences = [_normalize_word(word) for sentence in sentences for word in sentence]
    sentences = [[_normalize_word(word) for word in sentence] for sentence in sentences]
    if parallel_sentences is not None:
        parallel_sentences = [[_normalize_word(word) for word in sentence] for sentence in parallel_sentences]
        flattened_parallel_sentences = [word for sentence in parallel_sentences for word in sentence]
    if words is not None:
        flattened_sentences = [_normalize_word(word) for word in words]
    if language is not None and language == "arabic":
        flattened_sentences = [clean_arabic_text(word.split("#")[0]) for word in flattened_sentences]
        sentences = [[clean_arabic_text(word.split("#")[0]) for word in sentence] for sentence in sentences]
        if parallel_sentences is not None:
            parallel_sentences = [[clean_arabic_text(word.split("#")[0]) for word in sentence] for sentence in parallel_sentences]
            flattened_parallel_sentences = [word for sentence in parallel_sentences for word in sentence]

    if wordpiece_modeling_tokenizer is not None:
        flat_sentences_tokens, offsets = wordpiece_modeling_tokenizer.intra_word_tokenize(
            flattened_sentences
        )
        flattened_sentences = [t.text for t in flat_sentences_tokens]
        if parallel_sentences is not None:
            flat_parallel_sentences_tokens, offsets = wordpiece_modeling_tokenizer.intra_word_tokenize(
                flattened_sentences
            )
    else:
        flat_sentences_tokens = [Token(word) for word in flattened_sentences]
        if parallel_sentences is not None:
            flat_parallel_sentences_tokens = [Token(word) for word in flattened_parallel_sentences]

    text_field = TextField(flat_sentences_tokens, token_indexers)

    cluster_dict = {}
    if gold_clusters is not None:
        gold_clusters = _canonicalize_clusters(gold_clusters)
        if remove_singleton_clusters:
            gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 1]

        if wordpiece_modeling_tokenizer is not None:
            for cluster in gold_clusters:
                for mention_id, mention in enumerate(cluster):
                    start = offsets[mention[0]][0]
                    end = offsets[mention[1]][1]
                    cluster[mention_id] = (start, end)

        for cluster_id, cluster in enumerate(gold_clusters):
            for mention in cluster:
                cluster_dict[tuple(mention)] = cluster_id

    spans: List[Field] = []
    span_index_map: Dict[Tuple[int, int], int] = {}
    token_same_sentence_spans: Dict[int, List[Tuple[int, int]]] = {}
    token_sentence_start_end_map: Dict[int, Tuple[int, int]] = {}
    sentence_index_span_map: Dict[int, Tuple[int, int]] = {}
    span_labels: Optional[List[Union[int,str]]] = [] if gold_clusters is not None else None

    sentence_offset = 0
    sentence_offsets = []
    for sent_index, sentence in enumerate(sentences):
        sentence_spans = []
        sentence_index_span_map[sent_index] = []
        for start, end in enumerate_spans(
            sentence, offset=sentence_offset, max_span_width=max_span_width
        ):
            if wordpiece_modeling_tokenizer is not None:
                start = offsets[start][0]
                end = offsets[end][1]

                # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces
                # We have to do this check here because we use a span width embedding that has
                # only `max_span_width` entries, and since we are doing wordpiece
                # modeling, the span width embedding operates on wordpiece lengths. So a check
                # here is necessary or else we wouldn't know how many entries there would be.
                if end - start + 1 > max_span_width:
                    continue
                # We also don't generate spans that contain special tokens
                if start < len(wordpiece_modeling_tokenizer.single_sequence_start_tokens):
                    continue
                if end >= len(flat_sentences_tokens) - len(
                    wordpiece_modeling_tokenizer.single_sequence_end_tokens
                ):
                    continue

            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)
            if span_label_map is not None:
                if (start, end) in span_label_map:
                    span_labels[-1] = span_label_map[(start, end)]
                else:
                    span_labels[-1] = "O"

            if end <= len(flat_sentences_tokens)-1:
                span = (start, end)
                span_index_map[span] = len(spans)
                sentence_spans.append(len(spans))
                spans.append(SpanField(start, end, text_field))
                sentence_index_span_map[sent_index].append((start, end))
        for i in range(len(sentence)):
            token_same_sentence_spans[i+sentence_offset] = sentence_spans
            token_sentence_start_end_map[i+sentence_offset] = (sentence_offset, sentence_offset+len(sentence)-1)
        sentence_offsets.append(sentence_offset)
        sentence_offset += len(sentence)

    if len(spans) == 0:
        return None
    span_field = ListField(spans)

    metadata: Dict[str, Any] = {"original_text": flattened_sentences, "sentence_offsets": sentence_offsets, "sentences": sentences, "sentence_index_span_map": sentence_index_span_map, "span_index_map": span_index_map}
    if gold_clusters is not None:
        metadata["clusters"] = gold_clusters
    if language is not None:
        metadata["language"] = language
    if sentence_objects is not None:
        metadata["sentence_objects"] = sentence_objects
    metadata_field = MetadataField(metadata)

    fields: Dict[str, Field] = {
        "text": text_field,
        "spans": span_field,
        "metadata": metadata_field,
    }
    if span_labels is not None and include_coref:
        fields["span_labels"] = SequenceLabelField(span_labels, span_field, label_namespace="span_labels")
    if parallel_sentences is not None:
        fields["parallel_text"] = TextField(flat_parallel_sentences_tokens, token_indexers)
    if include_srl and srl_frames is not None:
        predicate_span_pairs = []
        pair_labels = []
        filtered_srl_frames = []
        for predicate_index, arguments in srl_frames:
            filtered_arguments = []
            covered_spans = set()
            for (start, end, arg_type) in arguments:
                if (start, end) in span_index_map and (start, end) not in covered_spans:
                    if start == predicate_index == end:
                        continue
                    predicate_span_pairs.append((predicate_index, span_index_map[(start, end)]))
                    pair_labels.append(arg_type)
                    filtered_arguments.append((start, end, arg_type))
                    covered_spans.add((start, end))
            arguments_without_predicate = [arg for arg in arguments if arg[-1] != "V"]
            if len(arguments_without_predicate) > 0:
                filtered_srl_frames.append((predicate_index, arguments_without_predicate))
            if len(set([arg[:2] for arg in arguments])) < len([arg[:2] for arg in arguments]):
                print(predicate_index, arguments)
                print(flattened_sentences)
            if len(filtered_srl_frames) < len(srl_frames) and predicate_index == srl_frames[-1][0]:
                print(predicate_index, arguments, filtered_arguments)
                print('B', srl_frames, filtered_srl_frames)
                print(flattened_sentences)
        # if len(predicate_span_pairs) > 0:
        fields["srl_labels"] = AsymmetricAdjacencyField(predicate_span_pairs, text_field, span_field, labels=pair_labels, label_namespace="srl_labels")
        srl_seq_label_fields = []
        srl_seq_labels = []
        srl_seq_indices = []
        srl_seq_words = []
        predicate_indices = []
        max_seq_length = 0
        for frame in srl_frames:
            predicate_index, arguments = frame
            if predicate_index >= len(flat_sentences_tokens):
                continue
            sentence_start, sentence_end = token_sentence_start_end_map[predicate_index]
            seq_labels = ["O" for _ in range(sentence_start, sentence_end+1)]
            seq_labels[predicate_index-sentence_start] = "B-V"
            for (start, end, arg_type) in arguments:
                if any([seq_labels[idx-sentence_start] != "O" for idx in range(start, end+1)]):
                    continue
                seq_labels[start-sentence_start] = "B-"+arg_type
                for i in range(start+1, end+1):
                    seq_labels[i-sentence_start] = "I-"+arg_type
            srl_seq_indices.append(list(range(sentence_start, sentence_end+1)))
            sentence_field = TextField(flat_sentences_tokens[sentence_start:sentence_end+1], token_indexers)
            seq_label_field = SequenceLabelField(seq_labels, sentence_field, label_namespace="srl_seq_labels")
            srl_seq_label_fields.append(seq_label_field)
            predicate_indices.append(predicate_index)
            srl_seq_labels.append(seq_labels)
            srl_seq_words.append([word for word in flattened_sentences[sentence_start:sentence_end+1]])
            max_seq_length = max(max_seq_length, sentence_end+1-sentence_start)
        if len(srl_seq_label_fields) > 0 and named_entity_spans is None:
            fields["srl_seq_labels"] = ListField(srl_seq_label_fields)
            srl_seq_indices = [seq+[-1 for _ in range(max_seq_length-len(seq))] for seq in srl_seq_indices]
            fields["srl_seq_indices"] = ArrayField(np.array(srl_seq_indices, dtype=np.int64), dtype=np.int64, padding_value=-1)
            fields["srl_seq_predicates"] = ArrayField(np.array(predicate_indices, dtype=np.int64), dtype=np.int64, padding_value=-1)
            metadata["srl_seq_labels"] = srl_seq_labels
            metadata["srl_seq_words"] = srl_seq_words
        metadata["srl_frames"] = filtered_srl_frames
        word_span_coincidence = []
        for token in range(len(flat_sentences_tokens)):
            for span_index in token_same_sentence_spans[token]:
                word_span_coincidence.append((token, span_index))
        fields["word_span_mask"] = AsymmetricAdjacencyField(word_span_coincidence, text_field, span_field, padding_value=0)
    if include_ner and named_entities is not None:
        remap = {"B-OTHER": "O", "I-OTHER": "O", "B-NUMBER": "B-QUANTITY", "I-NUMBER": "I-QUANTITY"}
        named_entities = [ent if ent not in remap else remap[ent] for ent in named_entities]
        if wordpiece_modeling_tokenizer is not None:
            converted_named_entities = ["O" for _ in flat_sentences_tokens]
            for index, ne in enumerate(named_entities):
                if ne != "O":
                    converted_named_entities[offsets[index][0]] = ne
                    for i in range(offsets[index][0]+1, offsets[index][1]+1):
                        converted_named_entities[i] = "I-"+ne[2:]
            named_entities = converted_named_entities
        fields["ner_seq_labels"] = SequenceLabelField(named_entities[:len(flat_sentences_tokens)], text_field, label_namespace="ner_seq_labels")
        metadata["ner_seq_labels"] = named_entities[:len(flat_sentences_tokens)]
        if named_entity_spans is not None:
            ner_span_label_map = {(start, end): label for (start, end, label) in named_entity_spans}
            ner_span_labels = [None for _ in span_index_map]
            for span in span_index_map:
                if span in ner_span_label_map:
                    ner_span_labels[span_index_map[span]] = ner_span_label_map[span]
                else:
                    ner_span_labels[span_index_map[span]] = "None"
            fields["ner_span_labels"] = SequenceLabelField(ner_span_labels, span_field, label_namespace="ner_span_labels")
    metadata["document_id"] = document_id

    return Instance(fields)
Exemplo n.º 17
0
    def text_to_instance(
        self,  # type: ignore
        sentence: List[Token],
        gold_clusters: Optional[List[List[Tuple[int,
                                                int]]]] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[Token]``, required.
            The already tokenised sentence to analyse.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the sentence, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full sentence.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the sentence text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        metadata: Dict[str, Any] = {"original_text": sentence}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters

        text_field = TextField(sentence, self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None

        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append(SpanField(start, end, text_field))

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
Exemplo n.º 18
0
    def text_to_instance(self, sentence: List[str],
                         ner_dict: Dict[Tuple[int, int],
                                        str], relation_dict, cluster_dict,
                         trigger_dict, argument_dict, doc_key: str,
                         dataset: str, sentence_num: int, groups: List[str],
                         start_ix: int, end_ix: int, tree: Dict[str, Any],
                         syntax_dict: Dict[Tuple[int, int], str],
                         children_dict: Dict[Tuple[int, int],
                                             List[Tuple[int, int]]],
                         dep_children_dict: Dict[Tuple[int, int],
                                                 List[Tuple[int, int]]],
                         tf_dict: Dict[Tuple[int, int], Any]):
        """
        TODO(dwadden) document me.
        """

        sentence = [self._normalize_word(word) for word in sentence]

        text_field = TextField([Token(word) for word in sentence],
                               self._token_indexers)
        text_field_with_context = TextField([Token(word) for word in groups],
                                            self._token_indexers)

        # feili, NER labels. One label per token
        ner_sequence_labels = self._generate_ner_label(sentence, ner_dict)
        ner_sequence_label_field = SequenceLabelField(
            ner_sequence_labels,
            text_field,
            label_namespace="ner_sequence_labels")

        # Put together the metadata.
        metadata = dict(sentence=sentence,
                        ner_dict=ner_dict,
                        relation_dict=relation_dict,
                        cluster_dict=cluster_dict,
                        trigger_dict=trigger_dict,
                        argument_dict=argument_dict,
                        doc_key=doc_key,
                        dataset=dataset,
                        groups=groups,
                        start_ix=start_ix,
                        end_ix=end_ix,
                        sentence_num=sentence_num,
                        seq_dict=ner_sequence_labels,
                        tree=tree,
                        syntax_dict=syntax_dict,
                        children_dict=children_dict,
                        dep_children_dict=dep_children_dict)
        metadata_field = MetadataField(metadata)

        # Trigger labels. One label per token in the input.
        token_trigger_labels = []
        for i in range(len(text_field)):
            token_trigger_labels.append(trigger_dict[i])

        trigger_label_field = SequenceLabelField(
            token_trigger_labels, text_field, label_namespace="trigger_labels")

        # Generate fields for text spans, ner labels, coref labels.
        spans = []
        span_ner_labels = []
        # feili
        span_labels = []
        span_coref_labels = []
        span_syntax_labels = []
        span_children_labels = []
        dep_span_children_labels = []
        # span_children_syntax_labels = []
        span_tree_labels = []
        raw_spans = []
        assert len(syntax_dict) == len(children_dict)
        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            span_ix = (start, end)
            # here we need to consider how to use tree info
            # for example, use_tree, span is in tree, match is true or false
            # if self._tree_span_filter and not self._is_span_in_tree(span_ix, syntax_dict, children_dict):
            #     if len(raw_spans) == 0: # in case that there is no span for this instance
            #         pass
            #     else:
            #         continue
            span_tree_labels.append('1' if self._is_span_in_tree(
                span_ix, syntax_dict, children_dict) else '')

            span_ner_labels.append(ner_dict[span_ix])
            span_labels.append('' if ner_dict[span_ix] == '' else '1')
            span_coref_labels.append(cluster_dict[span_ix])
            spans.append(SpanField(start, end, text_field))
            span_syntax_labels.append(syntax_dict[span_ix])
            raw_spans.append(span_ix)

            # if len(children_dict[span_ix]) == 0:
            #     children_field = ListField([SpanField(-1, -1, text_field)])
            #     children_syntax_field = SequenceLabelField([''], children_field,
            #                                            label_namespace="span_syntax_labels")
            # else:
            #     children_field = ListField([SpanField(children_span[0], children_span[1], text_field)
            #                for children_span in children_dict[span_ix]])
            #     children_syntax_field = SequenceLabelField([syntax_dict[children_span] for children_span in children_dict[span_ix]],
            #                                                children_field, label_namespace="span_syntax_labels")
            # span_children_labels.append(children_field)
            # span_children_syntax_labels.append(children_syntax_field)

        span_field = ListField(spans)

        for span in raw_spans:

            if len(children_dict[span]) == 0:
                children_field = ListField([IndexField(-1, span_field)])
            else:
                children_field = []
                for children_span in children_dict[span]:
                    if children_span in raw_spans:
                        children_field.append(
                            IndexField(raw_spans.index(children_span),
                                       span_field))
                    else:
                        children_field.append(IndexField(-1, span_field))
                children_field = ListField(children_field)

            span_children_labels.append(children_field)

        # for span in raw_spans:
        #     if len(dep_children_dict[span]) == 0:
        #         children_field = ListField([IndexField(-1, span_field)])
        #     else:
        #         children_field = []
        #         for children_span in dep_children_dict[span]:
        #             if children_span in raw_spans:
        #                 children_field.append(IndexField(raw_spans.index(children_span), span_field))
        #             else:
        #                 children_field.append(IndexField(-1, span_field))
        #         children_field = ListField(children_field)
        #     dep_span_children_labels.append(children_field)

        n_tokens = len(sentence)
        candidate_indices = [(i, j) for i in range(n_tokens)
                             for j in range(n_tokens)]
        dep_adjs = []
        dep_adjs_indices = []
        # tf_indices = {}
        # tf_features = {}
        # for k, v in tf_dict.items():
        #     tf_indices[k] = []
        #     tf_features[k] = []
        tf_indices = []
        tf_features = []
        for token_pair in candidate_indices:
            dep_adj_label = dep_children_dict[token_pair]
            if dep_adj_label:
                dep_adjs_indices.append(token_pair)
                dep_adjs.append(dep_adj_label)

            # for k,v in tf_dict.items():
            #     feature = tf_dict[k][token_pair]
            #     if feature:
            #         tf_indices[k].append(token_pair)
            #         tf_features[k].append(feature)

            feature = tf_dict[token_pair]
            if feature:
                tf_indices.append(token_pair)
                tf_features.append(feature)

        ner_label_field = SequenceLabelField(span_ner_labels,
                                             span_field,
                                             label_namespace="ner_labels")
        coref_label_field = SequenceLabelField(span_coref_labels,
                                               span_field,
                                               label_namespace="coref_labels")
        # feili
        span_label_field = SequenceLabelField(span_labels,
                                              span_field,
                                              label_namespace="span_labels")

        # Generate labels for relations and arguments. Only store non-null values.
        # For the arguments, by convention the first span specifies the trigger, and the second
        # specifies the argument. Ideally we'd have an adjacency field between (token, span) pairs
        # for the event arguments field, but AllenNLP doesn't make it possible to express
        # adjacencies between two different sequences.
        n_spans = len(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]
        candidate_indices = [(i, j) for i in range(n_spans)
                             for j in range(n_spans)]

        relations = []
        relation_indices = []
        for i, j in candidate_indices:
            span_pair = (span_tuples[i], span_tuples[j])
            relation_label = relation_dict[span_pair]
            if relation_label:
                relation_indices.append((i, j))
                relations.append(relation_label)

        relation_label_field = AdjacencyField(
            indices=relation_indices,
            sequence_field=span_field,
            labels=relations,
            label_namespace="relation_labels")

        arguments = []
        argument_indices = []
        n_tokens = len(sentence)
        candidate_indices = [(i, j) for i in range(n_tokens)
                             for j in range(n_spans)]
        for i, j in candidate_indices:
            token_span_pair = (i, span_tuples[j])
            argument_label = argument_dict[token_span_pair]
            if argument_label:
                argument_indices.append((i, j))
                arguments.append(argument_label)

        argument_label_field = AdjacencyFieldAssym(
            indices=argument_indices,
            row_field=text_field,
            col_field=span_field,
            labels=arguments,
            label_namespace="argument_labels")

        # Syntax
        span_syntax_field = SequenceLabelField(
            span_syntax_labels,
            span_field,
            label_namespace="span_syntax_labels")
        span_children_field = ListField(span_children_labels)
        span_tree_field = SequenceLabelField(
            span_tree_labels, span_field, label_namespace="span_tree_labels")
        # span_children_syntax_field = ListField(span_children_syntax_labels)
        # dep_span_children_field = ListField(dep_span_children_labels)
        dep_span_children_field = AdjacencyField(
            indices=dep_adjs_indices,
            sequence_field=text_field,
            labels=dep_adjs,
            label_namespace="dep_adj_labels")

        # tf_f1_field = AdjacencyField(indices=tf_indices['F1'], sequence_field=text_field, labels=tf_features['F1'],
        #     label_namespace="tf_f1_labels")
        # tf_f2_field = AdjacencyField(indices=tf_indices['F2'], sequence_field=text_field, labels=tf_features['F2'],
        #                              label_namespace="tf_f2_labels")
        # tf_f3_field = AdjacencyField(indices=tf_indices['F3'], sequence_field=text_field, labels=tf_features['F3'],
        #                              label_namespace="tf_f3_labels")
        # tf_f4_field = AdjacencyField(indices=tf_indices['F4'], sequence_field=text_field, labels=tf_features['F4'],
        #                              label_namespace="tf_f4_labels")
        # tf_f5_field = AdjacencyField(indices=tf_indices['F5'], sequence_field=text_field, labels=tf_features['F5'],
        #                              label_namespace="tf_f5_labels")

        tf_field = AdjacencyField(indices=tf_indices,
                                  sequence_field=text_field,
                                  labels=tf_features,
                                  label_namespace="tf_labels")

        # Pull it  all together.
        fields = dict(
            text=text_field_with_context,
            spans=span_field,
            ner_labels=ner_label_field,
            coref_labels=coref_label_field,
            trigger_labels=trigger_label_field,
            argument_labels=argument_label_field,
            relation_labels=relation_label_field,
            metadata=metadata_field,
            span_labels=span_label_field,
            ner_sequence_labels=ner_sequence_label_field,
            syntax_labels=span_syntax_field,
            span_children=span_children_field,
            span_tree_labels=span_tree_field,
            dep_span_children=dep_span_children_field,
            # tf_f1 = tf_f1_field,
            # tf_f2 = tf_f2_field,
            # tf_f3 = tf_f3_field,
            # tf_f4 = tf_f4_field,
            # tf_f5 = tf_f5_field)
            tf=tf_field)
        # span_children_syntax=span_children_syntax_field)

        return Instance(fields)
Exemplo n.º 19
0
    def build_instance(
            self,  # type: ignore
            doc: List[List[str]],
            clusters: List[List[Tuple[int, int]]] = None,
            doc_relations: List[Dict[Tuple[Tuple[int, int], Tuple[int, int]],
                                     str]] = None,
            doc_ner_labels: List[Dict[Tuple[int, int], str]] = None,
            **kwargs) -> Instance:
        """
        Parameters
        ----------
        doc : ``List[List[str]]``, required.
            A list of lists representing the tokenized words and sentences in the document.
        clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        doc_relations : TODO

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.

        Extra fields:

            spans : see docstring
                Shape:  (num_spans)
                 0   1   2   3   4   5   6   7   8   9   10  11  12  13  14  15

            sentences_span_indices : list spans (absolute indices) for every sentence,
                can be used in order to isolate spans between different sentences.
                By design, the RelEx part considers intra-sentences truth_relations
                and is able to extract inter-sentence truth_relations with the help
                of already predicted sets of coreferences.

                Shape: (sentences_padded, spans_in_sentence_padded)
                 0   1   2   3
                 4   5   6   #
                 7   8   9   #
                10  11  12  13
                14  15   #   #
                Range: [0, ..., num_spans-1], # is padding

            TODO
            sentences_truth_spans : relative indices in sentence_spans
                correspond to at least one relation from truth.
                Intended to be used for effective packing and padding of
                the sparse matrix.

                PyTorch lacks of (at least stable) support of sparse tensors,
                and we aim to implement it ourselves. The matrix is not going
                to be encoded using COO because the sparsity of matrix is just
                an effect of sparsity of the truth spans. This matrix is simply
                compressed matrix w.r.t. COO-encoded spans that are going
                to be used for encoding the relation matrix.

                Shape: (sentences_padded, gold_spans_in_sentence_padded)
                 1   3
                 0   2
                 0   1
                 2   #
                 1   #
                Range: [0, ..., spans_in_sentence - 1], # is padding

            TODO
            sentences_spans_in_truth : simply the inverse of `sentences_truth_spans`
                This matrix can be also interpreted as boolean matrix of if the
                span occurs is truth span: if the element is not padded, it is,
                and the element points out where they occur in compressed matrix.

                Shape: (sentences_padded, spans_in_sentence_padded)
                 #   0   #   1
                 0   #   1   #
                 0   1   #   #
                 #   #   0   #
                 #   0   #   #
                Range: [0, ..., gold_spans_in_sentence_padded - 1], # is padding

            TODO
            sentences_relations : TODO

                Shape: (sentences_padded, gold_spans_in_sentence_padded, gold_spans_in_sentence_padded)
                Range: [0, ..., num_classes - 1], # is padding

            sentences_ner_labels : TODO

                Shape: TODO
                Range: TODO

        """

        metadatas: Dict[str, Any] = {}

        flattened_doc = [
            self._normalize_word(word) for sentence in doc for word in sentence
        ]
        metadatas["doc_tokens"] = doc
        metadatas["original_text"] = flattened_doc

        metadatas.update(kwargs)

        text_field = TextField([Token(word) for word in flattened_doc],
                               self._token_indexers)

        spans: List[SpanField] = []
        doc_span_offsets: List[List[int]] = []

        # Construct spans and mappings
        sentence_offset = 0
        for sentence in doc:
            sentence_spans: List[int] = []

            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                absolute_index = len(spans)
                spans.append(SpanField(start, end, text_field))
                sentence_spans.append(absolute_index)

            sentence_offset += len(sentence)
            doc_span_offsets.append(sentence_spans)

        # Just making fields out of the lists
        spans_field = OptionalListField(spans,
                                        empty_field=SpanField(
                                            -1, -1, text_field).empty_field())
        doc_span_offsets_field = ListField([
            OptionalListField([
                IndexField(span_offset, spans_field)
                for span_offset in sentence_span_offsets
            ],
                              empty_field=IndexField(
                                  -1, spans_field).empty_field())
            for sentence_span_offsets in doc_span_offsets
        ])

        # num_sentences = len(sentences)
        # num_spans = len(spans)
        # inverse_mapping = -np.ones(shape=(num_sentences, num_spans), dtype=int)
        # for sentence_id, indices in enumerate(sentences_span_indices):
        #     for gold_index, real_index in enumerate(indices.array):
        #         inverse_mapping[sentence_id, real_index] = gold_index

        # sentences_spans_field = ListField([
        #     ListField(spans) for spans in sentences_span_indices
        # ])
        # sentences_span_inverse_mapping_field = ArrayField(inverse_mapping, padding_value=-1)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": spans_field,
            "doc_span_offsets": doc_span_offsets_field
        }

        # TODO TODO TODO rename sentences to doc, sencence to snt

        for key, value in metadatas.items():
            fields[key] = MetadataField(value)

        if clusters is None or doc_relations is None:
            return Instance(fields)

        # Here we can be sure both `clusters` and `doc_relations` are given.
        # However, we can be sure yet whether `doc_ner_labels` is given or not.

        #
        #               TRUTH AFTER THIS ONLY
        #

        fields["clusters"] = MetadataField(clusters)
        cluster_dict = {(start, end): cluster_id
                        for cluster_id, cluster in enumerate(clusters)
                        for start, end in cluster}

        truth_spans = {
            span
            for sentence in doc_relations for spans, label in sentence.items()
            for span in spans
        }
        fields["truth_spans"] = MetadataField(truth_spans)

        span_labels: Optional[List[int]] = []
        doc_truth_spans: List[List[int]] = []
        doc_spans_in_truth: List[List[int]] = []

        for sentence, sentence_spans_field in zip(doc, doc_span_offsets_field):
            sentence_truth_spans: List[IndexField] = []
            sentence_spans_in_truth: List[int] = []

            for relative_index, span in enumerate(sentence_spans_field):
                absolute_index = cast(IndexField, span).sequence_index
                span_field: SpanField = cast(SpanField,
                                             spans_field[absolute_index])

                start = span_field.span_start
                end = span_field.span_end

                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

                compressed_index = -1
                if (start, end) in truth_spans:
                    compressed_index = len(sentence_truth_spans)
                    sentence_truth_spans.append(
                        IndexField(relative_index, sentence_spans_field))

                sentence_spans_in_truth.append(compressed_index)

            sentence_truth_spans_field = OptionalListField(
                sentence_truth_spans,
                empty_field=IndexField(-1, sentence_spans_field).empty_field())
            doc_truth_spans.append(sentence_truth_spans_field)

            sentence_spans_in_truth_field = OptionalListField(
                [
                    IndexField(compressed_index, sentence_truth_spans_field)
                    for compressed_index in sentence_spans_in_truth
                ],
                empty_field=IndexField(
                    -1, sentence_truth_spans_field).empty_field())
            doc_spans_in_truth.append(sentence_spans_in_truth_field)

        span_labels_field = SequenceLabelField(span_labels, spans_field)
        doc_truth_spans_field = ListField(doc_truth_spans)
        doc_spans_in_truth_field = ListField(doc_spans_in_truth)

        fields["span_labels"] = span_labels_field
        fields["doc_truth_spans"] = doc_truth_spans_field
        fields["doc_spans_in_truth"] = doc_spans_in_truth_field

        # "sentences_span_inverse_mapping": sentences_span_inverse_mapping_field,
        # "truth_relations": MetadataField(truth_relations)

        # our code

        # test code
        # sample_label = LabelField('foo')
        # sample_list = ListField([sample_label,  sample_label])
        # sample_seq_labels = SequenceLabelField(labels=['bar', 'baz'],
        #                                        sequence_field=sample_list)
        #
        # empty_seq_labels = sample_seq_labels.empty_field()

        # TODO reverse matrix generation tactic
        # TODO Add dummy
        doc_relex_matrices: List[AdjacencyField] = []
        for (sentence, truth_relations, sentence_spans, truth_spans_field,
             spans_in_truth) in zip(doc, doc_relations, doc_span_offsets,
                                    doc_truth_spans_field,
                                    doc_spans_in_truth_field):

            relations = collections.defaultdict(str)
            for (span_a, span_b), label in truth_relations.items():
                # Span absolute indices (document-wide indexing)
                try:
                    a_absolute_index = spans.index(span_a)
                    b_absolute_index = spans.index(span_b)
                    # Fill the dict as sparse matrix, padded with zeros
                    relations[a_absolute_index, b_absolute_index] = label
                except ValueError:
                    logger.warning('Span not found')

            indices: List[Tuple[int, int]] = []
            labels: List[str] = []

            for span_a, span_b in itertools.product(
                    enumerate(truth_spans_field), repeat=2):
                a_compressed_index, a_relative = cast(Tuple[int, IndexField],
                                                      span_a)
                b_compressed_index, b_relative = cast(Tuple[int, IndexField],
                                                      span_b)

                a_absolute = sentence_spans[a_relative.sequence_index]
                b_absolute = sentence_spans[b_relative.sequence_index]

                label = relations[a_absolute, b_absolute]

                indices.append((a_compressed_index, b_compressed_index))
                labels.append(label)

            doc_relex_matrices.append(
                AdjacencyField(indices=indices,
                               labels=labels,
                               sequence_field=truth_spans_field,
                               label_namespace="relation_labels")
            )  # TODO pad with zeros maybe?

        # fields["doc_relations"] = MetadataField(doc_relations)
        fields["doc_relation_labels"] = ListField(doc_relex_matrices)

        # gold_candidates = []
        # gold_candidate_labels = []
        #
        # for sentence in sentences_relations:
        #
        #     candidates: List[ListField[SpanField]] = []
        #     candidate_labels: List[LabelField] = []
        #
        #     for label, (a_start, a_end), (b_start, b_end) in sentence:
        #         a_span = SpanField(a_start, a_end, text_field)
        #         b_span = SpanField(b_start, b_end, text_field)
        #         candidate_field = ListField([a_span, b_span])
        #         label_field = OptionalLabelField(label, 'relation_labels')
        #
        #         candidates.append(candidate_field)
        #         candidate_labels.append(label_field)
        #
        #     # if not candidates:
        #     #     continue
        #     #     # TODO very very tmp
        #
        #     empty_text = text_field.empty_field()
        #     empty_span = SpanField(-1, -1, empty_text).empty_field()
        #     empty_candidate = ListField([empty_span, empty_span]).empty_field()
        #     empty_candidates = ListField([empty_candidate]).empty_field()
        #     empty_label = OptionalLabelField('', 'relation_labels')  # .empty_field()?
        #     empty_candidate_labels = ListField([empty_label])  # ? .empty_field() ?
        #
        #     if candidates:
        #         candidates_field = ListField(candidates)
        #         candidate_labels_field = ListField(candidate_labels)
        #     else:
        #         candidates_field = empty_candidates
        #         candidate_labels_field = empty_candidate_labels
        #
        #     gold_candidates.append(candidates_field)
        #     gold_candidate_labels.append(candidate_labels_field)
        #
        # fields["gold_candidates"] = ListField(gold_candidates)
        # fields["gold_candidate_labels"] = ListField(gold_candidate_labels)
        #
        # fields["sentences_relations"] = MetadataField(sentences_relations)

        if doc_ner_labels is None:
            return Instance(fields)

        # NER
        doc_ner: List[OptionalListField[LabelField]] = []

        sentence_offset = 0
        for sentence, sentence_ner_dict in zip(doc, doc_ner_labels):
            sentence_ner_labels: List[LabelField] = []

            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if (start, end) in sentence_ner_dict:
                    label = sentence_ner_dict[(start, end)]
                    sentence_ner_labels.append(LabelField(label, 'ner_labels'))
                else:
                    sentence_ner_labels.append(LabelField('O', 'ner_labels'))

            sentence_offset += len(sentence)
            sentence_ner_labels_field = OptionalListField(
                sentence_ner_labels,
                empty_field=LabelField('*', 'ner_tags').empty_field())
            doc_ner.append(sentence_ner_labels_field)

        doc_ner_field = ListField(doc_ner)
        fields["doc_ner_labels"] = doc_ner_field

        return Instance(fields)
Exemplo n.º 20
0
    def __getitem__(self, item):
        """
		Args:
			item: int, idx
		Returns:
			tokens: tokens of query + context, [seq_len]
			token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
			start_labels: start labels of NER in tokens, [seq_len]
			end_labels: end labels of NER in tokens, [seq_len]
			label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
			match_labels: match labels, [seq_len, seq_len]
			sample_idx: sample id
			label_idx: label id

		"""
        cls_tok = "[CLS]"
        sep_tok = "[SEP]"
        if 'roberta' in self.args.bert_config_dir:
            cls_tok = "<s>"
            sep_tok = "</s>"

        # begin{get the label2idx dictionary}
        label2idx = {}
        label2idx_list = self.args.label2idx_list
        for labidx in label2idx_list:
            lab, idx = labidx
            label2idx[lab] = int(idx)
        # print('label2idx: ',label2idx)
        # end{get the label2idx dictionary}

        # begin{get the morph2idx dictionary}
        morph2idx = {}
        morph2idx_list = self.args.morph2idx_list
        for morphidx in morph2idx_list:
            morph, idx = morphidx
            morph2idx[morph] = int(idx)
        # end{get the morph2idx dictionary}

        data = self.all_data[item]
        tokenizer = self.tokenzier

        # AutoTokenizer(self.args.bert_config_dir)

        qas_id = data.get("qas_id", "0.0")
        sample_idx, label_idx = qas_id.split(".")

        sample_idx = torch.LongTensor([int(sample_idx)])
        label_idx = torch.LongTensor([int(label_idx)])

        query = data["query"]
        context = data["context"].strip()
        if '\u200b' in context:
            context = context.replace('\u200b', '')
        elif '\ufeff' in context:
            context = context.replace('\ufeff', '')
        elif '  ' in context:
            context = context.replace('  ', ' ')

        span_position_label = data["span_position_label"]
        # context = "Japan -DOCSTART- began the defence of their Asian Cup on Friday ."

        start_positions = []
        end_positions = []

        for seidx, label in span_position_label.items():
            sidx, eidx = seidx.split(';')
            start_positions.append(int(sidx))
            end_positions.append(int(eidx))

        # add space offsets
        words = context.split()

        # convert the span position into the character index, space is also a position.
        pos_start_positions = start_positions
        pos_end_positions = end_positions

        pos_span_idxs = []
        for sidx, eidx in zip(pos_start_positions, pos_end_positions):
            pos_span_idxs.append((sidx, eidx))

        # all span (sidx, eidx)
        all_span_idxs = enumerate_spans(context.split(),
                                        offset=0,
                                        max_span_width=self.args.max_span_len)
        # get the span-length of each span

        # begin{compute the span weight}
        all_span_weights = []

        for span_idx in all_span_idxs:
            weight = self.args.neg_span_weight
            if span_idx in pos_span_idxs:
                weight = 1.0
            all_span_weights.append(weight)
        # end{compute the span weight}

        all_span_lens = []
        for idxs in all_span_idxs:
            sid, eid = idxs
            slen = eid - sid + 1
            all_span_lens.append(slen)

        morph_idxs = self.case_feature_tokenLevel(morph2idx, all_span_idxs,
                                                  words,
                                                  self.args.max_span_len)

        if 'roberta' in self.args.bert_config_dir:

            tokenizer.post_processor = TemplateProcessing(
                single="<s> $A </s>",
                pair="<s> $A </s> $B:1 </s>:1",
                special_tokens=[
                    ("<s>", tokenizer.token_to_id("<s>")),
                    ("</s>", tokenizer.token_to_id("</s>")),
                ],
            )
            tokenizer._tokenizer.post_processor = BertProcessing(
                ("</s>", tokenizer.token_to_id("</s>")),
                ("<s>", tokenizer.token_to_id("<s>")),
            )
            p1 = tokenizer.token_to_id("<s>")
            p2 = tokenizer.token_to_id("</s>")
            print("p1", p1)
            print("p2", p2)

        query_context_tokens = tokenizer.encode(context,
                                                add_special_tokens=True)
        tokens = query_context_tokens.ids  # subword index
        type_ids = query_context_tokens.type_ids  # the split of two sentence on the subword-level, 0 for first sent, 1 for the second sent
        offsets = query_context_tokens.offsets  # the subword's start-index and end-idx of the character-level.

        # print("current sent: ", context)
        all_span_idxs_ltoken, all_span_word, all_span_idxs_new_label = self.convert2tokenIdx(
            words, tokens, type_ids, offsets, all_span_idxs,
            span_position_label)
        pos_span_idxs_ltoken, pos_span_word, pos_span_idxs_new_label = self.convert2tokenIdx(
            words, tokens, type_ids, offsets, pos_span_idxs,
            span_position_label)

        span_label_ltoken = []
        for seidx_str, label in all_span_idxs_new_label.items():
            span_label_ltoken.append(label2idx[label])
        '''
		an example of tokens, type_ids, and offsets value.
		inputs: 
			query = "you are beautiful ."
			context = 'i love you .'

		outputs:
			tokens:  [101, 2017, 2024, 3376, 1012, 102, 1045, 2293, 2017, 1012, 102]
			type_ids:  [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
			offsets:  [(0, 0), (0, 3), (4, 7), (8, 17), (18, 19), (0, 0), (0, 1), (2, 6), (7, 10), (11, 12), (0, 0)]
			query_context_tokens.tokens: ['[CLS]', 'you', 'are', 'beautiful', '.', '[SEP]', 'i', 'love', 'you', '.', '[SEP]']
			query_context_tokens.words:  [None, 0, 1, 2, 3, None, 0, 1, 2, 3, None]
		'''

        # # the max-end-index should not exceed the max-length.
        # all_span_idxs_ltoken

        # return  tokens, type_ids, all_span_idxs_ltoken, pos_span_mask_ltoken
        # truncate
        tokens = tokens[:self.max_length]
        type_ids = type_ids[:self.max_length]
        all_span_idxs_ltoken = all_span_idxs_ltoken[:self.max_num_span]
        # pos_span_mask_ltoken = pos_span_mask_ltoken[:self.max_num_span]
        span_label_ltoken = span_label_ltoken[:self.max_num_span]
        all_span_lens = all_span_lens[:self.max_num_span]
        morph_idxs = morph_idxs[:self.max_num_span]
        all_span_weights = all_span_weights[:self.max_num_span]

        # make sure last token is [SEP]
        sep_token = tokenizer.token_to_id(sep_tok)
        if tokens[-1] != sep_token:
            assert len(tokens) == self.max_length
            tokens = tokens[:-1] + [sep_token]

        # padding to the max length.
        import numpy as np
        real_span_mask_ltoken = np.ones_like(span_label_ltoken)
        if self.pad_to_maxlen:
            tokens = self.pad(tokens, 0)
            type_ids = self.pad(type_ids, 1)
            all_span_idxs_ltoken = self.pad(all_span_idxs_ltoken,
                                            value=(0, 0),
                                            max_length=self.max_num_span)
            # pos_span_mask_ltoken = self.pad(pos_span_mask_ltoken,value=0,max_length=self.max_num_span)
            real_span_mask_ltoken = self.pad(real_span_mask_ltoken,
                                             value=0,
                                             max_length=self.max_num_span)
            span_label_ltoken = self.pad(span_label_ltoken,
                                         value=0,
                                         max_length=self.max_num_span)
            all_span_lens = self.pad(all_span_lens,
                                     value=0,
                                     max_length=self.max_num_span)
            morph_idxs = self.pad(morph_idxs,
                                  value=0,
                                  max_length=self.max_num_span)
            all_span_weights = self.pad(all_span_weights,
                                        value=0,
                                        max_length=self.max_num_span)

        tokens = torch.LongTensor(tokens)
        type_ids = torch.LongTensor(
            type_ids)  # use to split the first and second sentence.
        all_span_idxs_ltoken = torch.LongTensor(all_span_idxs_ltoken)
        # pos_span_mask_ltoken = torch.LongTensor(pos_span_mask_ltoken)
        real_span_mask_ltoken = torch.LongTensor(real_span_mask_ltoken)
        span_label_ltoken = torch.LongTensor(span_label_ltoken)
        all_span_lens = torch.LongTensor(all_span_lens)
        morph_idxs = torch.LongTensor(morph_idxs)
        # print("all_span_weights: ",all_span_weights)
        all_span_weights = torch.Tensor(all_span_weights)

        min_idx = np.max(np.array(all_span_idxs_ltoken))

        return [
            tokens,
            type_ids,  # use to split the first and second sentence.
            all_span_idxs_ltoken,
            morph_idxs,
            span_label_ltoken,
            all_span_lens,
            all_span_weights,

            # pos_span_mask_ltoken,
            real_span_mask_ltoken,
            words,
            all_span_word,
            all_span_idxs,
        ]
Exemplo n.º 21
0
    def text_to_instance(
        self,  # type: ignore
        sentences: List[List[str]],
        gold_clusters: Optional[List[List[Tuple[int,
                                                int]]]] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenised words and sentences in the document.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]
        # align clusters
        gold_clusters = self.align_clusters_to_tokens(flattened_sentences,
                                                      gold_clusters)

        def tokenizer(s: str):
            return self.token_indexer.wordpiece_tokenizer(s)

        flattened_sentences = tokenizer(" ".join(flattened_sentences))
        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters
        if len(flattened_sentences) > 512:
            #import pdb
            #pdb.set_trace()
            text_field = TextField(
                [Token(word) for word in flattened_sentences[:512]],
                self._token_indexers)
            total_list = [text_field]
            import math
            for i in range(
                    math.ceil(float(len(flattened_sentences[512:])) / 100.0)):
                # slide by 100
                text_field = TextField([
                    Token(word)
                    for word in flattened_sentences[512 + (i * 100):512 +
                                                    ((i + 1) * 100)]
                ], self._token_indexers)
                total_list.append(text_field)
            text_field = ListField(total_list)
            # doing the Listfield

        else:
            text_field = TextField(
                [Token(word) for word in flattened_sentences],
                self._token_indexers)
        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None
        sentence_offset = 0
        normal = []
        for sentence in sentences:
            # enumerate the spans.
            for start, end in enumerate_spans(
                    sentence,
                    offset=sentence_offset,
                    max_span_width=self._max_span_width):
                if span_labels is not None:
                    if (start, end) in cluster_dict:
                        span_labels.append(cluster_dict[(start, end)])
                    else:
                        span_labels.append(-1)
                # align the spans to the BERT tokeniation
                normal.append((start, end))
                # span field for Span, which needs to be a flattened esnetnece.
                span_field = text_field
                """
                if len(flattened_sentences) > 512:
                    span_field = TextField([Token(["[CLS]"])] + [Token(word) for word in flattened_sentences]+ [Token(["[SEP]"])] , self._token_indexers) 
                else:
                    span_field = text_field
                """
                spans.append(SpanField(start, end, span_field))
            sentence_offset += len(sentence)

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)
        return Instance(fields)
Exemplo n.º 22
0
    def text_to_instance(self, # type: ignore
                         tokens: List[Token],
                         ner_tags: List[str] = None) -> Instance:
        """
        We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
        """
        sequence = TextField(tokens, self._token_indexers)
        instance_fields: Dict[str, Field] = {'tokens': sequence}

        def _remove_BI(_one_tag):
            if _one_tag == 'O':
                return _one_tag
            else:
                return _one_tag[2:]
        
        if self.coding_scheme == "BIOUL":
            coded_ner = to_bioul(ner_tags,
                                 encoding=self._original_coding_scheme) if ner_tags is not None else None
        else:
            # the default IOB1
            coded_ner = ner_tags

        # TODO:
        # ner_tags -> spans of NE
        # return something like spans, span_labels ("O" if span not in golden_spans, "PER", "LOC"... otherwise)
        spans: List[Field] = []
        span_labels: List[str] = []
            
        gold_spans: List[Field] = []
        gold_span_labels: List[str] = []

        assert len(ner_tags) == len(tokens), "sentence:%s but ner_tags:%s"%(str(tokens), str(ner_tags))
        ner_gold_spans = _extract_spans(ner_tags) # ner_gold_spans: Dict[tuple(startid, endid), str(entity_type)]
        for start, end in enumerate_spans(ner_tags, offset=0, max_span_width=self._max_span_width):
            span_labels.append(ner_gold_spans.get((start, end), 'O'))
            spans.append(SpanField(start, end, sequence))
            pass
        
        _dict_gold_spans = {}
        for ky, val in ner_gold_spans.items():
            gold_span_labels.append(val)
            gold_spans.append(SpanField(ky[0], ky[1], sequence))
            if val != 'O':
                _dict_gold_spans[ky] = val
            pass
        
        instance_fields["metadata"] = MetadataField({"words": [x.text for x in tokens] ,
                                                    "gold_spans": _dict_gold_spans})
        
        assert len(spans) == len(span_labels), "span length not equal to span label length..."
        span_field = ListField(spans) # a list of (start, end) tuples...
        
        # contains all possible spans and their tags
        instance_fields['spans'] = span_field
        instance_fields['span_labels'] = SequenceLabelField(span_labels, span_field, "span_tags")
        
        # only contain gold_spans and their tags
        # e.g. (0,0,O), (1,1,O), (2,3,PER), (4,4,O) for 'I am Donald Trump .'
        gold_span_field = ListField(gold_spans)
        instance_fields['gold_spans'] = gold_span_field
        instance_fields['gold_span_labels'] = SequenceLabelField(gold_span_labels, 
                                                                 gold_span_field, "span_tags")


        # Add "tag label" to instance
        if self.tag_label == 'ner' and coded_ner is not None:
            instance_fields['tags'] = SequenceLabelField(coded_ner, sequence,
                                                         'token_tags')
        return Instance(instance_fields)
Exemplo n.º 23
0
    def text_to_instance(self, sentence: List[str],
                         ner_dict: Dict[Tuple[int, int], str], relation_dict,
                         cluster_dict, trigger_dict, argument_dict,
                         doc_key: str, dataset: str, sentence_num: int,
                         groups: List[str], start_ix: int, end_ix: int):
        """
        TODO(dwadden) document me.
        """

        sentence = [self._normalize_word(word) for word in sentence]

        text_field = TextField([Token(word) for word in sentence],
                               self._token_indexers)
        text_field_with_context = TextField([Token(word) for word in groups],
                                            self._token_indexers)

        # Put together the metadata.
        metadata = dict(sentence=sentence,
                        ner_dict=ner_dict,
                        relation_dict=relation_dict,
                        cluster_dict=cluster_dict,
                        trigger_dict=trigger_dict,
                        argument_dict=argument_dict,
                        doc_key=doc_key,
                        dataset=dataset,
                        groups=groups,
                        start_ix=start_ix,
                        end_ix=end_ix,
                        sentence_num=sentence_num)
        metadata_field = MetadataField(metadata)

        # Trigger labels. One label per token in the input.
        token_trigger_labels = []
        for i in range(len(text_field)):
            token_trigger_labels.append(trigger_dict[i])

        trigger_label_field = SequenceLabelField(
            token_trigger_labels, text_field, label_namespace="trigger_labels")

        # Generate fields for text spans, ner labels, coref labels.
        spans = []
        span_ner_labels = []
        span_coref_labels = []
        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            span_ix = (start, end)
            span_ner_labels.append(ner_dict[span_ix])
            span_coref_labels.append(cluster_dict[span_ix])
            spans.append(SpanField(start, end, text_field))

        span_field = ListField(spans)
        ner_label_field = SequenceLabelField(span_ner_labels,
                                             span_field,
                                             label_namespace="ner_labels")
        coref_label_field = SequenceLabelField(span_coref_labels,
                                               span_field,
                                               label_namespace="coref_labels")

        # Generate labels for relations and arguments. Only store non-null values.
        # For the arguments, by convention the first span specifies the trigger, and the second
        # specifies the argument. Ideally we'd have an adjacency field between (token, span) pairs
        # for the event arguments field, but AllenNLP doesn't make it possible to express
        # adjacencies between two different sequences.
        n_spans = len(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]
        candidate_indices = [(i, j) for i in range(n_spans)
                             for j in range(n_spans)]

        relations = []
        relation_indices = []
        for i, j in candidate_indices:
            span_pair = (span_tuples[i], span_tuples[j])
            relation_label = relation_dict[span_pair]
            if relation_label:
                relation_indices.append((i, j))
                relations.append(relation_label)

        relation_label_field = AdjacencyField(
            indices=relation_indices,
            sequence_field=span_field,
            labels=relations,
            label_namespace="relation_labels")

        arguments = []
        argument_indices = []
        n_tokens = len(sentence)
        candidate_indices = [(i, j) for i in range(n_tokens)
                             for j in range(n_spans)]
        for i, j in candidate_indices:
            token_span_pair = (i, span_tuples[j])
            argument_label = argument_dict[token_span_pair]
            if argument_label:
                argument_indices.append((i, j))
                arguments.append(argument_label)

        argument_label_field = AdjacencyFieldAssym(
            indices=argument_indices,
            row_field=text_field,
            col_field=span_field,
            labels=arguments,
            label_namespace="argument_labels")

        # Pull it  all together.
        fields = dict(text=text_field_with_context,
                      spans=span_field,
                      ner_labels=ner_label_field,
                      coref_labels=coref_label_field,
                      trigger_labels=trigger_label_field,
                      argument_labels=argument_label_field,
                      relation_labels=relation_label_field,
                      metadata=metadata_field)

        return Instance(fields)
Exemplo n.º 24
0
    def text_to_instance(self, sample: Dict[str, Any], training: bool = True):
        text = sample["text"]
        words = [text[t["span"]["start"]: t["span"]["end"]] for t in sample["tokens"]]
        tokens = [Token(w) for w in words]
        entities = extract_entities(sample)

        relations = extract_relations_from_smart_sample(sample, include_trigger=True)

        sequence = TextField(tokens, self._token_indexers)
        instance_fields: Dict[str, Field] = {"tokens": sequence}
        words = [x.text for x in tokens]
        spans = []
        span_masks = []

        context_size = len(words) + 1
        for start, end in enumerate_spans(words,
                                          max_span_width=self._max_span_width):  # TODO beim training wird eigentlich keine vollständige candidate liste genommen
            assert start >= 0
            assert end >= 0
            spans.append(SpanField(start, end, sequence))
            span_masks.append(create_mask(start, end, context_size))

        instance_fields["span_masks"] = ListField(
            [ArrayField(np.array(si, dtype=np.int), dtype=np.int) for si in span_masks])

        span_field = ListField(spans)

        span_tuples = [(span.span_start, span.span_end) for span in spans]
        instance_fields["spans"] = span_field  # TODO was ist mit dem negative sampling?

        ner_labels = ["O" for i in span_tuples]

        ner_list = [((e.start, e.end), e.role) for e in entities]

        for span, label in ner_list:
            if self._too_long(span):
                continue
            ix = span_tuples.index(span)
            ner_labels[ix] = label

        # TODO Evaluate if this should be a MultiLabel instead of Label
        instance_fields["ner_labels"] = ListField(
            [LabelField(entry, label_namespace="ner_labels") for entry in ner_labels])

        pos_span_pairs = []
        pos_span_labels = []
        pos_span_masks = []

        for rel in relations:
            mand_arg_roles = relation_args_names[rel.label]
            try:
                # TODO handle special case Merger
                s1, s2 = sorted([s for s in rel.spans if s.role in mand_arg_roles], key=lambda x: x.role)
                pos_span_pairs += [(span_tuples.index(s1.span), span_tuples.index(s2.span))]
                pos_span_labels += [[rel.label]]
                pos_span_masks.append(create_rel_mask((s1.start, s1.end), (s2.start, s2.end), context_size))
            except ValueError:
                pass
            except Exception:
                i = 10

        neg_span_pairs = []
        neg_span_labels = []
        neg_span_masks = []
        if len(ner_list) < 2:
            ner_cands = random.sample(span_tuples, min(len(span_tuples), 7))

            ner_cands = [nc for nc in ner_cands if not self._too_long(nc)]

            ner_list += [(s, "") for s in ner_cands]

        for i1, s1 in enumerate(ner_list):
            for i2, s2 in enumerate(ner_list):
                # rev = (s2, s1)
                # rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric
                if self._too_long(s1[0]) or self._too_long(s2[0]):
                    continue
                # candidate
                cand = (span_tuples.index(s1[0]), span_tuples.index(s2[0]))

                # do not add as negative relation sample:
                # neg. relations from an entity to itself
                # entity pairs that are related according to gt
                # entity pairs whose reverse exists as a symmetric relation in gt
                # if s1 != s2 and (s1, s2) not in pos_span_pairs and not rev_symmetric:
                if cand[0] != cand[1] and cand not in pos_span_pairs:
                    neg_span_pairs += [cand]
                    neg_span_labels += [[]]
                    neg_span_masks.append(create_rel_mask(s1[0], s2[0], context_size))

        negative_samples = random.sample(
                list(zip(neg_span_pairs, neg_span_labels, neg_span_masks)),
                min(len(neg_span_labels), self._max_relation_negative_samples)
        )
        neg_span_pairs = [ns[0] for ns in negative_samples]
        neg_span_labels = [ns[1] for ns in negative_samples]
        neg_span_masks = [ns[2] for ns in negative_samples]

        relation_spans = pos_span_pairs + neg_span_pairs
        relation_labels = pos_span_labels + neg_span_labels
        relation_masks = pos_span_masks + neg_span_masks

        if relation_spans:
            rels_sample_masks = np.ones(len(relation_spans))
        else:
            rels_sample_masks = np.zeros(1)

        instance_fields["rels_sample_masks"] = ArrayField(rels_sample_masks, dtype=np.bool)

        instance_fields["relation_masks"] = ListField(
            [ArrayField(np.array(si, dtype=np.int), dtype=np.int) for si in relation_masks])

        instance_fields["rel_span_indices"] = ListField(
            [ArrayField(np.array(si, dtype=np.int), dtype=np.int) for si in relation_spans])

        instance_fields["rel_labels"] = ListField(
            [MultiLabelField(rel_label, label_namespace="rel_labels") for rel_label in relation_labels])

        metadata = {"words": words, "relations": relations}
        instance_fields["metadata"] = MetadataField(metadata)

        return Instance(instance_fields)
Exemplo n.º 25
0
    def text_to_instance(self,  # type: ignore
                         sentence: List[Token],
                         gold_clusters: Optional[List[List[Tuple[int, int]]]] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[Token]``, required.
            The already tokenised sentence to analyse.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the sentence, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full sentence.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the sentence text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        metadata: Dict[str, Any] = {"original_text": sentence}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters

        text_field = TextField(sentence, self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        spans: List[Field] = []
        span_labels: Optional[List[int]] = [] if gold_clusters is not None else None

        for start, end in enumerate_spans(sentence, max_span_width=self._max_span_width):
            if span_labels is not None:
                if (start, end) in cluster_dict:
                    span_labels.append(cluster_dict[(start, end)])
                else:
                    span_labels.append(-1)

            spans.append(SpanField(start, end, text_field))

        span_field = ListField(spans)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {"text": text_field,
                                    "spans": span_field,
                                    "metadata": metadata_field}
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
Exemplo n.º 26
0
    def text_to_instance(self, sentence: List[str],
                         ner_dict: Dict[Tuple[int, int],
                                        str], relation_dict, doc_key: str,
                         dataset: str, sentence_num: int, groups: List[str],
                         start_ix: int, end_ix: int, tree: Dict[str, Any],
                         children_dict: Dict[Tuple[int, int],
                                             List[Tuple[int, int]]],
                         dep_children_dict: Dict[Tuple[int, int],
                                                 List[Tuple[int, int]]],
                         tf_dict: Dict[Tuple[int, int], Any]):

        sentence = [self._normalize_word(word) for word in sentence]

        text_field = TextField([Token(word) for word in sentence],
                               self._token_indexers)
        text_field_with_context = TextField([Token(word) for word in groups],
                                            self._token_indexers)

        # Put together the metadata.
        metadata = dict(sentence=sentence,
                        ner_dict=ner_dict,
                        relation_dict=relation_dict,
                        doc_key=doc_key,
                        dataset=dataset,
                        groups=groups,
                        start_ix=start_ix,
                        end_ix=end_ix,
                        sentence_num=sentence_num,
                        tree=tree,
                        children_dict=children_dict,
                        dep_children_dict=dep_children_dict)
        metadata_field = MetadataField(metadata)

        # Generate fields for text spans, ner labels
        spans = []
        span_ner_labels = []
        span_children_labels = []
        raw_spans = []

        for start, end in enumerate_spans(sentence,
                                          max_span_width=self._max_span_width):
            span_ix = (start, end)
            span_ner_labels.append(ner_dict[span_ix])
            spans.append(SpanField(start, end, text_field))
            raw_spans.append(span_ix)

        span_field = ListField(spans)

        for span in raw_spans:

            if len(children_dict[span]) == 0:
                children_field = ListField([IndexField(-1, span_field)])
            else:
                children_field = []
                for children_span in children_dict[span]:
                    if children_span in raw_spans:
                        children_field.append(
                            IndexField(raw_spans.index(children_span),
                                       span_field))
                    else:
                        children_field.append(IndexField(-1, span_field))
                children_field = ListField(children_field)

            span_children_labels.append(children_field)

        n_tokens = len(sentence)
        candidate_indices = [(i, j) for i in range(n_tokens)
                             for j in range(n_tokens)]
        dep_adjs = []
        dep_adjs_indices = []
        tf_indices = []
        tf_features = []
        for token_pair in candidate_indices:
            dep_adj_label = dep_children_dict[token_pair]
            if dep_adj_label:
                dep_adjs_indices.append(token_pair)
                dep_adjs.append(dep_adj_label)

            feature = tf_dict[token_pair]
            if feature:
                tf_indices.append(token_pair)
                tf_features.append(feature)

        ner_label_field = SequenceLabelField(span_ner_labels,
                                             span_field,
                                             label_namespace="ner_labels")

        n_spans = len(spans)
        span_tuples = [(span.span_start, span.span_end) for span in spans]
        candidate_indices = [(i, j) for i in range(n_spans)
                             for j in range(n_spans)]

        relations = []
        relation_indices = []
        for i, j in candidate_indices:
            span_pair = (span_tuples[i], span_tuples[j])
            relation_label = relation_dict[span_pair]
            if relation_label:
                relation_indices.append((i, j))
                relations.append(relation_label)

        relation_label_field = AdjacencyField(
            indices=relation_indices,
            sequence_field=span_field,
            labels=relations,
            label_namespace="relation_labels")

        # Syntax
        span_children_field = ListField(span_children_labels)
        dep_span_children_field = AdjacencyField(
            indices=dep_adjs_indices,
            sequence_field=text_field,
            labels=dep_adjs,
            label_namespace="dep_adj_labels")

        tf_field = AdjacencyField(indices=tf_indices,
                                  sequence_field=text_field,
                                  labels=tf_features,
                                  label_namespace="tf_labels")

        fields = dict(text=text_field_with_context,
                      spans=span_field,
                      ner_labels=ner_label_field,
                      relation_labels=relation_label_field,
                      metadata=metadata_field,
                      span_children=span_children_field,
                      dep_span_children=dep_span_children_field,
                      tf=tf_field)

        return Instance(fields)
Exemplo n.º 27
0
 def enumerate_spans(self, *args, **kwargs):
     for start, end in enumerate_spans(self.doc, *args, **kwargs):
         yield start, end + 1  # enumerate_spans is inclusive
Exemplo n.º 28
0
    def text_to_instance(
        self,  # type: ignore
        sentences: List[List[str]],
        gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
        *,
        mention_token_spans: Optional[Sequence[Tuple[int, int]]] = None
    ) -> Instance:  # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentences : ``List[List[str]]``, required.
            A list of lists representing the tokenized words and sentences in the document.
        gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None)
            A list of all clusters in the document, represented as word spans. Each cluster
            contains some number of spans, which can be nested and overlap, but will never
            exactly match between clusters.
        mention_token_spans: optional
            A Sequence of spans which should be consider for coref. This will override
            the usual behavior of including all spans up to the maximum width.  The spans should
            be specified in terms of token indices with inclusive end token indices.

        Returns
        -------
        An ``Instance`` containing the following ``Fields``:
            text : ``TextField``
                The text of the full document.
            spans : ``ListField[SpanField]``
                A ListField containing the spans represented as ``SpanFields``
                with respect to the document text.
            span_labels : ``SequenceLabelField``, optional
                The id of the cluster which each possible span belongs to, or -1 if it does
                 not belong to a cluster. As these labels have variable length (it depends on
                 how many spans we are considering), we represent this a as a ``SequenceLabelField``
                 with respect to the ``spans ``ListField``.
        """
        flattened_sentences = [
            self._normalize_word(word) for sentence in sentences
            for word in sentence
        ]

        metadata: Dict[str, Any] = {"original_text": flattened_sentences}
        if gold_clusters is not None:
            metadata["clusters"] = gold_clusters

        text_field = TextField([Token(word) for word in flattened_sentences],
                               self._token_indexers)

        cluster_dict = {}
        if gold_clusters is not None:
            for cluster_id, cluster in enumerate(gold_clusters):
                for mention in cluster:
                    cluster_dict[tuple(mention)] = cluster_id

        span_fields: List[Field] = []
        span_labels: Optional[
            List[int]] = [] if gold_clusters is not None else None

        if mention_token_spans is None:
            # every possible span in the document up to a certain maximum size is a
            # mention candidate
            sentence_offset = 0
            for sentence in sentences:
                for start, end in enumerate_spans(
                        sentence,
                        offset=sentence_offset,
                        max_span_width=self._max_span_width):
                    if span_labels is not None:
                        if (start, end) in cluster_dict:
                            span_labels.append(cluster_dict[(start, end)])
                        else:
                            span_labels.append(-1)

                    span_fields.append(SpanField(start, end, text_field))
                sentence_offset += len(sentence)
        else:
            if span_labels is not None:
                raise NotImplementedError(
                    "We currently don't handle known mentions plus "
                    "gold labels")
            # the mentions spans are already known; we just need to make SpanFields for them
            span_fields = [
                SpanField(start, end, text_field)
                for (start, end) in mention_token_spans
            ]

        span_field = ListField(span_fields)
        metadata_field = MetadataField(metadata)

        fields: Dict[str, Field] = {
            "text": text_field,
            "spans": span_field,
            "metadata": metadata_field
        }
        if span_labels is not None:
            fields["span_labels"] = SequenceLabelField(span_labels, span_field)

        return Instance(fields)
    def text_to_instance(self, tokenized_text: [str], passage: Passage, dataset_label: str, lang: str, id: str = None,
                         gold_tree: Passage = None) -> Instance:
        fields: Dict[str, Field] = {}

        word_tokens = self.word_tokenizer.split_words(" ".join(tokenized_text), lang)
        assert len(tokenized_text) == len(word_tokens)
        for i in range(len(word_tokens)):
            assert word_tokens[i].text == tokenized_text[i]
        sentence_field = TextField(word_tokens, self.token_indexers)
        fields["tokens"] = sentence_field

        lang_field = MetadataField(lang)
        fields["lang"] = lang_field

        passage_field = MetadataField(passage)
        fields["passage"] = passage_field

        dataset_label_field = MetadataField(dataset_label)
        fields["dataset_label"] = dataset_label_field

        if id is not None:
            id_field = MetadataField(id)
            fields["id"] = id_field

        if gold_tree is not None:
            gold_ucca_tree_field = MetadataField(gold_tree)
            fields["gold_ucca_tree"] = gold_ucca_tree_field

            gold_primary_tree = UCCA2tree(copy.deepcopy(gold_tree)).convert()
            gold_primary_tree_field = MetadataField(gold_primary_tree)
            fields["gold_primary_tree"] = gold_primary_tree_field

            spans, (heads, deps, labels) = gerenate_remote(gold_tree)
            remote_head_list: [int] = []
            remote_dep_list: [int] = []
            remote_label_list: [LabelField] = []
            remote_nodes_span_list: [SpanField] = []
            for head_sublist, dep_sublist, label_sublist in zip(heads, deps, labels):
                for head, dep, label in zip(head_sublist, dep_sublist, label_sublist):
                    remote_head_list.append(head)
                    remote_dep_list.append(dep)
                    remote_label_list.append(LabelField(label, label_namespace="remote_labels"))

            for (start, end) in spans:
                # SUDA code use (i, i+1) to represent the span covering the token i,
                # while AllenNlp uses (i,i)
                end = end - 1
                remote_nodes_span_list.append(SpanField(start, end, sentence_field))

            if len(remote_head_list) == 0:  # In that case remote_dep_list and remote_label_list is also empty
                empty_array_field = ArrayField(np.zeros(1)).empty_field()
                empty_label_list_field = ListField([LabelField("dummy")]).empty_field()
                try:
                    empty_span_list_field = ListField([SpanField(0, 0, sentence_field)]).empty_field()
                except:
                    print(tokenized_text)
                fields["remote_heads"] = empty_array_field
                fields["remote_deps"] = empty_array_field
                fields["remote_labels"] = empty_label_list_field
                fields["remote_nodes_spans"] = empty_span_list_field
            else:
                fields["remote_heads"] = ArrayField(np.array(remote_head_list))
                fields["remote_deps"] = ArrayField(np.array(remote_dep_list))
                fields["remote_labels"] = ListField(remote_label_list)
                fields["remote_nodes_spans"] = ListField(remote_nodes_span_list)

        spans: List[Field] = []
        gold_labels = []
        for start, end in enumerate_spans(tokenized_text):
            spans.append(SpanField(start, end, sentence_field))
            # TODO: Use gold labels in the training instead of the oracle_label function. Right now they are needed
            #  for creating the vocabulary of labels
            if gold_tree is not None:
                # SUDA code use (i, i+1) to represent the span covering the token i,
                # while AllenNlp uses (i,i)
                gold_label = str(gold_primary_tree.oracle_label(start, end + 1))
                gold_labels.append(gold_label)

        span_list_field: ListField = ListField(spans)
        fields["spans"] = span_list_field

        if gold_tree is not None:
            fields["span_labels"] = SequenceLabelField(gold_labels, span_list_field)

        return Instance(fields)