コード例 #1
0
ファイル: model.py プロジェクト: MaxDall/flair
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size: int = 32,
        return_probabilities_for_all_classes: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
    ):
        """
        Predicts the class labels for the given sentences. The labels are directly added to the sentences.
        :param sentences: list of sentences
        :param mini_batch_size: mini batch size to use
        :param return_probabilities_for_all_classes : return probabilities for all classes instead of only best predicted
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name is None:
            label_name = self.label_type if self.label_type is not None else "label"

        with torch.no_grad():
            if not sentences:
                return sentences

            if isinstance(sentences, DataPoint):
                sentences = [sentences]

            # filter empty sentences
            if isinstance(sentences[0], DataPoint):
                sentences = [
                    sentence for sentence in sentences if len(sentence) > 0
                ]
            if len(sentences) == 0:
                return sentences

            # reverse sort all sequences by their length
            rev_order_len_index = sorted(range(len(sentences)),
                                         key=lambda k: len(sentences[k]),
                                         reverse=True)

            reordered_sentences: List[Union[DataPoint, str]] = [
                sentences[index] for index in rev_order_len_index
            ]

            dataloader = DataLoader(
                dataset=SentenceDataset(reordered_sentences),
                batch_size=mini_batch_size)
            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            overall_loss = 0
            batch_no = 0
            label_count = 0
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                # stop if all sentences are empty
                if not batch:
                    continue

                scores, gold_labels, data_points, label_candidates = self.forward_pass(
                    batch, return_label_candidates=True)
                # remove previously predicted labels of this type
                for sentence in data_points:
                    sentence.remove_labels(label_name)

                if return_loss:
                    overall_loss += self._calculate_loss(scores,
                                                         gold_labels)[0]
                    label_count += len(label_candidates)

                # if anything could possibly be predicted
                if len(label_candidates) > 0:
                    if self.multi_label:
                        sigmoided = torch.sigmoid(
                            scores)  # size: (n_sentences, n_classes)
                        n_labels = sigmoided.size(1)
                        for s_idx, (data_point, label_candidate) in enumerate(
                                zip(data_points, label_candidates)):
                            for l_idx in range(n_labels):
                                label_value = self.label_dictionary.get_item_for_index(
                                    l_idx)
                                if label_value == 'O': continue
                                label_threshold = self._get_label_threshold(
                                    label_value)
                                label_score = sigmoided[s_idx, l_idx].item()
                                if label_score > label_threshold or return_probabilities_for_all_classes:
                                    label = label_candidate.spawn(
                                        value=label_value, score=label_score)
                                    data_point.add_complex_label(
                                        label_name, label)
                    else:
                        softmax = torch.nn.functional.softmax(scores, dim=-1)

                        if return_probabilities_for_all_classes:
                            n_labels = softmax.size(1)
                            for s_idx, (data_point,
                                        label_candidate) in enumerate(
                                            zip(data_points,
                                                label_candidates)):
                                for l_idx in range(n_labels):
                                    label_value = self.label_dictionary.get_item_for_index(
                                        l_idx)
                                    if label_value == 'O': continue
                                    label_score = softmax[s_idx, l_idx].item()
                                    label = label_candidate.spawn(
                                        value=label_value, score=label_score)
                                    data_point.add_complex_label(
                                        label_name, label)
                        else:
                            conf, idx = torch.max(softmax, dim=-1)
                            for data_point, label_candidate, c, i in zip(
                                    data_points, label_candidates, conf, idx):
                                label_value = self.label_dictionary.get_item_for_index(
                                    i.item())
                                if label_value == 'O': continue
                                label = label_candidate.spawn(
                                    value=label_value, score=c.item())
                                data_point.add_complex_label(label_name, label)

                store_embeddings(batch, storage_mode=embedding_storage_mode)

            if return_loss:
                return overall_loss, label_count
コード例 #2
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size: int = 32,
        return_probabilities_for_all_classes: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
    ):
        """
        Predicts labels for current batch with CRF or Softmax.
        :param sentences: List of sentences in batch
        :param mini_batch_size: batch size for test data
        :param return_probabilities_for_all_classes: Whether to return probabilites for all classes
        :param verbose: whether to use progress bar
        :param label_name: which label to predict
        :param return_loss: whether to return loss value
        :param embedding_storage_mode: determines where to store embeddings - can be "gpu", "cpu" or None.
        """
        if label_name is None:
            label_name = self.tag_type

        with torch.no_grad():
            if not sentences:
                return sentences

            # make sure its a list
            if not isinstance(sentences, list) and not isinstance(
                    sentences, flair.data.Dataset):
                sentences = [sentences]

            # filter empty sentences
            sentences = [
                sentence for sentence in sentences if len(sentence) > 0
            ]

            # reverse sort all sequences by their length
            reordered_sentences = sorted(sentences,
                                         key=lambda s: len(s),
                                         reverse=True)

            if len(reordered_sentences) == 0:
                return sentences

            dataloader = DataLoader(
                dataset=FlairDatapointDataset(reordered_sentences),
                batch_size=mini_batch_size,
            )
            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            overall_loss = 0
            batch_no = 0
            label_count = 0
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                # stop if all sentences are empty
                if not batch:
                    continue

                # get features from forward propagation
                features, gold_labels = self.forward(batch)

                # remove previously predicted labels of this type
                for sentence in batch:
                    sentence.remove_labels(label_name)

                # if return_loss, get loss value
                if return_loss:
                    loss = self._calculate_loss(features, gold_labels)
                    overall_loss += loss[0]
                    label_count += loss[1]

                # Sort batch in same way as forward propagation
                lengths = torch.LongTensor(
                    [len(sentence) for sentence in batch])
                lengths = lengths.sort(dim=0, descending=True)
                batch = [batch[i] for i in lengths.indices]

                if self.use_crf:
                    predictions, all_tags = self.viterbi_decoder.decode(
                        features, return_probabilities_for_all_classes)
                else:
                    predictions, all_tags = self._standard_inference(
                        features, batch, return_probabilities_for_all_classes)

                for sentence, labels in zip(batch, predictions):
                    for token, label in zip(sentence.tokens, labels):
                        token.add_tag_label(label_name, label)

                # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
                for (sentence, sent_all_tags) in zip(batch, all_tags):
                    for (token, token_all_tags) in zip(sentence.tokens,
                                                       sent_all_tags):
                        token.add_tags_proba_dist(label_name, token_all_tags)

            store_embeddings(sentences, storage_mode=embedding_storage_mode)

            if return_loss:
                return overall_loss, label_count
コード例 #3
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size: int = 32,
        multi_class_prob: bool = False,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
    ):
        """
        Predicts the class labels for the given sentences. The labels are directly added to the sentences.
        :param sentences: list of sentences
        :param mini_batch_size: mini batch size to use
        :param multi_class_prob : return probability for all class for multiclass
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name == None:
            label_name = self.label_type if self.label_type is not None else 'label'

        with torch.no_grad():
            if not sentences:
                return sentences

            if isinstance(sentences, DataPoint):
                sentences = [sentences]

            # filter empty sentences
            if isinstance(sentences[0], Sentence):
                sentences = [
                    sentence for sentence in sentences if len(sentence) > 0
                ]
            if len(sentences) == 0: return sentences

            # reverse sort all sequences by their length
            rev_order_len_index = sorted(range(len(sentences)),
                                         key=lambda k: len(sentences[k]),
                                         reverse=True)

            reordered_sentences: List[Union[DataPoint, str]] = [
                sentences[index] for index in rev_order_len_index
            ]

            dataloader = DataLoader(
                dataset=SentenceDataset(reordered_sentences),
                batch_size=mini_batch_size)
            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            overall_loss = 0
            batch_no = 0
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                # stop if all sentences are empty
                if not batch:
                    continue

                scores = self.forward(batch)

                if return_loss:
                    overall_loss += self._calculate_loss(scores, batch)

                predicted_labels = self._obtain_labels(
                    scores, predict_prob=multi_class_prob)

                for (sentence, labels) in zip(batch, predicted_labels):
                    for label in labels:
                        if self.multi_label or multi_class_prob:
                            sentence.add_label(label_name, label.value,
                                               label.score)
                        else:
                            sentence.set_label(label_name, label.value,
                                               label.score)

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

            if return_loss:
                return overall_loss / batch_no
