Esempio n. 1
0
    def token_list_to_sentence(self, token_list: conllu.TokenList) -> Sentence:
        sentence: Sentence = Sentence()

        # current token ID
        token_idx = 0

        for conllu_token in token_list:
            token = Token(conllu_token["form"])

            if "ner" in conllu_token:
                token.add_label("ner", conllu_token["ner"])

            if "ner-2" in conllu_token:
                token.add_label("ner-2", conllu_token["ner-2"])

            if "lemma" in conllu_token:
                token.add_label("lemma", conllu_token["lemma"])

            if "misc" in conllu_token and conllu_token["misc"] is not None:
                space_after = conllu_token["misc"].get("SpaceAfter")
                if space_after == "No":
                    token.whitespace_after = False

            sentence.add_token(token)
            token_idx += 1

        if "sentence_id" in token_list.metadata:
            sentence.add_label("sentence_id", token_list.metadata["sentence_id"])

        if "relations" in token_list.metadata:
            # relations: List[Relation] = []
            for head_start, head_end, tail_start, tail_end, label in token_list.metadata["relations"]:
                # head and tail span indices are 1-indexed and end index is inclusive
                head = Span(sentence.tokens[head_start - 1 : head_end])
                tail = Span(sentence.tokens[tail_start - 1 : tail_end])

                sentence.add_complex_label("relation", RelationLabel(value=label, head=head, tail=tail))

        # determine all NER label types in sentence and add all NER spans as sentence-level labels
        ner_label_types = []
        for token in sentence.tokens:
            for annotation in token.annotation_layers.keys():
                if annotation.startswith("ner") and annotation not in ner_label_types:
                    ner_label_types.append(annotation)

        for label_type in ner_label_types:
            spans = sentence.get_spans(label_type)
            for span in spans:
                sentence.add_complex_label("entity", label=SpanLabel(span=span, value=span.tag, score=span.score))

        return sentence
Esempio n. 2
0
    def _get_tars_formatted_sentence(self, label, sentence):

        original_text = sentence.to_tokenized_string()

        label_text_pair = (f"{label} {self.separator} {original_text}"
                           if self.prefix else
                           f"{original_text} {self.separator} {label}")

        label_length = 0 if not self.prefix else len(label.split(" ")) + len(
            self.separator.split(" "))

        # make a tars sentence where all labels are O by default
        tars_sentence = Sentence(label_text_pair, use_tokenizer=False)

        for entity_label in sentence.get_labels(self.label_type):
            if entity_label.value == label:
                new_span = [
                    tars_sentence.get_token(token.idx + label_length)
                    for token in entity_label.span
                ]
                tars_sentence.add_complex_label(
                    self.static_label_type,
                    SpanLabel(Span(new_span), value="entity"))

        return tars_sentence
Esempio n. 3
0
    def add_entity_markers(self, sentence, span_1, span_2):

        text = ""

        entity_one_is_first = None
        offset = 0
        for token in sentence:
            if token == span_2[0]:
                if entity_one_is_first is None:
                    entity_one_is_first = False
                offset += 1
                text += " <e2>"
                span_2_startid = offset
            if token == span_1[0]:
                offset += 1
                text += " <e1>"
                if entity_one_is_first is None:
                    entity_one_is_first = True
                span_1_startid = offset

            text += " " + token.text

            if token == span_1[-1]:
                offset += 1
                text += " </e1>"
            if token == span_2[-1]:
                offset += 1
                text += " </e2>"

            offset += 1

        expanded_sentence = Sentence(text, use_tokenizer=False)

        expanded_span_1 = Span([expanded_sentence[span_1_startid - 1]])
        expanded_span_2 = Span([expanded_sentence[span_2_startid - 1]])

        return (
            expanded_sentence,
            (
                expanded_span_1,
                expanded_span_2,
            )
            if entity_one_is_first
            else (expanded_span_2, expanded_span_1),
        )
