def predict( self, sentences: Union[List[Sentence], Sentence, List[str], str], mini_batch_size=32, embedding_storage_mode="none", all_tag_prob: bool = False, verbose: bool = False, use_tokenizer: Union[bool, Callable[[str], List[Token]]] = space_tokenizer, ) -> List[Sentence]: """ Predict sequence tags for Named Entity Recognition task :param sentences: a Sentence or a string or a List of Sentence or a List of string. :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, up to a point when it has no more effect. :param embedding_storage_mode: 'none' for the minimum memory footprint, 'cpu' to store embeddings in Ram, 'gpu' to store embeddings in GPU memory. :param all_tag_prob: True to compute the score for each tag on each token, otherwise only the score of the best tag is returned :param verbose: set to True to display a progress bar :param use_tokenizer: a custom tokenizer when string are provided (default is space based tokenizer). :return: List of Sentence enriched by the predicted tags """ with torch.no_grad(): if not sentences: return sentences if isinstance(sentences, Sentence) or isinstance(sentences, str): sentences = [sentences] if (flair.device.type == "cuda") and embedding_storage_mode == "cpu": log.warning( "You are inferring on GPU with parameter 'embedding_storage_mode' set to 'cpu'." "This option will slow down your inference, usually 'none' (default value) " "is a better choice." ) # reverse sort all sequences by their length rev_order_len_index = sorted( range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True ) original_order_index = sorted( range(len(rev_order_len_index)), key=lambda k: rev_order_len_index[k] ) reordered_sentences: List[Union[Sentence, str]] = [ sentences[index] for index in rev_order_len_index ] if isinstance(sentences[0], Sentence): # remove previous embeddings store_embeddings(reordered_sentences, "none") dataset = SentenceDataset(reordered_sentences) else: dataset = StringDataset( reordered_sentences, use_tokenizer=use_tokenizer ) dataloader = DataLoader( dataset=dataset, batch_size=mini_batch_size, collate_fn=lambda x: x ) if self.use_crf: transitions = self.transitions.detach().cpu().numpy() else: transitions = None # progress bar for verbosity if verbose: dataloader = tqdm(dataloader) results: List[Sentence] = [] for i, batch in enumerate(dataloader): if verbose: dataloader.set_description(f"Inferencing on batch {i}") results += batch batch = self._filter_empty_sentences(batch) # stop if all sentences are empty if not batch: continue feature: torch.Tensor = self.forward(batch) tags, all_tags = self._obtain_labels( feature=feature, batch_sentences=batch, transitions=transitions, get_all_tags=all_tag_prob, ) for (sentence, sent_tags) in zip(batch, tags): for (token, tag) in zip(sentence.tokens, sent_tags): token.add_tag_label(self.tag_type, tag) # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided for (sentence, sent_all_tags) in zip(batch, all_tags): for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags): token.add_tags_proba_dist(self.tag_type, token_all_tags) # clearing token embeddings to save memory store_embeddings(batch, storage_mode=embedding_storage_mode) results: List[Union[Sentence, str]] = [ results[index] for index in original_order_index ] assert len(sentences) == len(results) return results
def predict( self, sentences: Union[List[Sentence], Sentence, List[str], str], mini_batch_size: int = 32, embedding_storage_mode="none", multi_class_prob: bool = False, verbose: bool = False, use_tokenizer: Union[bool, Callable[[str], List[Token]]] = space_tokenizer, ) -> List[Sentence]: """ Predicts the class labels for the given sentences. The labels are directly added to the sentences. :param sentences: list of sentences :param mini_batch_size: mini batch size to use :param embedding_storage_mode: 'none' for the minimum memory footprint, 'cpu' to store embeddings in Ram, 'gpu' to store embeddings in GPU memory. :param multi_class_prob : return probability for all class for multiclass :param verbose: set to True to display a progress bar :param use_tokenizer: a custom tokenizer when string are provided (default is space based tokenizer). :return: the list of sentences containing the labels """ with torch.no_grad(): if not sentences: return sentences if isinstance(sentences, Sentence) or isinstance(sentences, str): sentences = [sentences] if (flair.device.type == "cuda") and embedding_storage_mode == "cpu": log.warning( "You are inferring on GPU with parameter 'embedding_storage_mode' set to 'cpu'." "This option will slow down your inference, usually 'none' (default value) " "is a better choice." ) # reverse sort all sequences by their length rev_order_len_index = sorted( range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True ) original_order_index = sorted( range(len(rev_order_len_index)), key=lambda k: rev_order_len_index[k] ) reordered_sentences: List[Union[Sentence, str]] = [ sentences[index] for index in rev_order_len_index ] if isinstance(sentences[0], Sentence): # remove previous embeddings store_embeddings(reordered_sentences, "none") dataset = SentenceDataset(reordered_sentences) else: dataset = StringDataset( reordered_sentences, use_tokenizer=use_tokenizer ) dataloader = DataLoader( dataset=dataset, batch_size=mini_batch_size, collate_fn=lambda x: x ) # progress bar for verbosity if verbose: dataloader = tqdm(dataloader) results: List[Sentence] = [] for i, batch in enumerate(dataloader): if verbose: dataloader.set_description(f"Inferencing on batch {i}") results += batch batch = self._filter_empty_sentences(batch) # stop if all sentences are empty if not batch: continue scores = self.forward(batch) predicted_labels = self._obtain_labels( scores, predict_prob=multi_class_prob ) for (sentence, labels) in zip(batch, predicted_labels): sentence.labels = labels # clearing token embeddings to save memory store_embeddings(batch, storage_mode=embedding_storage_mode) results: List[Union[Sentence, str]] = [ results[index] for index in original_order_index ] assert len(sentences) == len(results) return results