コード例 #4
0
    def predict(
            self,
            sentences: Union[List[Sentence], Sentence],
            mini_batch_size=32,
            all_tag_prob: bool = False,
            verbose: bool = False,
            label_name: Optional[str] = None,
            return_loss=False,
            embedding_storage_mode="none",
    ):
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a List of Sentence
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name == None:
            label_name = self.tag_type

        with torch.no_grad():
            if not sentences:
                return sentences

            if isinstance(sentences, Sentence):
                sentences = [sentences]

            # reverse sort all sequences by their length
            rev_order_len_index = sorted(
                range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True
            )

            reordered_sentences: List[Union[Sentence, str]] = [
                sentences[index] for index in rev_order_len_index
            ]

            dataloader = DataLoader(
                dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size
            )

            if self.use_crf:
                transitions = self.transitions.detach().cpu().numpy()
            else:
                transitions = None

            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            overall_loss = 0
            batch_no = 0
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(f"Inferencing on batch {batch_no}")

                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                feature = self.forward(batch)

                if return_loss:
                    overall_loss += self._calculate_loss(feature, batch)

                tags, all_tags = self._obtain_labels(
                    feature=feature,
                    batch_sentences=batch,
                    transitions=transitions,
                    get_all_tags=all_tag_prob,
                )

                for (sentence, sent_tags) in zip(batch, tags):
                    for (token, tag) in zip(sentence.tokens, sent_tags):
                        token.add_tag_label(label_name, tag)

                # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
                for (sentence, sent_all_tags) in zip(batch, all_tags):
                    for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags):
                        token.add_tags_proba_dist(label_name, token_all_tags)

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

            if return_loss:
                return overall_loss / batch_no
