def test_flattened_index_select(self): indices = numpy.array([[1, 2], [3, 4]]) targets = torch.ones([2, 6, 3]).cumsum(1) - 1 # Make the second batch double it's index so they're different. targets[1, :, :] *= 2 indices = torch.tensor(indices, dtype=torch.long) selected = util.flattened_index_select(targets, indices) assert list(selected.size()) == [2, 2, 2, 3] ones = numpy.ones([3]) numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones) numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2) numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3) numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4) numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 2) numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 4) numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 6) numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 8) # Check we only accept 2D indices. with pytest.raises(ConfigurationError): util.flattened_index_select(targets, torch.ones([3, 4, 5]))
def inference_coref(self, batch, embedded_text_input_relation, mask): submodel = self.model._tagger_coref ### Fast inference of coreference ### spans = batch["spans"] document_length = mask.size(1) num_spans = spans.size(1) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() spans = F.relu(spans.float()).long() encoded_text_coref = submodel._context_layer( embedded_text_input_relation, mask) endpoint_span_embeddings = submodel._endpoint_span_extractor( encoded_text_coref, spans) attended_span_embeddings = submodel._attentive_span_extractor( embedded_text_input_relation, spans) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) num_spans_to_keep = int( math.floor(submodel._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = submodel._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) max_antecedents = min(submodel._max_antecedents, num_spans_to_keep) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ submodel._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(mask)) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) span_pair_embeddings = submodel._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) coreference_scores = submodel._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } return output_dict
def predict_labels_doc(self, output_dict): # Shape: (batch_size, num_spans_to_keep) coref_labels = output_dict["coref_labels"] coreference_scores = output_dict["coreference_scores"] _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict["predicted_antecedents"] = predicted_antecedents top_span_indices = output_dict["top_span_indices"] flat_top_span_indices = output_dict["flat_top_span_indices"] valid_antecedent_indices = output_dict["antecedent_indices"] valid_antecedent_log_mask = output_dict["valid_antecedent_log_mask"] top_spans = output_dict["top_spans"] top_span_mask = output_dict["top_span_mask"] metadata = output_dict["metadata"] sentence_lengths = output_dict["sentence_lengths"] if coref_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( coref_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) # There's an integer wrap-around happening here. It occurs in the original code. antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() # Need to get cluster data in same form as for original AllenNLP coref code so that the # evaluation code works. evaluation_metadata = self._make_evaluation_metadata( metadata, sentence_lengths) self._mention_recall(top_spans, evaluation_metadata) # TODO(dwadden) Shouldnt need to do the unsqueeze here; figure out what's happening. self._conll_coref_scores(top_spans, valid_antecedent_indices.unsqueeze(0), predicted_antecedents, evaluation_metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def get_coref_scores(self, top_span_embeddings, top_span_mention_scores, valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask): candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) return coreference_scores
def score_spans_if_labels( self, output_dict, span_labels, metadata, top_span_indices, flat_top_span_indices, top_span_mask, top_spans, valid_antecedent_indices, valid_antecedent_log_mask, coreference_scores, predicted_antecedents, ): if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def _distance_pruning( self, top_span_embeddings: torch.FloatTensor, top_span_mention_scores: torch.FloatTensor, max_antecedents: int, ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]: """ Generates antecedents for each span and prunes down to `max_antecedents`. This method prunes antecedents only based on distance (i.e. number of intervening spans). The closest antecedents are kept. # Parameters top_span_embeddings: torch.FloatTensor, required. The embeddings of the top spans. (batch_size, num_spans_to_keep, embedding_size). top_span_mention_scores: torch.FloatTensor, required. The mention scores of the top spans. (batch_size, num_spans_to_keep). max_antecedents: int, required. The maximum number of antecedents to keep for each span. # Returns top_partial_coreference_scores: torch.FloatTensor The partial antecedent scores for each span-antecedent pair. Computed by summing the span mentions scores of the span and the antecedent. This score is partial because compared to the full coreference scores, it lacks the interaction term w * FFNN([g_i, g_j, g_i * g_j, features]). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mask: torch.BoolTensor The mask representing whether each antecedent span is valid. Required since different spans have different numbers of valid antecedents. For example, the first span in the document should have no valid antecedents. (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_offsets: torch.LongTensor The distance between the span and each of its antecedents in terms of the number of considered spans (i.e not the word distance between the spans). (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_indices: torch.LongTensor The indices of every antecedent to consider with respect to the top k spans. (batch_size, num_spans_to_keep, max_antecedents) """ # These antecedent matrices are independent of the batch dimension - they're just a function # of the span's position in top_spans. # The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. num_spans_to_keep = top_span_embeddings.size(1) device = util.get_device_of(top_span_embeddings) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) ( top_antecedent_indices, top_antecedent_offsets, top_antecedent_mask, ) = self._generate_valid_antecedents( # noqa num_spans_to_keep, max_antecedents, device ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) top_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores.unsqueeze(-1), top_antecedent_indices ).squeeze(-1) # Shape: (batch_size, num_spans_to_keep, max_antecedents) * 4 top_partial_coreference_scores = ( top_span_mention_scores.unsqueeze(-1) + top_antecedent_mention_scores ) top_antecedent_indices = top_antecedent_indices.unsqueeze(0).expand_as( top_partial_coreference_scores ) top_antecedent_offsets = top_antecedent_offsets.unsqueeze(0).expand_as( top_partial_coreference_scores ) top_antecedent_mask = top_antecedent_mask.expand_as(top_partial_coreference_scores) return ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, )
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, metadata: List[Dict[str, Any]], doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, doc_relations=None, doc_ner_labels: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence) for document in metadata for sentence in document['doc_tokens']) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relations is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) self._compute_relex_metrics(output_dict, doc_relations) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) if self._use_gold_mentions: if text_embeddings.is_cuda: device = torch.device("cuda") else: device = torch.device("cpu") s = [ torch.as_tensor(pair, dtype=torch.long, device=device) for cluster in metadata[0]["clusters"] for pair in cluster ] gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1) span_mask = spans.unsqueeze(2) - gm span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1] == 0) span_mask, _ = (span_mask == 2).max(-1) num_spans = span_mask.sum().item() span_mask = span_mask.float() else: span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() num_spans = spans.size(1) # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents( num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def get_coreference_scores( self, spans, top_span_mention_scores, num_spans_to_keep, top_span_indices, flat_top_span_indices, top_span_mask, top_span_embeddings, text_mask, get_scores, ): # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } if get_scores: output_dict["coreference_scores"] = coreference_scores output_dict['top_span_indices'] = top_span_indices return { 'output_dict': output_dict, 'ant_mask': valid_antecedent_log_mask, "coreference_scores": coreference_scores }