Example #1
0
    def _get_tars_formatted_sentence(self, label, sentence):

        original_text = sentence.to_tokenized_string()

        label_text_pair = (f"{label} {self.separator} {original_text}"
                           if self.prefix else
                           f"{original_text} {self.separator} {label}")

        label_length = 0 if not self.prefix else len(label.split(" ")) + len(
            self.separator.split(" "))

        # make a tars sentence where all labels are O by default
        tars_sentence = Sentence(label_text_pair, use_tokenizer=False)

        for entity_label in sentence.get_labels(self.label_type):
            if entity_label.value == label:
                new_span = Span([
                    tars_sentence.get_token(token.idx + label_length)
                    for token in entity_label.data_point
                ])
                new_span.add_label(self.static_label_type, value="entity")

        return tars_sentence
Example #2
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size=32,
        return_probabilities_for_all_classes: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
        most_probable_first: bool = True,
    ):
        # return
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a List of Sentence
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name is None:
            label_name = self.get_current_label_type()

        # with torch.no_grad():
        if not sentences:
            return sentences

        if not isinstance(sentences, list):
            sentences = [sentences]

        reordered_sentences = sorted(sentences,
                                     key=lambda s: len(s),
                                     reverse=True)

        dataloader = DataLoader(
            dataset=FlairDatapointDataset(reordered_sentences),
            batch_size=mini_batch_size,
        )

        # progress bar for verbosity
        if verbose:
            dataloader = tqdm(dataloader)

        overall_loss = 0
        overall_count = 0
        with torch.no_grad():
            for batch in dataloader:

                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                # go through each sentence in the batch
                for sentence in batch:

                    # always remove tags first
                    sentence.remove_labels(label_name)

                    all_labels = [
                        label.decode("utf-8") for label in
                        self.get_current_label_dictionary().idx2item
                    ]

                    all_detected = {}
                    for label in all_labels:
                        tars_sentence = self._get_tars_formatted_sentence(
                            label, sentence)

                        loss_and_count = self.tars_model.predict(
                            tars_sentence,
                            label_name=label_name,
                            return_loss=True,
                        )

                        overall_loss += loss_and_count[0].item()
                        overall_count += loss_and_count[1]

                        for predicted in tars_sentence.get_labels(label_name):
                            predicted.value = label
                            all_detected[predicted] = predicted.score

                    if most_probable_first:
                        import operator

                        already_set_indices: List[int] = []

                        sorted_x = sorted(all_detected.items(),
                                          key=operator.itemgetter(1))
                        sorted_x.reverse()
                        for tuple in sorted_x:
                            # get the span and its label
                            label = tuple[0]
                            # label = span.get_labels("tars_temp_label")[0].value
                            label_length = (0 if not self.prefix else
                                            len(label.value.split(" ")) +
                                            len(self.separator.split(" ")))

                            # determine whether tokens in this span already have a label
                            tag_this = True
                            for token in label.data_point:
                                corresponding_token = sentence.get_token(
                                    token.idx - label_length)
                                if corresponding_token is None:
                                    tag_this = False
                                    continue
                                if token.idx in already_set_indices:
                                    tag_this = False
                                    continue

                            # only add if all tokens have no label
                            if tag_this:
                                already_set_indices.extend(
                                    token.idx for token in label.data_point)
                                predicted_span = Span([
                                    sentence.get_token(token.idx -
                                                       label_length)
                                    for token in label.data_point
                                ])
                                predicted_span.add_label(label_name,
                                                         value=label.value,
                                                         score=label.score)

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

        if return_loss:
            return overall_loss, overall_count