def test_correct_sequence_elements_are_embedded(self): sequence_tensor = torch.randn([2, 5, 7]) # Concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(7, "x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 14] assert extractor.get_output_dim() == 14 assert extractor.get_input_dim() == 7 start_indices, end_indices = indices.split(1, -1) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. start_embeddings, end_embeddings = span_representations.split(7, -1) correct_start_embeddings = batched_index_select( sequence_tensor, start_indices.squeeze()) correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()) numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.data.numpy())
class DyGIE(Model): """ TODO(dwadden) document me. Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``text`` ``TextField`` we get as input to the model. context_layer : ``Seq2SeqEncoder`` This layer incorporates contextual information for each word in the document. feature_size: ``int`` The embedding size for all the embedded features, such as distances or span widths. submodule_params: ``TODO(dwadden)`` A nested dictionary specifying parameters to be passed on to initialize submodules. max_span_width: ``int`` The maximum width of candidate spans. target_task: ``str``: The task used to make early stopping decisions. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. module_initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the individual modules. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. display_metrics: ``List[str]``. A list of the metrics that should be printed out during model training. """ def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder, modules, # TODO(dwadden) Add type. feature_size: int, max_span_width: int, target_task: str, feedforward_params: Dict[str, Union[int, float]], loss_weights: Dict[str, float], initializer: InitializerApplicator = InitializerApplicator(), module_initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, display_metrics: List[str] = None) -> None: super(DyGIE, self).__init__(vocab, regularizer) #################### # Create span extractor. self._endpoint_span_extractor = EndpointSpanExtractor( embedder.get_output_dim(), combination="x,y", num_width_embeddings=max_span_width, span_width_embedding_dim=feature_size, bucket_widths=False) #################### # Set parameters. self._embedder = embedder self._loss_weights = loss_weights self._max_span_width = max_span_width self._display_metrics = self._get_display_metrics(target_task) token_emb_dim = self._embedder.get_output_dim() span_emb_dim = self._endpoint_span_extractor.get_output_dim() #################### # Create submodules. modules = Params(modules) # Helper function to create feedforward networks. def make_feedforward(input_dim): return FeedForward(input_dim=input_dim, num_layers=feedforward_params["num_layers"], hidden_dims=feedforward_params["hidden_dims"], activations=torch.nn.ReLU(), dropout=feedforward_params["dropout"]) # Submodules self._ner = NERTagger.from_params(vocab=vocab, make_feedforward=make_feedforward, span_emb_dim=span_emb_dim, feature_size=feature_size, params=modules.pop("ner")) self._coref = CorefResolver.from_params(vocab=vocab, make_feedforward=make_feedforward, span_emb_dim=span_emb_dim, feature_size=feature_size, params=modules.pop("coref")) self._relation = RelationExtractor.from_params(vocab=vocab, make_feedforward=make_feedforward, span_emb_dim=span_emb_dim, feature_size=feature_size, params=modules.pop("relation")) self._events = EventExtractor.from_params(vocab=vocab, make_feedforward=make_feedforward, token_emb_dim=token_emb_dim, span_emb_dim=span_emb_dim, feature_size=feature_size, params=modules.pop("events")) #################### # Initialize text embedder and all submodules for module in [self._ner, self._coref, self._relation, self._events]: module_initializer(module) initializer(self) @staticmethod def _get_display_metrics(target_task): """ The `target` is the name of the task used to make early stopping decisions. Show metrics related to this task. """ lookup = { "ner": [f"MEAN__{name}" for name in ["ner_precision", "ner_recall", "ner_f1"]], "relation": [f"MEAN__{name}" for name in ["relation_precision", "relation_recall", "relation_f1"]], "coref": ["coref_precision", "coref_recall", "coref_f1", "coref_mention_recall"], "events": [f"MEAN__{name}" for name in ["trig_class_f1", "arg_class_f1"]]} if target_task not in lookup: raise ValueError(f"Invalied value {target_task} has been given as the target task.") return lookup[target_task] @staticmethod def _debatch(x): # TODO(dwadden) Get rid of this when I find a better way to do it. return x if x is None else x.squeeze(0) @overrides def forward(self, text, spans, metadata, ner_labels=None, coref_labels=None, relation_labels=None, trigger_labels=None, argument_labels=None): """ TODO(dwadden) change this. """ # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it. if relation_labels is not None: relation_labels = relation_labels.long() if argument_labels is not None: argument_labels = argument_labels.long() # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the # extra dimension in the input tensors. Will return to this once the model runs. if len(metadata) > 1: raise NotImplementedError("Multi-document minibatching not supported.") metadata = metadata[0] spans = self._debatch(spans) # (n_sents, max_n_spans, 2) ner_labels = self._debatch(ner_labels) # (n_sents, max_n_spans) coref_labels = self._debatch(coref_labels) # (n_sents, max_n_spans) relation_labels = self._debatch(relation_labels) # (n_sents, max_n_spans, max_n_spans) trigger_labels = self._debatch(trigger_labels) # TODO(dwadden) argument_labels = self._debatch(argument_labels) # TODO(dwadden) # Encode using BERT, then debatch. # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension. # (1, n_sents, max_sententence_length, embedding_dim) # TODO(dwadden) Deal with the case where the input is longer than 512. text_embeddings = self._embedder(text, num_wrapping_dims=1) # (n_sents, max_n_wordpieces, embedding_dim) text_embeddings = self._debatch(text_embeddings) # (n_sents, max_sentence_length) text_mask = self._debatch(util.get_text_field_mask(text, num_wrapping_dims=1).float()) sentence_lengths = text_mask.sum(dim=1).long() # (n_sents) span_mask = (spans[:, :, 0] >= 0).float() # (n_sents, max_n_spans) # SpanFields return -1 when they are used as padding. As we do some comparisons based on # span widths when we attend over the span representations that we generate from these # indices, we need them to be <= 0. This is only relevant in edge cases where the number of # spans we consider after the pruning stage is >= the total number of spans, because in this # case, it is possible we might consider a masked span. spans = F.relu(spans.float()).long() # (n_sents, max_n_spans, 2) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) span_embeddings = self._endpoint_span_extractor(text_embeddings, spans) # Make calls out to the modules to get results. output_coref = {'loss': 0} output_ner = {'loss': 0} output_relation = {'loss': 0} output_events = {'loss': 0} # Prune and compute span representations for coreference module if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0: output_coref, coref_indices = self._coref.compute_representations( spans, span_mask, span_embeddings, sentence_lengths, coref_labels, metadata) # Propagation of global information to enhance the span embeddings if self._coref.coref_prop > 0: output_coref = self._coref.coref_propagation(output_coref) span_embeddings = self._coref.update_spans( output_coref, span_embeddings, coref_indices) # Make predictions and compute losses for each module if self._loss_weights['ner'] > 0: output_ner = self._ner( spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata) if self._loss_weights['coref'] > 0: output_coref = self._coref.predict_labels(output_coref, metadata) if self._loss_weights['relation'] > 0: output_relation = self._relation( spans, span_mask, span_embeddings, sentence_lengths, relation_labels, metadata) if self._loss_weights['events'] > 0: # The `text_embeddings` serve as representations for event triggers. output_events = self._events( text_mask, text_embeddings, spans, span_mask, span_embeddings, sentence_lengths, trigger_labels, argument_labels, ner_labels, metadata) # Use `get` since there are some cases where the output dict won't have a loss - for # instance, when doing prediction. loss = (self._loss_weights['coref'] * output_coref.get("loss", 0) + self._loss_weights['ner'] * output_ner.get("loss", 0) + self._loss_weights['relation'] * output_relation.get("loss", 0) + self._loss_weights['events'] * output_events.get("loss", 0)) # Multiply the loss by the weight multiplier for this document. weight = metadata.weight if metadata.weight is not None else 1.0 loss *= torch.tensor(weight) output_dict = dict(coref=output_coref, relation=output_relation, ner=output_ner, events=output_events) output_dict['loss'] = loss output_dict["metadata"] = metadata return output_dict def update_span_embeddings(self, span_embeddings, span_mask, top_span_embeddings, top_span_mask, top_span_indices): # TODO(Ulme) Speed this up by tensorizing new_span_embeddings = span_embeddings.clone() for sample_nr in range(len(top_span_mask)): for top_span_nr, span_nr in enumerate(top_span_indices[sample_nr]): if top_span_mask[sample_nr, top_span_nr] == 0 or span_mask[sample_nr, span_nr] == 0: break new_span_embeddings[sample_nr, span_nr] = top_span_embeddings[sample_nr, top_span_nr] return new_span_embeddings @overrides def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]): """ Converts the list of spans and predicted antecedent indices into clusters of spans for each element in the batch. Parameters ---------- output_dict : ``Dict[str, torch.Tensor]``, required. The result of calling :func:`forward` on an instance or batch of instances. Returns ------- The same output dictionary, but with an additional ``clusters`` key: clusters : ``List[List[List[Tuple[int, int]]]]`` A nested list, representing, for each instance in the batch, the list of clusters, which are in turn comprised of a list of (start, end) inclusive spans into the original document. """ doc = copy.deepcopy(output_dict["metadata"]) if self._loss_weights["coref"] > 0: # TODO(dwadden) Will need to get rid of the [0] when batch training is enabled. decoded_coref = self._coref.make_output_human_readable(output_dict["coref"])["predicted_clusters"][0] sentences = doc.sentences sentence_starts = [sent.sentence_start for sent in sentences] predicted_clusters = [document.Cluster(entry, i, sentences, sentence_starts) for i, entry in enumerate(decoded_coref)] doc.predicted_clusters = predicted_clusters # TODO(dwadden) update the sentences with cluster information. if self._loss_weights["ner"] > 0: for predictions, sentence in zip(output_dict["ner"]["predictions"], doc): sentence.predicted_ner = predictions if self._loss_weights["relation"] > 0: for predictions, sentence in zip(output_dict["relation"]["predictions"], doc): sentence.predicted_relations = predictions if self._loss_weights["events"] > 0: for predictions, sentence in zip(output_dict["events"]["predictions"], doc): sentence.predicted_events = predictions return doc def get_metrics(self, reset: bool = False) -> Dict[str, float]: """ Get all metrics from all modules. For the ones that shouldn't be displayed, prefix their keys with an underscore. """ metrics_coref = self._coref.get_metrics(reset=reset) metrics_ner = self._ner.get_metrics(reset=reset) metrics_relation = self._relation.get_metrics(reset=reset) metrics_events = self._events.get_metrics(reset=reset) # Make sure that there aren't any conflicting names. metric_names = (list(metrics_coref.keys()) + list(metrics_ner.keys()) + list(metrics_relation.keys()) + list(metrics_events.keys())) assert len(set(metric_names)) == len(metric_names) all_metrics = dict(list(metrics_coref.items()) + list(metrics_ner.items()) + list(metrics_relation.items()) + list(metrics_events.items())) # If no list of desired metrics given, display them all. if self._display_metrics is None: return all_metrics # Otherwise only display the selected ones. res = {} for k, v in all_metrics.items(): if k in self._display_metrics: res[k] = v else: new_k = "_" + k res[new_k] = v return res
class TweetJointly(Model): def __init__( self, vocab: Vocabulary, transformer_model_name: str = "bert-base-uncased", feedforward: Optional[FeedForward] = None, smoothing: bool = False, smooth_alpha: float = 0.7, sentiment_task: bool = False, sentiment_task_weight: float = 1.0, sentiment_classification_with_label: bool = True, sentiment_seq2vec: Optional[Seq2VecEncoder] = None, candidate_span_task: bool = False, candidate_span_task_weight: float = 1.0, candidate_delay: int = 30000, candidate_span_num: int = 5, candidate_classification_layer_units: int = 128, candidate_span_extractor: Optional[SpanExtractor] = None, candidate_span_with_logits: bool = False, dropout: Optional[float] = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) if "BERTweet" not in transformer_model_name: self._text_field_embedder = BasicTextFieldEmbedder({ "tokens": PretrainedTransformerEmbedder(transformer_model_name) }) else: self._text_field_embedder = BasicTextFieldEmbedder( {"tokens": TweetBertEmbedder(transformer_model_name)}) # span start & end task if feedforward is None: self._linear_layer = nn.Sequential( nn.Linear(self._text_field_embedder.get_output_dim(), 128), nn.ReLU(), nn.Linear(128, 2), ) else: self._linear_layer = feedforward self._span_start_accuracy = CategoricalAccuracy() self._span_end_accuracy = CategoricalAccuracy() self._span_accuracy = BooleanAccuracy() self._jaccard = Jaccard() self._candidate_delay = candidate_delay self._delay = 0 self._smoothing = smoothing self._smooth_alpha = smooth_alpha if smoothing: self._loss = nn.KLDivLoss(reduction="batchmean") else: self._loss = nn.CrossEntropyLoss() # sentiment task self._sentiment_task = sentiment_task if self._sentiment_task: self._sentiment_classification_accuracy = CategoricalAccuracy() self._sentiment_loss_log = LossLog() self.register_buffer("sentiment_task_weight", torch.tensor(sentiment_task_weight)) self._sentiment_classification_with_label = ( sentiment_classification_with_label) if sentiment_seq2vec is None: raise ConfigurationError( "sentiment task is True, we need a sentiment seq2vec encoder" ) else: self._sentiment_encoder = sentiment_seq2vec self._sentiment_linear = nn.Linear( self._sentiment_encoder.get_output_dim(), vocab.get_vocab_size("labels"), ) # candidate span task self._candidate_span_task = candidate_span_task if candidate_span_task: assert candidate_span_num > 0 assert candidate_span_task_weight > 0 assert candidate_classification_layer_units > 0 self._candidate_span_num = candidate_span_num self.register_buffer("candidate_span_task_weight", torch.tensor(candidate_span_task_weight)) self._candidate_classification_layer_units = ( candidate_classification_layer_units) self._span_classification_accuracy = CategoricalAccuracy() self._candidate_loss_log = LossLog() self._candidate_span_linear = nn.Linear( self._text_field_embedder.get_output_dim(), self._candidate_classification_layer_units, ) if candidate_span_extractor is None: self._candidate_span_extractor = EndpointSpanExtractor( input_dim=self._candidate_classification_layer_units) else: self._candidate_span_extractor = candidate_span_extractor if candidate_span_with_logits: self._candidate_with_logits = True self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim() + 1, 1) else: self._candidate_with_logits = False self._candidate_span_vec_linear = nn.Linear( self._candidate_span_extractor.get_output_dim(), 1) self._candidate_jaccard = Jaccard() if sentiment_task or candidate_span_task: self._base_loss_log = LossLog() else: self._base_loss_log = None if dropout is not None: self._dropout = nn.Dropout(dropout) else: self._dropout = None def forward( # type: ignore self, text: Dict[str, Dict[str, torch.LongTensor]], sentiment: torch.IntTensor, text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]], text_span: torch.IntTensor, selected_text_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # batch_size * text_length * hidden_dims embedded_question = self._text_field_embedder(text_with_sentiment) if self._dropout is not None: embedded_question = self._dropout(embedded_question) self._delay += int(embedded_question.size(0)) # span start & span end task logits = self._linear_layer(embedded_question) span_start_logits, span_end_logits = logits.split(1, dim=-1) span_start_logits = span_start_logits.squeeze(-1) span_end_logits = span_end_logits.squeeze(-1) possible_answer_mask = torch.zeros_like( util.get_token_ids_from_text_field_tensors( text_with_sentiment)).bool() for i, (start, end) in enumerate(text_span): possible_answer_mask[i, start:end + 1] = True span_start_logits = util.replace_masked_values(span_start_logits, possible_answer_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, possible_answer_mask, -1e32) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_spans = get_best_span(span_start_logits, span_end_logits) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_spans, "best_span_scores": best_span_scores, } loss = torch.tensor(0.0).to(embedded_question.device) # sentiment task if self._sentiment_task: if self._sentiment_classification_with_label: global_context_vec = self._sentiment_encoder(embedded_question) else: embedded_only_text = self._text_field_embedder(text) if self._dropout is not None: embedded_only_text = self._dropout(embedded_only_text) global_context_vec = self._sentiment_encoder( embedded_only_text) sentiment_logits = self._sentiment_linear(global_context_vec) sentiment_probs = torch.softmax(sentiment_logits, dim=-1) self._sentiment_classification_accuracy(sentiment_probs, sentiment) sentiment_loss = cross_entropy(sentiment_logits, sentiment) self._sentiment_loss_log(sentiment_loss) loss.add_(self.sentiment_task_weight * sentiment_loss) predict_sentiment_idx = sentiment_probs.argmax(dim=-1) sentiment_predicts = [] for i in predict_sentiment_idx.tolist(): sentiment_predicts.append( self.vocab.get_token_from_index(i, "labels")) output_dict["sentiment_logits"] = sentiment_logits output_dict["sentiment_probs"] = sentiment_probs output_dict["sentiment_predicts"] = sentiment_predicts # span classification if self._candidate_span_task and (self._delay >= self._candidate_delay): # shape: (batch_size, passage_length, embedding_dim) text_features_for_candidate = self._candidate_span_linear( embedded_question) text_features_for_candidate = torch.relu( text_features_for_candidate) with torch.no_grad(): # batch_size * candidate_num * 2 candidate_span = get_candidate_span(span_start_probs, span_end_probs, self._candidate_span_num) candidate_span_list = candidate_span.tolist() output_dict["candidate_spans"] = candidate_span_list if selected_text_span is not None: candidate_span, candidate_span_label = self.candidate_span_with_labels( candidate_span, selected_text_span) else: candidate_span_label = None # shape: (batch_size, candidate_num, span_extractor_output_dim) span_feature_vec = self._candidate_span_extractor( text_features_for_candidate, candidate_span) if self._candidate_with_logits: candidate_span_start_logits = torch.gather( span_start_logits, 1, candidate_span[:, :, 0]) candidate_span_end_logits = torch.gather( span_end_logits, 1, candidate_span[:, :, 1]) candidate_span_sum_logits = (candidate_span_start_logits + candidate_span_end_logits) span_feature_vec = torch.cat( (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)), -1) # batch_size * candidate_num span_classification_logits = self._candidate_span_vec_linear( span_feature_vec).squeeze() span_classification_probs = torch.softmax( span_classification_logits, -1) output_dict[ "span_classification_probs"] = span_classification_probs candidate_best_span_idx = span_classification_probs.argmax(dim=-1) view_idx = ( candidate_best_span_idx + torch.arange(0, end=candidate_best_span_idx.shape[0]).to( candidate_best_span_idx.device) * self._candidate_span_num) candidate_span_view = candidate_span.view(-1, 2) candidate_best_spans = candidate_span_view.index_select( 0, view_idx) output_dict["candidate_best_spans"] = candidate_best_spans.tolist() if selected_text_span is not None: self._span_classification_accuracy(span_classification_probs, candidate_span_label) candidate_span_loss = cross_entropy(span_classification_logits, candidate_span_label) self._candidate_loss_log(candidate_span_loss) weighted_loss = self.candidate_span_task_weight * candidate_span_loss if candidate_span_loss > 1e2: print(f"candidate loss: {candidate_span_loss}") print( f"span_classification_logits: {span_classification_logits}" ) print(f"candidate_span_label: {candidate_span_label}") loss.add_(weighted_loss) candidate_best_spans = candidate_best_spans.detach().cpu().numpy() output_dict["best_candidate_span_str"] = [] for metadata_entry, best_span in zip(metadata, candidate_best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) if predicted_end >= len(text_with_sentiment_tokens): predicted_end = len(text_with_sentiment_tokens) - 1 best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_candidate_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._candidate_jaccard(best_span_string, answers) # Compute the loss for training. if selected_text_span is not None: span_start = selected_text_span[:, 0] span_end = selected_text_span[:, 1] span_mask = span_start != -1 self._span_accuracy( best_spans, selected_text_span, span_mask.unsqueeze(-1).expand_as(best_spans), ) if not self._smoothing: start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) if torch.any(start_loss > 1e9): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start) logger.critical("text_with_sentiment: %r", text_with_sentiment) assert False end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) if torch.any(end_loss > 1e9): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end) assert False else: sequence_length = span_start_logits.size(1) device = span_start.device start_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_start) start_smooth_probs = torch.exp( start_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) start_smooth_probs = start_smooth_probs * possible_answer_mask start_smooth_probs = start_smooth_probs / start_smooth_probs.sum( -1, keepdim=True) span_start_log_probs = span_start_logits - torch.log( torch.exp(span_start_logits).sum(-1)).unsqueeze(-1) end_distance = get_sequence_distance_from_span_endpoint( sequence_length, span_end) end_smooth_probs = torch.exp( end_distance * torch.log(torch.tensor(self._smooth_alpha).to(device))) end_smooth_probs = end_smooth_probs * possible_answer_mask end_smooth_probs = end_smooth_probs / end_smooth_probs.sum( -1, keepdim=True) span_end_log_probs = span_end_logits - torch.log( torch.exp(span_end_logits).sum(-1)).unsqueeze(-1) # print(end_smooth_probs) # print(start_smooth_probs) # print(span_end_log_probs) # print(span_start_log_probs) start_loss = self._loss(span_start_log_probs, start_smooth_probs) end_loss = self._loss(span_end_log_probs, end_smooth_probs) span_start_end_loss = (start_loss + end_loss) / 2 if self._base_loss_log is not None: self._base_loss_log(span_start_end_loss) loss.add_(span_start_end_loss) self._span_start_accuracy(span_start_logits, span_start, span_mask) self._span_end_accuracy(span_end_logits, span_end, span_mask) output_dict["loss"] = loss # compute best span jaccard best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] for metadata_entry, best_span in zip(metadata, best_spans): text_with_sentiment_tokens = metadata_entry[ "text_with_sentiment_tokens"] predicted_start, predicted_end = tuple(best_span) best_span_string = self.span_tokens_to_text( metadata_entry["text"], text_with_sentiment_tokens, predicted_start, predicted_end, ) output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("selected_text", "") if len(answers) > 0: self._jaccard(best_span_string, answers) return output_dict # @staticmethod # def candidate_span_with_labels( # candidate_span: torch.Tensor, selected_text_span: torch.Tensor # ) -> Tuple[torch.Tensor, torch.Tensor]: # correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1) # candidate_span_adjust = torch.where( # ~(correct_span_idx.unsqueeze(-1) == 1), # candidate_span, # selected_text_span.unsqueeze(1), # ) # candidate_span_label = correct_span_idx.argmax(-1) # return candidate_span_adjust, candidate_span_label @staticmethod def candidate_span_with_labels( candidate_span: torch.Tensor, selected_text_span: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: candidate_span_label = batch_span_jaccard( candidate_span, selected_text_span).max(-1).indices return candidate_span, candidate_span_label @staticmethod def get_candidate_span_mask(candidate_span: torch.Tensor, passage_length: int) -> torch.Tensor: device = candidate_span.device batch_size, candidate_num = candidate_span.size()[:-1] candidate_span_mask = torch.zeros(batch_size, candidate_num, passage_length).to(device) for i in range(batch_size): for j in range(candidate_num): span_start, span_end = candidate_span[i][j] candidate_span_mask[i][j][span_start:span_end + 1] = 1 return candidate_span_mask @staticmethod def span_tokens_to_text(source_text, tokens, span_start, span_end): text_with_sentiment_tokens = tokens predicted_start = span_start predicted_end = span_end while (predicted_start >= 0 and text_with_sentiment_tokens[predicted_start].idx is None): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index " f"'{span_start}' to an offset in the original text.") character_start = 0 else: character_start = text_with_sentiment_tokens[predicted_start].idx while (predicted_end < len(text_with_sentiment_tokens) and text_with_sentiment_tokens[predicted_end].idx is None): predicted_end -= 1 if predicted_end >= len(text_with_sentiment_tokens): print(text_with_sentiment_tokens) print(len(text_with_sentiment_tokens)) print(span_end) print(predicted_end) logger.warning( f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index " f"'{span_end}' to an offset in the original text.") character_end = len(source_text) else: end_token = text_with_sentiment_tokens[predicted_end] if end_token.idx == 0: character_end = (end_token.idx + len(sanitize_wordpiece(end_token.text)) + 1) else: character_end = end_token.idx + len( sanitize_wordpiece(end_token.text)) best_span_string = source_text[character_start:character_end].strip() return best_span_string def get_metrics(self, reset: bool = False) -> Dict[str, float]: jaccard = self._jaccard.get_metric(reset) metrics = { "start_acc": self._span_start_accuracy.get_metric(reset), "end_acc": self._span_end_accuracy.get_metric(reset), "span_acc": self._span_accuracy.get_metric(reset), "jaccard": jaccard, } if self._candidate_span_task: metrics[ "candidate_span_acc"] = self._span_classification_accuracy.get_metric( reset) metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric( reset) metrics["candidate_loss"] = self._candidate_loss_log.get_metric( reset) if self._sentiment_task: metrics[ "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric( reset) metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric( reset) if self._base_loss_log is not None: metrics["base_loss"] = self._base_loss_log.get_metric(reset) return metrics
class SCIIE(Model): """ Parameters ---------- vocab : ``Vocabulary`` text_field_embedder : ``TextFieldEmbedder`` Used to embed the ``text`` ``TextField`` we get as input to the model. context_layer : ``Seq2SeqEncoder`` This layer incorporates contextual information for each word in the document. mention_feedforward : ``FeedForward`` This feedforward network is applied to the span representations which is then scored by a linear layer. antecedent_feedforward: ``FeedForward`` This feedforward network is applied to pairs of span representation, along with any pairwise features, which is then scored by a linear layer. feature_size: ``int`` The embedding size for all the embedded features, such as distances or span widths. max_span_width: ``int`` The maximum width of candidate spans. spans_per_word: float, required. A multiplier between zero and one which controls what percentage of candidate mention spans we retain with respect to the number of words in the document. max_antecedents: int, required. For each mention which survives the pruning stage, we consider this many antecedents. lexical_dropout: ``int`` The probability of dropping out dimensions of the embedded text. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__(self, vocab: Vocabulary, embedding_dim: int, feature_size: int, max_span_width: int, spans_per_word: float, lexical_dropout: float = 0.2, mlp_dropout: float = 0.4, embedder_type=None, regularizer: Optional[RegularizerApplicator] = None) -> None: super(SCIIE, self).__init__(vocab, regularizer) self.class_num = self.vocab.get_vocab_size('labels') word_embeddings = get_embeddings(embedder_type, self.vocab, embedding_dim, True) embedding_dim = word_embeddings.get_output_dim() self._text_field_embedder = word_embeddings context_layer = PytorchSeq2SeqWrapper( torch.nn.LSTM(embedding_dim, feature_size, batch_first=True, bidirectional=True)) self._context_layer = context_layer endpoint_span_extractor_input_dim = context_layer.get_output_dim() attentive_span_extractor_input_dim = word_embeddings.get_output_dim() self._endpoint_span_extractor = EndpointSpanExtractor( endpoint_span_extractor_input_dim, combination="x,y", num_width_embeddings=max_span_width, span_width_embedding_dim=feature_size, bucket_widths=False) self._attentive_span_extractor = SelfAttentiveSpanExtractor( input_dim=attentive_span_extractor_input_dim) # self._span_extractor = PoolingSpanExtractor(embedding_dim, # num_width_embeddings=max_span_width, # span_width_embedding_dim=feature_size, # bucket_widths=False) entity_feedforward = FeedForward( self._endpoint_span_extractor.get_output_dim() + self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu, mlp_dropout) # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150, # F.relu, mlp_dropout) feedforward_scorer = torch.nn.Sequential( TimeDistributed(entity_feedforward), TimeDistributed( torch.nn.Linear(entity_feedforward.get_output_dim(), 1))) self._mention_pruner = Pruner(feedforward_scorer) self._entity_scorer = torch.nn.Sequential( TimeDistributed(entity_feedforward), TimeDistributed( torch.nn.Linear(entity_feedforward.get_output_dim(), self.class_num - 1))) self._max_span_width = max_span_width self._spans_per_word = spans_per_word if lexical_dropout > 0: self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout) else: self._lexical_dropout = lambda x: x self._metric_all = FBetaMeasure() self._metric_avg = NERF1Metric() @overrides def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1]) # Shape: (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, class_num + 1) ne_scores = self._compute_named_entity_scores(top_span_embeddings) # Shape: (batch_size, num_spans_to_keep) _, predicted_named_entities = ne_scores.max(2) output_dict = { "top_spans": top_spans, "predicted_named_entities": predicted_named_entities } if labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) negative_log_likelihood = F.cross_entropy( ne_scores.reshape(-1, self.class_num), pruned_gold_labels.reshape(-1)) pruner_loss = F.binary_cross_entropy_with_logits( top_span_mention_scores.reshape(-1), (pruned_gold_labels.reshape(-1) != 0).float()) loss = negative_log_likelihood + pruner_loss output_dict["loss"] = loss output_dict["pruner_loss"] = pruner_loss batch_size, _ = labels.shape all_scores = ne_scores.new_zeros( [batch_size * num_spans, self.class_num]) all_scores[:, 0] = 1 all_scores[flat_top_span_indices] = ne_scores.reshape( -1, self.class_num) all_scores = all_scores.reshape( [batch_size, num_spans, self.class_num]) self._metric_all(all_scores, labels) self._metric_avg(all_scores, labels) return output_dict @overrides def get_metrics(self, reset: bool = False, prefix=""): metric = self._metric_all.get_metric(reset) metric2 = self._metric_avg.get_metric(reset) metric.update(metric2) return metric def _compute_named_entity_scores( self, span_embeddings: torch.FloatTensor) -> torch.Tensor: """ Parameters ---------- span_embeddings: ``torch.FloatTensor``, required. Embedding representations of spans. Has shape (batch_size, num_spans_to_keep, encoding_dim) """ # Shape: (batch_size, num_spans_to_keep, class_num) scores = self._entity_scorer(span_embeddings) # Shape: (batch_size, num_spans_to_keep, 1) shape = [scores.size(0), scores.size(1), 1] dummy_scores = scores.new_full(shape, 0) ne_scores = torch.cat([dummy_scores, scores], -1) return ne_scores
class SrlE2e(Model): """ # Parameters vocab : `Vocabulary`, required A Vocabulary, required in order to compute sizes for input/output projections. model : `Union[str, BertModel]`, required. A string describing the BERT model to load or an already constructed BertModel. initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`) Used to initialize the model parameters. label_smoothing : `float`, optional (default = `0.0`) Whether or not to use label smoothing on the labels when computing cross entropy loss. ignore_span_metric : `bool`, optional (default = `False`) Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction. srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`) The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp, which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used. """ def __init__( self, vocab: Vocabulary, bert_model: Union[str, BertModel], mention_feedforward: FeedForward, context_layer: Seq2SeqEncoder = None, embedding_dropout: float = 0.0, initializer: InitializerApplicator = InitializerApplicator(), max_span_width: int = 30, feature_size: int = 10, spans_per_word: float = 100, label_smoothing: float = None, ignore_span_metric: bool = False, srl_eval_path: str = DEFAULT_SRL_EVAL_PATH, **kwargs, ) -> None: super().__init__(vocab, **kwargs) if isinstance(bert_model, str): self.bert_model = BertModel.from_pretrained(bert_model) else: self.bert_model = bert_model self.num_classes = self.vocab.get_vocab_size("span_labels") if srl_eval_path is not None: # For the span based evaluation, we don't want to consider labels # for verb, because the verb index is provided to the model. self.span_metric = SrlEvalScorer(srl_eval_path, ignore_classes=["V"]) else: self.span_metric = None self.tag_projection_layer = Linear(self.bert_model.config.hidden_size, self.num_classes) self.embedding_dropout = Dropout(p=embedding_dropout) self._label_smoothing = label_smoothing self.ignore_span_metric = ignore_span_metric self._mention_feedforward = TimeDistributed(mention_feedforward) self._mention_scorer = TimeDistributed( torch.nn.Linear(mention_feedforward.get_output_dim(), 1)) self._attentive_span_extractor = SelfAttentiveSpanExtractor( input_dim=self.bert_model.config.hidden_size) self.span_representation_dim = self._attentive_span_extractor.get_output_dim( ) self._context_layer = context_layer if context_layer is not None: self._endpoint_span_extractor = EndpointSpanExtractor( context_layer.get_output_dim(), combination="x,y", num_width_embeddings=max_span_width, span_width_embedding_dim=feature_size, bucket_widths=False, ) self.span_representation_dim = self._endpoint_span_extractor.get_output_dim( ) self.hidden_layer = torch.nn.Sequential( torch.nn.Linear(self.span_representation_dim + self.bert_model.config.hidden_size, self.span_representation_dim, bias=False), torch.nn.ReLU()) self.output_layer = torch.nn.Linear(self.span_representation_dim, self.num_classes - 1, bias=False) self._max_span_width = max_span_width self._spans_per_word = spans_per_word self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none') self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none') initializer(self) def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, sentence_end: torch.LongTensor, spans: torch.LongTensor, span_labels: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, ): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) start = time.time() bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() embedded_text_input = self.embedding_dropout(bert_embeddings) batch_size, sequence_length, _ = embedded_text_input.size() # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( bert_embeddings, spans) if self._context_layer is not None: contextualized_embeddings = self._context_layer( embedded_text_input, mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) span_embeddings = endpoint_span_embeddings else: span_embeddings = attended_span_embeddings # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * sequence_length)) num_spans = spans.shape[1] num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings)).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep) verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat( 1, 1, embedded_text_input.shape[-1]) verb_embeddings = torch.gather(embedded_text_input, 1, verb_index) assert len( verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1 verb_embeddings = verb_embeddings.squeeze(1) # print(verb_indicator.sum(1, keepdim=True) > 0) verb_embeddings = torch.where( (verb_indicator.sum(1, keepdim=True) > 0).repeat( 1, verb_embeddings.shape[-1]), verb_embeddings, torch.zeros_like(verb_embeddings)) # print(verb_embeddings) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, spans.shape[1]) span_embeddings = util.batched_index_select(span_embeddings, top_span_indices, flat_top_span_indices) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) top_span_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) concatenated_span_embeddings = torch.cat( (span_embeddings, verb_embeddings.unsqueeze(1).repeat( 1, span_embeddings.shape[1], 1)), dim=2) # print(concatenated_span_embeddings[:,:,:]) hidden = self.hidden_layer(concatenated_span_embeddings) # print(hidden[1,:,:]) # print(top_span_indices) # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])]) # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels")) predictions = self.output_layer(hidden) # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1) predictions = torch.cat( (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1) # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)) output_dict = {} # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]), top_span_labels.view(-1)) * top_span_mask.float().view(-1) ).sum() / top_span_mask.float().sum() # print(top_span_labels) # print(predictions.argmax(-1)) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.get_tags( top_spans, predictions, mask.shape[1], top_span_mask, output_dict) from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] # print('G', batch_bio_gold_tags) batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss return output_dict def get_tags(self, spans, logits, sequence_length, span_mask, output_dict): predicted_tag_ids = logits.argmax(2) predicted_tags = [] for i in range(spans.shape[0]): sequence = ["O" for _ in range(sequence_length)] for j in range(spans.shape[1]): if span_mask[i, j].item() == 0: continue tag = predicted_tag_ids[i, j].item() if tag != self.vocab.get_token_index("O", namespace="span_labels"): start = spans[i, j, 0].item() end = spans[i, j, 1].item() if all([el == "O" for el in sequence[start:end + 1]]): tag = self.vocab.get_token_from_index( tag, namespace="span_labels") sequence[start] = "B-" + tag for index in range(start + 1, end + 1): sequence[index] = "I-" + tag predicted_tags.append( [sequence[ind] for ind in output_dict["wordpiece_offsets"][i]]) print(predicted_tags) return predicted_tags @overrides def make_output_human_readable( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Does constrained viterbi decoding on class probabilities output in :func:`forward`. The constraint simply specifies that the output tags must be a valid BIO sequence. We add a `"tags"` key to the dictionary with the result. NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi decoding produces low quality output if you decode on top of word representations directly, because the model gets confused by the 'missing' positions (which is sensible as it is trained to perform tagging on wordpieces, not words). Secondly, it's important that the indices we use to recover words from the wordpieces are the start_offsets (i.e offsets which correspond to using the first wordpiece of words which are tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence when we select out the word tags from the wordpiece tags. This happens in the case that a word is split into multiple word pieces, and then we take the last tag of the word, which might correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag. """ all_predictions = output_dict["class_probabilities"] sequence_lengths = get_lengths_from_binary_sequence_mask( output_dict["mask"]).data.tolist() if all_predictions.dim() == 3: predictions_list = [ all_predictions[i].detach().cpu() for i in range(all_predictions.size(0)) ] else: predictions_list = [all_predictions] wordpiece_tags = [] word_tags = [] transition_matrix = self.get_viterbi_pairwise_potentials() start_transitions = self.get_start_transitions() # **************** Different ******************** # We add in the offsets here so we can compute the un-wordpieced tags. for predictions, length, offsets in zip( predictions_list, sequence_lengths, output_dict["wordpiece_offsets"]): max_likelihood_sequence, _ = viterbi_decode( predictions[:length], transition_matrix, allowed_start_transitions=start_transitions) tags = [ self.vocab.get_token_from_index(x, namespace="labels") for x in max_likelihood_sequence ] wordpiece_tags.append(tags) word_tags.append([tags[i] for i in offsets]) output_dict["wordpiece_tags"] = wordpiece_tags output_dict["tags"] = word_tags return output_dict def get_metrics(self, reset: bool = False): if self.ignore_span_metric: # Return an empty dictionary if ignoring the # span metric return {} else: metric_dict = self.span_metric.get_metric(reset=reset) # This can be a lot of metrics, as there are 3 per class. # we only really care about the overall metrics, so we filter for them here. return {x: y for x, y in metric_dict.items() if "overall" in x} def get_viterbi_pairwise_potentials(self): """ Generate a matrix of pairwise transition potentials for the BIO labels. The only constraint implemented here is that I-XXX labels must be preceded by either an identical I-XXX tag or a B-XXX tag. In order to achieve this constraint, pairs of labels which do not satisfy this constraint have a pairwise potential of -inf. # Returns transition_matrix : `torch.Tensor` A `(num_labels, num_labels)` matrix of pairwise potentials. """ all_labels = self.vocab.get_index_to_token_vocabulary("labels") num_labels = len(all_labels) transition_matrix = torch.zeros([num_labels, num_labels]) for i, previous_label in all_labels.items(): for j, label in all_labels.items(): # I labels can only be preceded by themselves or # their corresponding B tag. if i != j and label[ 0] == "I" and not previous_label == "B" + label[1:]: transition_matrix[i, j] = float("-inf") return transition_matrix def get_start_transitions(self): """ In the BIO sequence, we cannot start the sequence with an I-XXX tag. This transition sequence is passed to viterbi_decode to specify this constraint. # Returns start_transitions : `torch.Tensor` The pairwise potentials between a START token and the first token of the sequence. """ all_labels = self.vocab.get_index_to_token_vocabulary("labels") num_labels = len(all_labels) start_transitions = torch.zeros(num_labels) for i, label in all_labels.items(): if label[0] == "I": start_transitions[i] = float("-inf") return start_transitions default_predictor = "semantic_role_labeling"