コード例 #5
0
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size=32,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
    ):
        # return
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a List of Sentence
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if label_name == None:
            label_name = self.get_current_tag_type()

        # with torch.no_grad():
        if not sentences:
            return sentences

        if isinstance(sentences, Sentence):
            sentences = [sentences]

        # set context if not set already
        previous_sentence = None
        for sentence in sentences:
            if sentence.is_context_set(): continue
            sentence._previous_sentence = previous_sentence
            sentence._next_sentence = None
            if previous_sentence: previous_sentence._next_sentence = sentence
            previous_sentence = sentence

        # reverse sort all sequences by their length
        rev_order_len_index = sorted(range(len(sentences)),
                                     key=lambda k: len(sentences[k]),
                                     reverse=True)

        reordered_sentences: List[Union[Sentence, str]] = [
            sentences[index] for index in rev_order_len_index
        ]

        dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences),
                                batch_size=mini_batch_size)

        # progress bar for verbosity
        if verbose:
            dataloader = tqdm(dataloader)

        overall_loss = 0
        overall_count = 0
        batch_no = 0
        with torch.no_grad():
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                # go through each sentence in the batch
                for sentence in batch:

                    # always remove tags first
                    for token in sentence:
                        token.remove_labels(label_name)

                    all_labels = [
                        label.decode("utf-8")
                        for label in self.get_current_tag_dictionary().idx2item
                    ]

                    all_detected = {}
                    for label in all_labels:
                        tars_sentence = self._get_tars_formatted_sentence(
                            label, sentence)

                        label_length = 0 if not self.prefix else len(
                            label.split(" ")) + len(self.separator.split(" "))

                        loss_and_count = self.tars_model.predict(
                            tars_sentence,
                            label_name=label_name,
                            all_tag_prob=True,
                            return_loss=True)
                        overall_loss += loss_and_count[0].item()
                        overall_count += loss_and_count[1]

                        for span in tars_sentence.get_spans(label_name):
                            span.set_label('tars_temp_label', label)
                            all_detected[span] = span.score

                        for span in tars_sentence.get_spans(label_name):
                            for token in span:
                                corresponding_token = sentence.get_token(
                                    token.idx - label_length)
                                if corresponding_token is None: continue
                                if corresponding_token.get_tag(label_name).value != '' and \
                                        corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score:
                                    continue
                                corresponding_token.add_tag(
                                    label_name,
                                    token.get_tag(label_name).value + label,
                                    token.get_tag(label_name).score,
                                )

                    # import operator
                    # sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
                    # sorted_x.reverse()
                    # print(sorted_x)
                    # for tuple in sorted_x:
                    #     span = tuple[0]
                    #
                    #     tag_this = True
                    #
                    # for token in span:
                    #     corresponding_token = sentence.get_token(token.idx)
                    #     if corresponding_token is None:
                    #         tag_this = False
                    #         continue
                    #     if corresponding_token.get_tag(label_name).value != '' and \
                    #             corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score:
                    #         tag_this = False
                    #         continue
                    #
                    # if tag_this:
                    #     for token in span:
                    #         corresponding_token = sentence.get_token(token.idx)
                    #         corresponding_token.add_tag(
                    #             label_name,
                    #             token.get_tag(label_name).value + span.get_labels('tars_temp_label')[0].value,
                    #             token.get_tag(label_name).score,
                    #         )

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

        if return_loss:
            return overall_loss, overall_count
