def forward( self, # type: ignore text: TextFieldTensors, spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters text : `TextFieldTensors`, required. The output of a `TextField` representing the text of the document. spans : `torch.IntTensor`, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of indices into the text of the document. span_labels : `torch.IntTensor`, optional (default = None). A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : `List[Dict[str, Any]]`, optional (default = None). A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance. # Returns An output dictionary consisting of: top_spans : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : `torch.IntTensor` A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) batch_size = spans.size(0) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings) ).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep ) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices ) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. if self._coarse_to_fine: pruned_antecedents = self._coarse_to_fine_pruning( top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents ) else: pruned_antecedents = self._distance_pruning( top_span_embeddings, top_span_mention_scores, max_antecedents ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, ) = pruned_antecedents flat_top_antecedent_indices = util.flatten_and_batch_shift_indices( top_antecedent_indices, num_spans_to_keep ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) for _ in range(self._inference_order - 1): dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,) top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) attention_weight = util.masked_softmax( coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size) top_antecedent_with_dummy_embeddings = torch.cat( [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2 ) # Shape: (batch_size, num_spans_to_keep, embedding_size) attended_embeddings = util.weighted_sum( top_antecedent_with_dummy_embeddings, attention_weight ) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = self._span_updating_gated_sum( top_span_embeddings, attended_embeddings ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": top_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) antecedent_labels = util.batched_index_select( pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices ).squeeze(-1) antecedent_labels = util.replace_masked_values( antecedent_labels, top_antecedent_mask, -100 ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels ) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask.unsqueeze(-1) ) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores( top_spans, top_antecedent_indices, predicted_antecedents, metadata ) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def _coarse_to_fine_pruning( self, top_span_embeddings: torch.FloatTensor, top_span_mention_scores: torch.FloatTensor, top_span_mask: torch.BoolTensor, max_antecedents: int, ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]: """ Generates antecedents for each span and prunes down to `max_antecedents`. This method prunes antecedents using a fast bilinar interaction score between a span and a candidate antecedent, and the highest-scoring antecedents are kept. # Parameters top_span_embeddings: torch.FloatTensor, required. The embeddings of the top spans. (batch_size, num_spans_to_keep, embedding_size). top_span_mention_scores: torch.FloatTensor, required. The mention scores of the top spans. (batch_size, num_spans_to_keep). top_span_mask: torch.BoolTensor, required. The mask for the top spans. (batch_size, num_spans_to_keep). max_antecedents: int, required. The maximum number of antecedents to keep for each span. # Returns top_partial_coreference_scores: torch.FloatTensor The partial antecedent scores for each span-antecedent pair. Computed by summing the span mentions scores of the span and the antecedent as well as a bilinear interaction term. This score is partial because compared to the full coreference scores, it lacks the interaction term w * FFNN([g_i, g_j, g_i * g_j, features]). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mask: torch.BoolTensor The mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_offsets: torch.LongTensor The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_indices: torch.LongTensor The indices of every antecedent to consider with respect to the top k spans. (batch_size, num_spans_to_keep, max_antecedents) """ batch_size, num_spans_to_keep = top_span_embeddings.size()[:2] device = util.get_device_of(top_span_embeddings) # Shape: (1, num_spans_to_keep, num_spans_to_keep) _, _, valid_antecedent_mask = self._generate_valid_antecedents( num_spans_to_keep, num_spans_to_keep, device ) mention_one_score = top_span_mention_scores.unsqueeze(1) mention_two_score = top_span_mention_scores.unsqueeze(2) bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2) bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights) # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op partial_antecedent_scores = mention_one_score + mention_two_score + bilinear_score # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask # Shape: # (batch_size, num_spans_to_keep, max_antecedents) * 3 ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_indices, ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents) top_span_range = util.get_range_vector(num_spans_to_keep, device) # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op valid_antecedent_offsets = top_span_range.unsqueeze(-1) - top_span_range.unsqueeze(0) # TODO: we need to make `batched_index_select` more general to make this less awkward. top_antecedent_offsets = util.batched_index_select( valid_antecedent_offsets.unsqueeze(0) .expand(batch_size, num_spans_to_keep, num_spans_to_keep) .reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1), top_antecedent_indices.view(-1, max_antecedents), ).reshape(batch_size, num_spans_to_keep, max_antecedents) return ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, )
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, sentence_end: torch.LongTensor, spans: torch.LongTensor, span_labels: torch.LongTensor, metadata: List[Any], tags: torch.LongTensor = None, ): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) start = time.time() bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), # token_type_ids=verb_indicator, attention_mask=mask, ) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() embedded_text_input = self.embedding_dropout(bert_embeddings) batch_size, sequence_length, _ = embedded_text_input.size() # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( bert_embeddings, spans) if self._context_layer is not None: contextualized_embeddings = self._context_layer( embedded_text_input, mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) span_embeddings = endpoint_span_embeddings else: span_embeddings = attended_span_embeddings # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * sequence_length)) num_spans = spans.shape[1] num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings)).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep) verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat( 1, 1, embedded_text_input.shape[-1]) verb_embeddings = torch.gather(embedded_text_input, 1, verb_index) assert len( verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1 verb_embeddings = verb_embeddings.squeeze(1) # print(verb_indicator.sum(1, keepdim=True) > 0) verb_embeddings = torch.where( (verb_indicator.sum(1, keepdim=True) > 0).repeat( 1, verb_embeddings.shape[-1]), verb_embeddings, torch.zeros_like(verb_embeddings)) # print(verb_embeddings) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, spans.shape[1]) span_embeddings = util.batched_index_select(span_embeddings, top_span_indices, flat_top_span_indices) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) top_span_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) concatenated_span_embeddings = torch.cat( (span_embeddings, verb_embeddings.unsqueeze(1).repeat( 1, span_embeddings.shape[1], 1)), dim=2) # print(concatenated_span_embeddings[:,:,:]) hidden = self.hidden_layer(concatenated_span_embeddings) # print(hidden[1,:,:]) # print(top_span_indices) # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])]) # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels")) predictions = self.output_layer(hidden) # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1) predictions = torch.cat( (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1) # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)) output_dict = {} # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]), top_span_labels.view(-1)) * top_span_mask.float().view(-1) ).sum() / top_span_mask.float().sum() # print(top_span_labels) # print(predictions.argmax(-1)) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.get_tags( top_spans, predictions, mask.shape[1], top_span_mask, output_dict) from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] # print('G', batch_bio_gold_tags) batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss return output_dict