Esempio n. 4
0
        def get_token_span(self, span: Tuple[int, int]) -> Span:
            """
            Given an interval specified with start and end pos as tuple, this function returns a Span object
            spanning the tokens included in the interval. If the interval is overlapping with a token span, a
            ValueError is raised

            :param span: Start and end pos of the requested span as tuple
            :return: A span object spanning the requested token interval
            """
            span_start: int = self.__tokens_start_pos.index(span[0])
            span_end: int = self.__tokens_end_pos.index(span[1])
            return Span(self.tokens[span_start:span_end + 1])
Esempio n. 5
0
def mock_ner_span(text, tag, start, end):
    span = Span([])
    span.tag = tag
    span.start_pos = start
    span.end_pos = end
    span.tokens = [Token(text[start:end])]
    return span
Esempio n. 6
0
def mock_ner_span(text, tag, start, end):
    span = Span([]).set_label("class", tag)
    span.start_pos = start
    span.end_pos = end
    span.tokens = [Token(text[start:end])]
    return span
Esempio n. 7
0
def mock_ner_span(tag, start, end):
    span = Span([])
    span.tag = tag
    span.start_pos = start
    span.end_pos = end
    return span
Esempio n. 8
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size: int = 32,
        return_probabilities_for_all_classes: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
    ):
        """
        Predicts labels for current batch with CRF or Softmax.
        :param sentences: List of sentences in batch
        :param mini_batch_size: batch size for test data
        :param return_probabilities_for_all_classes: Whether to return probabilites for all classes
        :param verbose: whether to use progress bar
        :param label_name: which label to predict
        :param return_loss: whether to return loss value
        :param embedding_storage_mode: determines where to store embeddings - can be "gpu", "cpu" or None.
        """
        if label_name is None:
            label_name = self.tag_type

        with torch.no_grad():
            if not sentences:
                return sentences

            # make sure its a list
            if not isinstance(sentences, list) and not isinstance(
                    sentences, flair.data.Dataset):
                sentences = [sentences]

            # filter empty sentences
            sentences = [
                sentence for sentence in sentences if len(sentence) > 0
            ]

            # reverse sort all sequences by their length
            reordered_sentences = sorted(sentences,
                                         key=lambda s: len(s),
                                         reverse=True)

            if len(reordered_sentences) == 0:
                return sentences

            dataloader = DataLoader(
                dataset=FlairDatapointDataset(reordered_sentences),
                batch_size=mini_batch_size,
            )
            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            overall_loss = 0
            batch_no = 0
            label_count = 0
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                # stop if all sentences are empty
                if not batch:
                    continue

                # get features from forward propagation
                features, gold_labels = self.forward(batch)

                # remove previously predicted labels of this type
                for sentence in batch:
                    sentence.remove_labels(label_name)

                # if return_loss, get loss value
                if return_loss:
                    loss = self._calculate_loss(features, gold_labels)
                    overall_loss += loss[0]
                    label_count += loss[1]

                # Sort batch in same way as forward propagation
                lengths = torch.LongTensor(
                    [len(sentence) for sentence in batch])
                lengths = lengths.sort(dim=0, descending=True)
                batch = [batch[i] for i in lengths.indices]

                # make predictions
                if self.use_crf:
                    predictions, all_tags = self.viterbi_decoder.decode(
                        features, return_probabilities_for_all_classes)
                else:
                    predictions, all_tags = self._standard_inference(
                        features, batch, return_probabilities_for_all_classes)

                # add predictions to Sentence
                for sentence, sentence_predictions in zip(batch, predictions):

                    # BIOES-labels need to be converted to spans
                    if self.predict_spans:
                        sentence_tags = [
                            label.value for label in sentence_predictions
                        ]
                        sentence_scores = [
                            label.score for label in sentence_predictions
                        ]
                        predicted_spans = get_spans_from_bio(
                            sentence_tags, sentence_scores)
                        for predicted_span in predicted_spans:
                            span = Span(sentence[
                                predicted_span[0][0]:predicted_span[0][-1] +
                                1])
                            sentence.add_complex_label(
                                typename=label_name,
                                label=SpanLabel(span=span,
                                                value=predicted_span[2],
                                                score=predicted_span[1]),
                            )
                    # token-labels can be added directly
                    else:
                        for token, label in zip(sentence.tokens,
                                                sentence_predictions):
                            token.add_tag_label(label_name, label)

                # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
                for (sentence, sent_all_tags) in zip(batch, all_tags):
                    for (token, token_all_tags) in zip(sentence.tokens,
                                                       sent_all_tags):
                        token.add_tags_proba_dist(label_name, token_all_tags)

            store_embeddings(sentences, storage_mode=embedding_storage_mode)

            if return_loss:
                return overall_loss, label_count
