def test_masked_indices_are_handled_correctly(self): sequence_tensor = torch.randn([2, 5, 7]) # concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(7, "x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) # Make a mask with the second batch element completely masked. indices_mask = torch.LongTensor([[1, 1], [0, 0]]) span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask) start_embeddings, end_embeddings = span_representations.split(7, -1) start_indices, end_indices = indices.split(1, -1) correct_start_embeddings = batched_index_select( sequence_tensor, start_indices.squeeze()).data # Completely masked second batch element, so it should all be zero. correct_start_embeddings[1, :, :].fill_(0) correct_end_embeddings = batched_index_select( sequence_tensor, end_indices.squeeze()).data correct_end_embeddings[1, :, :].fill_(0) numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.numpy()) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.numpy())
def test_correct_sequence_elements_are_embedded(self): sequence_tensor = torch.randn([2, 5, 7]) # Concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(7, "x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 14] assert extractor.get_output_dim() == 14 assert extractor.get_input_dim() == 7 start_indices, end_indices = indices.split(1, -1) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. start_embeddings, end_embeddings = span_representations.split(7, -1) correct_start_embeddings = batched_index_select( sequence_tensor, start_indices.squeeze()) correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()) numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.data.numpy())
def test_span_scorer_works_for_completely_masked_rows(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) # type: ignore spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 mask[2, :] = 0 # fully masked last batch element. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( spans, mask, 2) # We can't check the last row here, because it's completely masked. # Instead we'll check that the scores for these elements are -inf. numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1], [1, 2]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1], [1, 1], [0, 0]])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements, with # masked elements equal to -inf. correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy() correct_scores[2, :] = float("-inf") numpy.testing.assert_array_equal(correct_scores, pruned_scores.data.numpy())
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( spans, mask, 2) # Second element in the batch would have indices 2, 3, but # 3 and 0 are masked, so instead it has 1, 2. numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements. numpy.testing.assert_array_equal( correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(), pruned_scores.data.numpy())
def test_batched_index_select(self): indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # Each element is a vector of it's index. targets = torch.ones([2, 10, 3]).cumsum(1) - 1 # Make the second batch double it's index so they're different. targets[1, :, :] *= 2 indices = torch.tensor(indices, dtype=torch.long) selected = util.batched_index_select(targets, indices) assert list(selected.size()) == [2, 2, 2, 3] ones = numpy.ones([3]) numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones) numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2) numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3) numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4) numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10) numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12) numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14) numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
def test_masked_indices_are_handled_correctly_with_exclusive_indices(self): sequence_tensor = torch.randn([2, 5, 8]) # concatentate start and end points together to form our representation # for both the forward and backward directions. extractor = EndpointSpanExtractor(8, "x,y", use_exclusive_start_indices=True) indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]]) sequence_mask = torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]) span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. start_embeddings, end_embeddings = span_representations.split(8, -1) correct_start_indices = torch.LongTensor([[0, 1], [-1, -1]]) # These indices should be -1, so they'll be replaced with a sentinel. Here, # we'll set them to a value other than -1 so we can index select the indices and # replace them later. correct_start_indices[1, 0] = 1 correct_start_indices[1, 1] = 1 correct_end_indices = torch.LongTensor([[3, 4], [2, 1]]) correct_start_embeddings = batched_index_select( sequence_tensor.contiguous(), correct_start_indices) # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel. correct_start_embeddings[1, 0] = extractor._start_sentinel.data correct_start_embeddings[1, 1] = extractor._start_sentinel.data numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()) correct_end_embeddings = batched_index_select( sequence_tensor.contiguous(), correct_end_indices) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.data.numpy())
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int( self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask( sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = ( torch.ones_like(sequence_tensor[:, 0, 0], dtype=torch.long) * sequence_tensor.size(1)) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze( -1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * ( 1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select( forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select( backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def forward( self, # pylint: disable=arguments-differ span_embeddings: torch.FloatTensor, span_mask: torch.LongTensor, num_spans_to_keep: int ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ Extracts the top-k scoring spans with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that we can rely on the ordering to consider the previous k spans as antecedents for each span later. Parameters ---------- span_embeddings : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_spans, embedding_size), representing the set of embedded span representations. span_mask : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans), denoting unpadded elements of ``span_embeddings``. num_spans_to_keep : ``int``, required. The number of spans to keep when pruning. Returns ------- top_span_embeddings : ``torch.FloatTensor`` The span representations of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, embedding_size). top_span_mask : ``torch.LongTensor`` The coresponding mask for ``top_span_embeddings``. Has shape (batch_size, num_spans_to_keep). top_span_indices : ``torch.IntTensor`` The indices of the top-k scoring spans into the original ``span_embeddings`` tensor. This is returned because it can be useful to retain pointers to the original spans, if each span is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_spans_to_keep). top_span_scores : ``torch.FloatTensor`` The values of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, 1). """ span_mask = span_mask.unsqueeze(-1) num_spans = span_embeddings.size(1) # Shape: (batch_size, num_spans, 1) span_scores = self._scorer(span_embeddings) if span_scores.size(-1) != 1 or span_scores.dim() != 3: raise ValueError( f"The scorer passed to SpanPruner must produce a tensor of shape" f"(batch_size, num_spans, 1), but found shape {span_scores.size()}" ) # Make sure that we don't select any masked spans by # setting their scores to be -inf. span_scores += span_mask.log() # Shape: (batch_size, num_spans_to_keep, 1) _, top_span_indices = span_scores.topk(num_spans_to_keep, 1) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``span_embeddings`` tensor). top_span_indices, _ = torch.sort(top_span_indices, 1) # Shape: (batch_size, num_spans_to_keep) top_span_indices = top_span_indices.squeeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep) top_span_mask = util.batched_index_select(span_mask, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, 1) top_span_scores = util.batched_index_select(span_scores, top_span_indices, flat_top_span_indices) return top_span_embeddings, top_span_mask.squeeze( -1), top_span_indices, top_span_scores
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.last_dim_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings
def _get_initial_state_and_scores( self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions( actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions( linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return { "initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score }
def test_correct_sequence_elements_are_embedded(self): sequence_tensor = torch.randn([2, 5, 8]) # concatentate start and end points together to form our representation # for both the forward and backward directions. extractor = BidirectionalEndpointSpanExtractor( input_dim=8, forward_combination="x,y", backward_combination="x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 16] assert extractor.get_output_dim() == 16 assert extractor.get_input_dim() == 8 # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. (forward_start_embeddings, forward_end_embeddings, backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1) forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split( 4, -1) # Forward direction => subtract 1 from start indices to make them exlusive. correct_forward_start_indices = torch.LongTensor([[0, 1], [-1, 2]]) # This index should be -1, so it will be replaced with a sentinel. Here, # we'll set it to a value other than -1 so we can index select the indices and # replace it later. correct_forward_start_indices[1, 0] = 1 # Forward direction => end indices are the same. correct_forward_end_indices = torch.LongTensor([[3, 4], [2, 4]]) # Backward direction => start indices are exclusive, so add 1 to the end indices. correct_backward_start_indices = torch.LongTensor([[4, 5], [3, 5]]) # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel. # Here we replace them with ones so we can index select using these indices without torch # complaining. correct_backward_start_indices[0, 1] = 1 correct_backward_start_indices[1, 1] = 1 # Backward direction => end indices are inclusive and equal to the forward start indices. correct_backward_end_indices = torch.LongTensor([[1, 2], [0, 3]]) correct_forward_start_embeddings = batched_index_select( forward_sequence_tensor.contiguous(), correct_forward_start_indices) # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel. correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data numpy.testing.assert_array_equal( forward_start_embeddings.data.numpy(), correct_forward_start_embeddings.data.numpy()) correct_forward_end_embeddings = batched_index_select( forward_sequence_tensor.contiguous(), correct_forward_end_indices) numpy.testing.assert_array_equal( forward_end_embeddings.data.numpy(), correct_forward_end_embeddings.data.numpy()) correct_backward_end_embeddings = batched_index_select( backward_sequence_tensor.contiguous(), correct_backward_end_indices) numpy.testing.assert_array_equal( backward_end_embeddings.data.numpy(), correct_backward_end_embeddings.data.numpy()) correct_backward_start_embeddings = batched_index_select( backward_sequence_tensor.contiguous(), correct_backward_start_indices) # This element had sequence_tensor index == sequence_tensor.size(1), # so it's exclusive index is the end sentinel. correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data numpy.testing.assert_array_equal( backward_start_embeddings.data.numpy(), correct_backward_start_embeddings.data.numpy())
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select( sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError( f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}." ) start_embeddings = util.batched_index_select( sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors( self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors