def _get_tars_formatted_sentence(self, label, sentence): original_text = sentence.to_tokenized_string() label_text_pair = (f"{label} {self.separator} {original_text}" if self.prefix else f"{original_text} {self.separator} {label}") label_length = 0 if not self.prefix else len(label.split(" ")) + len( self.separator.split(" ")) # make a tars sentence where all labels are O by default tars_sentence = Sentence(label_text_pair, use_tokenizer=False) for entity_label in sentence.get_labels(self.label_type): if entity_label.value == label: new_span = Span([ tars_sentence.get_token(token.idx + label_length) for token in entity_label.data_point ]) new_span.add_label(self.static_label_type, value="entity") return tars_sentence
def predict( self, sentences: Union[List[Sentence], Sentence], mini_batch_size=32, return_probabilities_for_all_classes: bool = False, verbose: bool = False, label_name: Optional[str] = None, return_loss=False, embedding_storage_mode="none", most_probable_first: bool = True, ): # return """ Predict sequence tags for Named Entity Recognition task :param sentences: a Sentence or a List of Sentence :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, up to a point when it has no more effect. :param all_tag_prob: True to compute the score for each tag on each token, otherwise only the score of the best tag is returned :param verbose: set to True to display a progress bar :param return_loss: set to True to return loss :param label_name: set this to change the name of the label type that is predicted :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. """ if label_name is None: label_name = self.get_current_label_type() # with torch.no_grad(): if not sentences: return sentences if not isinstance(sentences, list): sentences = [sentences] reordered_sentences = sorted(sentences, key=lambda s: len(s), reverse=True) dataloader = DataLoader( dataset=FlairDatapointDataset(reordered_sentences), batch_size=mini_batch_size, ) # progress bar for verbosity if verbose: dataloader = tqdm(dataloader) overall_loss = 0 overall_count = 0 with torch.no_grad(): for batch in dataloader: batch = self._filter_empty_sentences(batch) # stop if all sentences are empty if not batch: continue # go through each sentence in the batch for sentence in batch: # always remove tags first sentence.remove_labels(label_name) all_labels = [ label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item ] all_detected = {} for label in all_labels: tars_sentence = self._get_tars_formatted_sentence( label, sentence) loss_and_count = self.tars_model.predict( tars_sentence, label_name=label_name, return_loss=True, ) overall_loss += loss_and_count[0].item() overall_count += loss_and_count[1] for predicted in tars_sentence.get_labels(label_name): predicted.value = label all_detected[predicted] = predicted.score if most_probable_first: import operator already_set_indices: List[int] = [] sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) sorted_x.reverse() for tuple in sorted_x: # get the span and its label label = tuple[0] # label = span.get_labels("tars_temp_label")[0].value label_length = (0 if not self.prefix else len(label.value.split(" ")) + len(self.separator.split(" "))) # determine whether tokens in this span already have a label tag_this = True for token in label.data_point: corresponding_token = sentence.get_token( token.idx - label_length) if corresponding_token is None: tag_this = False continue if token.idx in already_set_indices: tag_this = False continue # only add if all tokens have no label if tag_this: already_set_indices.extend( token.idx for token in label.data_point) predicted_span = Span([ sentence.get_token(token.idx - label_length) for token in label.data_point ]) predicted_span.add_label(label_name, value=label.value, score=label.score) # clearing token embeddings to save memory store_embeddings(batch, storage_mode=embedding_storage_mode) if return_loss: return overall_loss, overall_count