Esempio n. 9
0
    def token_list_to_sentence(self, token_list: conllu.TokenList) -> Sentence:
        sentence: Sentence = Sentence()

        # Build the sentence tokens and add the annotations.
        for conllu_token in token_list:
            token = Token(conllu_token["form"])

            for field in self.token_annotation_fields:
                field_value: Any = conllu_token[field]
                if isinstance(field_value, dict):
                    # For fields that contain key-value annotations,
                    # we add the key as label type-name and the value as the label value.
                    for key, value in field_value.items():
                        token.add_label(typename=key, value=str(value))
                else:
                    token.add_label(typename=field, value=str(field_value))

            if conllu_token.get("misc") is not None:
                space_after: Optional[str] = conllu_token["misc"].get(
                    "SpaceAfter")
                if space_after == "No":
                    token.whitespace_after = False

            sentence.add_token(token)

        if "sentence_id" in token_list.metadata:
            sentence.add_label("sentence_id",
                               token_list.metadata["sentence_id"])

        if "relations" in token_list.metadata:
            for (
                    head_start,
                    head_end,
                    tail_start,
                    tail_end,
                    label,
            ) in token_list.metadata["relations"]:
                # head and tail span indices are 1-indexed and end index is inclusive
                head = Span(sentence.tokens[head_start - 1:head_end])
                tail = Span(sentence.tokens[tail_start - 1:tail_end])

                sentence.add_complex_label(
                    "relation", RelationLabel(value=label,
                                              head=head,
                                              tail=tail))

        # determine all NER label types in sentence and add all NER spans as sentence-level labels
        ner_label_types = []
        for token in sentence.tokens:
            for annotation in token.annotation_layers.keys():
                if annotation.startswith(
                        "ner") and annotation not in ner_label_types:
                    ner_label_types.append(annotation)

        for label_type in ner_label_types:
            spans = sentence.get_spans(label_type)
            for span in spans:
                sentence.add_complex_label(
                    "entity",
                    label=SpanLabel(span=span,
                                    value=span.tag,
                                    score=span.score),
                )

        return sentence
    def get_spans(self, tag_type: str, min_score=-1) -> List[Span]:

        spans: List[Span] = []

        current_span = []

        tags = defaultdict(lambda: 0.0)

        previous_tag_value: str = "O"
        for token in self:

            tag: Label = token.get_tag(tag_type)
            tag_value = tag.value

            # non-set tags are OUT tags
            if tag_value == "" or tag_value == "O":
                tag_value = "O-"

            # anything that is not a BIOES tag is a SINGLE tag
            if tag_value[0:2] not in ["B-", "I-", "O-", "E-", "S-"]:
                tag_value = "S-" + tag_value

            # anything that is not in the given tag_set is OUT
            if tag_value not in self.tag_set:
                tag_value = "O-"

            # anything that is not OUT is IN
            in_span = False
            if tag_value[0:2] not in ["O-"]:
                in_span = True

            # single and begin tags start a new span
            starts_new_span = False
            if tag_value[0:2] in ["B-", "S-"]:
                starts_new_span = True

            if (previous_tag_value[0:2] in ["S-"]
                    and previous_tag_value[2:] != tag_value[2:] and in_span):
                starts_new_span = True

            if (starts_new_span or not in_span) and len(current_span) > 0:
                scores = [t.get_tag(tag_type).score for t in current_span]
                span_score = sum(scores) / len(scores)
                if span_score > min_score:
                    spans.append(
                        Span(
                            current_span,
                            tag=sorted(tags.items(),
                                       key=lambda k_v: k_v[1],
                                       reverse=True)[0][0],
                            score=span_score,
                        ))
                current_span = []
                tags = defaultdict(lambda: 0.0)

            if in_span:
                current_span.append(token)
                weight = 1.1 if starts_new_span else 1.0
                tags[tag_value[2:]] += weight

            # remember previous tag
            previous_tag_value = tag_value

        if len(current_span) > 0:
            scores = [t.get_tag(tag_type).score for t in current_span]
            span_score = sum(scores) / len(scores)
            if span_score > min_score:
                spans.append(
                    Span(
                        current_span,
                        tag=sorted(tags.items(),
                                   key=lambda k_v: k_v[1],
                                   reverse=True)[0][0],
                        score=span_score,
                    ))

        return spans