コード例 #6
0
ファイル: tars_model.py プロジェクト: sckevmit/flair
    def predict(
        self,
        sentences: Union[List[Sentence], Sentence],
        mini_batch_size=32,
        verbose: bool = False,
        label_name: Optional[str] = None,
        return_loss=False,
        embedding_storage_mode="none",
        label_threshold: float = 0.5,
        multi_label: Optional[bool] = None,
    ):
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a List of Sentence
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param return_loss: set to True to return loss
        :param label_name: set this to change the name of the label type that is predicted
        :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
        you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
        'gpu' to store embeddings in GPU memory.
        """
        if not label_name:
            label_name = self.get_current_label_type()

        if multi_label is None:
            multi_label = self.is_current_task_multi_label()

        # with torch.no_grad():
        if not sentences:
            return sentences

        if isinstance(sentences, Sentence):
            sentences = [sentences]

        # set context if not set already
        previous_sentence = None
        for sentence in sentences:
            if sentence.is_context_set(): continue
            sentence._previous_sentence = previous_sentence
            sentence._next_sentence = None
            if previous_sentence: previous_sentence._next_sentence = sentence
            previous_sentence = sentence

        # reverse sort all sequences by their length
        rev_order_len_index = sorted(range(len(sentences)),
                                     key=lambda k: len(sentences[k]),
                                     reverse=True)

        reordered_sentences: List[Union[Sentence, str]] = [
            sentences[index] for index in rev_order_len_index
        ]

        dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences),
                                batch_size=mini_batch_size)

        # progress bar for verbosity
        if verbose:
            dataloader = tqdm(dataloader)

        overall_loss = 0
        overall_count = 0
        batch_no = 0
        with torch.no_grad():
            for batch in dataloader:

                batch_no += 1

                if verbose:
                    dataloader.set_description(
                        f"Inferencing on batch {batch_no}")

                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                # go through each sentence in the batch
                for sentence in batch:

                    # always remove tags first
                    sentence.remove_labels(label_name)

                    all_labels = [
                        label.decode("utf-8") for label in
                        self.get_current_label_dictionary().idx2item
                    ]

                    best_label = None
                    for label in all_labels:
                        tars_sentence = self._get_tars_formatted_sentence(
                            label, sentence)

                        loss_and_count = self.tars_model.predict(
                            tars_sentence,
                            label_name=label_name,
                            return_loss=True,
                            return_probabilities_for_all_classes=True
                            if label_threshold < 0.5 else False,
                        )

                        overall_loss += loss_and_count[0].item()
                        overall_count += loss_and_count[1]

                        # add all labels that according to TARS match the text and are above threshold
                        for predicted_tars_label in tars_sentence.get_labels(
                                label_name):
                            if predicted_tars_label.value == self.LABEL_MATCH \
                                    and predicted_tars_label.score > label_threshold:
                                # do not add labels below confidence threshold
                                sentence.add_label(label_name, label,
                                                   predicted_tars_label.score)

                    # only use label with highest confidence if enforcing single-label predictions
                    if not multi_label:
                        if len(sentence.get_labels(label_name)) > 0:
                            # get all label scores and do an argmax to get the best label
                            label_scores = torch.tensor([
                                label.score
                                for label in sentence.get_labels(label_name)
                            ],
                                                        dtype=torch.float)
                            best_label = sentence.get_labels(label_name)[
                                torch.argmax(label_scores)]

                            # remove previously added labels and only add the best label
                            sentence.remove_labels(label_name)
                            sentence.add_label(typename=label_name,
                                               value=best_label.value,
                                               score=best_label.score)

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

        if return_loss:
            return overall_loss, overall_count