Esempio n. 11
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size=32,
        return_probabilities_for_all_classes: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
        most_probable_first: bool = True,
    ):
        # return
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a List of Sentence
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name is None:
            label_name = self.get_current_label_type()

        # with torch.no_grad():
        if not sentences:
            return sentences

        if not isinstance(sentences, list):
            sentences = [sentences]

        reordered_sentences = sorted(sentences,
                                     key=lambda s: len(s),
                                     reverse=True)

        dataloader = DataLoader(
            dataset=FlairDatapointDataset(reordered_sentences),
            batch_size=mini_batch_size,
        )

        # progress bar for verbosity
        if verbose:
            dataloader = tqdm(dataloader)

        overall_loss = 0
        overall_count = 0
        with torch.no_grad():
            for batch in dataloader:

                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                # go through each sentence in the batch
                for sentence in batch:

                    # always remove tags first
                    sentence.remove_labels(label_name)

                    all_labels = [
                        label.decode("utf-8") for label in
                        self.get_current_label_dictionary().idx2item
                    ]

                    all_detected = {}
                    for label in all_labels:
                        tars_sentence = self._get_tars_formatted_sentence(
                            label, sentence)

                        loss_and_count = self.tars_model.predict(
                            tars_sentence,
                            label_name=label_name,
                            return_loss=True,
                        )

                        overall_loss += loss_and_count[0].item()
                        overall_count += loss_and_count[1]

                        for predicted in tars_sentence.get_labels(label_name):
                            predicted.value = label
                            all_detected[predicted] = predicted.score

                    if most_probable_first:
                        import operator

                        already_set_indices = []

                        sorted_x = sorted(all_detected.items(),
                                          key=operator.itemgetter(1))
                        sorted_x.reverse()
                        for tuple in sorted_x:
                            # get the span and its label
                            label = tuple[0]
                            # label = span.get_labels("tars_temp_label")[0].value
                            label_length = (0 if not self.prefix else
                                            len(label.value.split(" ")) +
                                            len(self.separator.split(" ")))

                            # determine whether tokens in this span already have a label
                            tag_this = True
                            for token in label.span:
                                corresponding_token = sentence.get_token(
                                    token.idx - label_length)
                                if corresponding_token is None:
                                    tag_this = False
                                    continue
                                if token.idx in already_set_indices:
                                    tag_this = False
                                    continue

                            # only add if all tokens have no label
                            if tag_this:
                                already_set_indices.extend(
                                    token.idx for token in label.span)
                                predicted_span = [
                                    sentence.get_token(token.idx -
                                                       label_length)
                                    for token in label.span
                                ]
                                sentence.add_complex_label(
                                    label_name,
                                    label=SpanLabel(Span(predicted_span),
                                                    value=label.value,
                                                    score=label.score),
                                )

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

        if return_loss:
            return overall_loss, overall_count