def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sub_obj_ids=None, sub_obj_masks=None, input_position=None): """ attention_mask: [batch_size, from_seq_length, to_seq_length] """ batch_size = input_ids.size(0) outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=False, output_attentions=False, position_ids=input_position) sequence_output = outputs[0] sub_ids = sub_obj_ids[:, :, 0].view(batch_size, -1) sub_embeddings = batched_index_select(sequence_output, sub_ids) obj_ids = sub_obj_ids[:, :, 1].view(batch_size, -1) obj_embeddings = batched_index_select(sequence_output, obj_ids) rep = torch.cat((sub_embeddings, obj_embeddings), dim=-1) rep = self.layer_norm(rep) rep = self.dropout(rep) logits = self.classifier(rep) if labels is not None: loss_fct = CrossEntropyLoss() active_loss = (sub_obj_masks.view(-1) == 1) active_logits = logits.view(-1, logits.shape[-1]) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) return loss else: return logits
def _get_span_embeddings(self, input_ids, spans, token_type_ids=None, attention_mask=None): sequence_output, pooled_output = self.bert( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) sequence_output = self.hidden_dropout(sequence_output) """ spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width spans_mask: (batch_size, num_spans, ) """ spans_start = spans[:, :, 0].view(spans.size(0), -1) spans_start_embedding = batched_index_select(sequence_output, spans_start) spans_end = spans[:, :, 1].view(spans.size(0), -1) spans_end_embedding = batched_index_select(sequence_output, spans_end) spans_width = spans[:, :, 2].view(spans.size(0), -1) spans_width_embedding = self.width_embedding(spans_width) # Concatenate embeddings of left/right points and the width embedding spans_embedding = torch.cat( (spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1) """ spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim) """ return spans_embedding
def test_correct_sequence_elements_are_embedded(self): sequence_tensor = torch.randn([2, 5, 7]) # Concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(7, "x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 14] assert extractor.get_output_dim() == 14 assert extractor.get_input_dim() == 7 start_indices, end_indices = indices.split(1, -1) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. start_embeddings, end_embeddings = span_representations.split(7, -1) correct_start_embeddings = batched_index_select( sequence_tensor, start_indices.squeeze()) correct_end_embeddings = batched_index_select(sequence_tensor, end_indices.squeeze()) numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.data.numpy())
def test_masked_indices_are_handled_correctly(self): sequence_tensor = torch.randn([2, 5, 7]) # concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(7, "x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]]) span_representations = extractor(sequence_tensor, indices) # Make a mask with the second batch element completely masked. indices_mask = torch.LongTensor([[1, 1], [0, 0]]) span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask) start_embeddings, end_embeddings = span_representations.split(7, -1) start_indices, end_indices = indices.split(1, -1) correct_start_embeddings = batched_index_select( sequence_tensor, start_indices.squeeze()).data # Completely masked second batch element, so it should all be zero. correct_start_embeddings[1, :, :].fill_(0) correct_end_embeddings = batched_index_select( sequence_tensor, end_indices.squeeze()).data correct_end_embeddings[1, :, :].fill_(0) numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.numpy()) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.numpy())
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select(sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}.") start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def _compute_span_representations( self, text_embeddings: torch.FloatTensor, text_mask: torch.FloatTensor, span_starts: torch.IntTensor, span_ends: torch.IntTensor) -> torch.FloatTensor: """ Computes an embedded representation of every candidate span. This is a concatenation of the contextualized endpoints of the span, an embedded representation of the width of the span and a representation of the span's predicted head. Parameters ---------- text_embeddings : ``torch.FloatTensor``, required. The embedded document of shape (batch_size, document_length, embedding_dim) over which we are computing a weighted sum. text_mask : ``torch.FloatTensor``, required. A mask of shape (batch_size, document_length) representing non-padding entries of ``text_embeddings``. span_starts : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the start of each span candidate. span_ends : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the end of each span candidate. Returns ------- span_embeddings : ``torch.FloatTensor`` An embedded representation of every candidate span with shape: (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_size + feature_size) """ # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, encoding_dim) start_embeddings = util.batched_index_select(contextualized_embeddings, span_starts.squeeze(-1)) end_embeddings = util.batched_index_select(contextualized_embeddings, span_ends.squeeze(-1)) # Compute and embed the span_widths (strictly speaking the span_widths - 1) # Shape: (batch_size, num_spans, 1) span_widths = span_ends - span_starts # Shape: (batch_size, num_spans, encoding_dim) span_width_embeddings = self._span_width_embedding( span_widths.squeeze(-1)) # Shape: (batch_size, document_length, 1) head_scores = self._head_scorer(contextualized_embeddings) # Shape: (batch_size, num_spans, embedding_dim) # Note that we used the original text embeddings, not the contextual ones here. attended_text_embeddings = self._create_attended_span_representations( head_scores, text_embeddings, span_ends, span_widths) # (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_dim + feature_dim) span_embeddings = torch.cat([ start_embeddings, end_embeddings, span_width_embeddings, attended_text_embeddings ], -1) return span_embeddings
def forward( self, input_ids, spans, spans_mask, token_type_ids=None, attention_mask=None, spans_ner_label=None, ): sequence_output = self.albert( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, )["last_hidden_state"] sequence_output = self.hidden_dropout(sequence_output) """ spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width spans_mask: (batch_size, num_spans, ) """ spans_start = spans[:, :, 0] spans_start_embedding = batched_index_select(sequence_output, spans_start) spans_end = spans[:, :, 1] spans_end_embedding = batched_index_select(sequence_output, spans_end) spans_width = spans[:, :, 2] spans_width_embedding = self.width_embedding(spans_width) spans_embedding = torch.cat( ( spans_start_embedding, spans_end_embedding, spans_width_embedding, ), dim=-1, ) """ spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim) """ logits = self.ner_classifier(spans_embedding)[:, :, 0] logits = torch.sigmoid(logits) if spans_ner_label is not None: gold = spans_ner_label.type(torch.float) pred = logits loss_fct = BCELoss( weight=(gold > 0.5) * 0.97 + 0.03, reduction="sum" ) # 0 is 98.56% of the time, 1 is 1.44% of the time loss = loss_fct(pred, gold) # print(gold[0], pred[0]) return loss, f1_loss(gold, pred), logits, spans_embedding else: return logits, spans_embedding
def forward( self, span: torch.Tensor, # SHAPE: (batch_size, num_spans, span_dim) span_pairs: torch.LongTensor # SHAPE: (batch_size, num_span_pairs) ): span1 = span2 = span if self.dim_reduce_layer1 is not None: span1 = self.dim_reduce_layer1(span) if self.dim_reduce_layer2 is not None: span2 = self.dim_reduce_layer2(span) if not self.pair: return span1, span2 num_spans = span.size(1) # get span pair embedding span_pairs_p = span_pairs[:, :, 0] span_pairs_c = span_pairs[:, :, 1] # SHAPE: (batch_size * num_span_pairs) flat_span_pairs_p = util.flatten_and_batch_shift_indices( span_pairs_p, num_spans) flat_span_pairs_c = util.flatten_and_batch_shift_indices( span_pairs_c, num_spans) # SHAPE: (batch_size, num_span_pairs, span_dim) span_pair_p_emb = util.batched_index_select(span1, span_pairs_p, flat_span_pairs_p) span_pair_c_emb = util.batched_index_select(span2, span_pairs_c, flat_span_pairs_c) if self.combine == 'concat': # SHAPE: (batch_size, num_span_pairs, span_dim * 2) span_pair_emb = torch.cat([span_pair_p_emb, span_pair_c_emb], -1) elif self.combine == 'coref': # use the indices gap as distance, which requires the indices to be consistent # with the order they appear in the sentences distance = span_pairs_p - span_pairs_c # SHAPE: (batch_size, num_span_pairs, dist_emb_dim) distance_embeddings = self.distance_embedding( util.bucket_values( distance, num_total_buckets=self.num_distance_buckets)) # SHAPE: (batch_size, num_span_pairs, span_dim * 3) span_pair_emb = torch.cat([ span_pair_p_emb, span_pair_c_emb, span_pair_p_emb * span_pair_c_emb, distance_embeddings ], -1) if self.repr_layer is not None: # SHAPE: (batch_size, num_span_pairs, out_dim) span_pair_emb = self.repr_layer(span_pair_emb) return span_pair_emb
def test_scorer_works_for_completely_masked_rows(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = Pruner(scorer=scorer) # type: ignore items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) items[0, :2, :] = 1 items[1, 2:, :] = 1 items[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 mask[2, :] = 0 # fully masked last batch element. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(items, mask, 2) # We can't check the last row here, because it's completely masked. # Instead we'll check that the scores for these elements are very small. numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1], [1, 2]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1], [1, 1], [0, 0]])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(items, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements, with masked elements very # small (but not -inf, because that can cause problems). We'll test these two cases # separately. correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy() numpy.testing.assert_array_equal(correct_scores[:2], pruned_scores[:2].data.numpy()) numpy.testing.assert_array_equal(pruned_scores[2] < -1e15, [[1], [1]]) numpy.testing.assert_array_equal(pruned_scores[2] == float('-inf'), [[0], [0]])
def coref_propagation_doc(self, output_dict): coreference_scores = output_dict["coreference_scores"] top_span_embeddings = output_dict["top_span_embeddings"] antecedent_indices = output_dict["antecedent_indices"] for t in range(self.coref_prop): assert coreference_scores.shape[1] == antecedent_indices.shape[0] assert coreference_scores.shape[2] - 1 == antecedent_indices.shape[1] assert top_span_embeddings.shape[1] == coreference_scores.shape[1] assert antecedent_indices.max() <= top_span_embeddings.shape[1] antecedent_distribution = self.antecedent_softmax(coreference_scores)[:, :, 1:] top_span_emb_repeated = top_span_embeddings.repeat(antecedent_distribution.shape[2],1,1) if antecedent_indices.shape[0]==antecedent_indices.shape[1]: selected_top_span_embs = util.batched_index_select(top_span_emb_repeated, antecedent_indices).unsqueeze(0) entity_embs = (selected_top_span_embs.permute([3,0,1,2]) * antecedent_distribution).permute([1, 2, 3, 0]).sum(dim=2) else: ant_var1 = antecedent_indices.unsqueeze(0).unsqueeze(-1).repeat(1,1,1,top_span_embeddings.shape[-1]) top_var1 = top_span_embeddings.unsqueeze(1).repeat(1,antecedent_distribution.shape[1],1,1) entity_embs = (torch.gather(top_var1, 2, ant_var1).permute([3,0,1,2]) * antecedent_distribution).permute([1, 2, 3, 0]).sum(dim=2) #entity_embs = F.dropout(entity_embs) f_network_input = torch.cat([top_span_embeddings, entity_embs], dim=-1) f_weights = self._f_network(f_network_input) top_span_embeddings = f_weights * top_span_embeddings + (1.0 - f_weights) * entity_embs #f_weights2 = self._f_network2(f_network_input) #top_span_embeddings = f_weights2 * top_span_embeddings + (1.0 - f_weights2) * entity_embs coreference_scores = self.get_coref_scores(top_span_embeddings, self._mention_pruner._scorer(top_span_embeddings), output_dict["antecedent_indices"], output_dict["valid_antecedent_offsets"], output_dict["valid_antecedent_log_mask"]) output_dict["coreference_scores"] = coreference_scores output_dict["top_span_embeddings"] = top_span_embeddings return output_dict
def forward(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, **kwargs): # slot_name -> Shape: batch_size, 1 gold_slot_labels = self._get_gold_slot_labels(kwargs) if gold_slot_labels is None: raise ConfigurationError( "QfirstQuestionGenerator requires gold labels for teacher forcing when running forward. " "You may wish to run beam_decode instead.") # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim() encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) # Shape: batch_size, self._sentence_encoder.get_output_dim() pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) # slot_name -> Shape: batch_size, slot_name_vocab_size slot_logits = self._question_generator(pred_rep, **gold_slot_labels) batch_size, _ = pred_rep.size() # Shape: <scalar> slot_nlls, neg_log_likelihood = self._get_cross_entropy( slot_logits, gold_slot_labels) self.metric(slot_logits, gold_slot_labels, torch.ones([batch_size]), slot_nlls, neg_log_likelihood) return {**slot_logits, "loss": neg_log_likelihood}
def forward(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, clause_dist: torch.FloatTensor = None, **kwargs): # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim() encoded_text, text_mask = self._sentence_encoder(text, predicate_indicator) # Shape: batch_size, encoder_output_dim pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) # Shape: batch_size, get_vocab_size(self._label_namespace) frame_logits = self._frame_pred(pred_rep) frame_probs = F.softmax(frame_logits, dim = 1) frames = F.softmax(self._frames_matrix, dim = 1) clause_probs = torch.matmul(frame_probs, frames) clause_log_probs = clause_probs.log() output_dict = { "probs": frame_probs } # TODO figure out how to do this with logits # TODO figure out how to handle the null case if clause_dist is not None and clause_dist.sum().item() > 0.1: gold_clause_probs = F.normalize(clause_dist.float(), p = 1, dim = 1) cross_entropy = torch.sum(-gold_clause_probs * clause_log_probs, 1) output_dict["loss"] = torch.mean(cross_entropy) gold_entropy = -gold_clause_probs * gold_clause_probs.log() # entropy summands gold_entropy[gold_entropy != gold_entropy] = 0.0 # zero out nans gold_entropy = torch.sum(gold_entropy, dim = 1) # compute sum for per-batch-item entropy kl_divergence = cross_entropy - gold_entropy self._metric(clause_probs, clause_dist > 0.0) self._kl_divergence_metric(kl_divergence) return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, clause: torch.LongTensor, answer_slot: torch.LongTensor, answer_spans: torch.LongTensor = None, span_counts: torch.LongTensor = None, num_answers: torch.LongTensor = None, metadata=None, **kwargs): embedded_clause = self._clause_embedding(clause) embedded_slot = self._slot_embedding(answer_slot) encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) combined_embedding = torch.cat( [embedded_clause, embedded_slot, pred_rep], -1) question_embedding = self._question_projection(combined_embedding) return self._span_selector(encoded_text, text_mask, extra_input_embedding=question_embedding, answer_spans=answer_spans, span_counts=span_counts, num_answers=num_answers, metadata=metadata)
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector( max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu( raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # text_embeddings = span_embeddings * span_mask.unsqueeze(-1) batch_size, num_spans, max_batch_span_width, _ = span_embeddings.size() view_text_embeddings = span_embeddings.view(batch_size * num_spans, max_batch_span_width, -1) span_mask = span_mask.view(batch_size * num_spans, max_batch_span_width) cnn_text_embeddings = self.cnn(view_text_embeddings, span_mask) cnn_text_embeddings = cnn_text_embeddings.view(batch_size, num_spans, self._output_dim) return cnn_text_embeddings
def predict_labels_doc(self, output_dict): # Shape: (batch_size, num_spans_to_keep) coref_labels = output_dict["coref_labels"] coreference_scores = output_dict["coreference_scores"] _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict["predicted_antecedents"] = predicted_antecedents top_span_indices = output_dict["top_span_indices"] flat_top_span_indices = output_dict["flat_top_span_indices"] valid_antecedent_indices = output_dict["antecedent_indices"] valid_antecedent_log_mask = output_dict["valid_antecedent_log_mask"] top_spans = output_dict["top_spans"] top_span_mask = output_dict["top_span_mask"] metadata = output_dict["metadata"] sentence_lengths = output_dict["sentence_lengths"] if coref_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( coref_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) # There's an integer wrap-around happening here. It occurs in the original code. antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() # Need to get cluster data in same form as for original AllenNLP coref code so that the # evaluation code works. evaluation_metadata = self._make_evaluation_metadata( metadata, sentence_lengths) self._mention_recall(top_spans, evaluation_metadata) # TODO(dwadden) Shouldnt need to do the unsqueeze here; figure out what's happening. self._conll_coref_scores(top_spans, valid_antecedent_indices.unsqueeze(0), predicted_antecedents, evaluation_metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def test_pruner_selects_top_scored_items_and_respects_masking(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = Pruner(scorer=scorer) items = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) items[0, :2, :] = 1 items[1, 2:, :] = 1 items[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( items, mask, 2) # Second element in the batch would have indices 2, 3, but # 3 and 0 are masked, so instead it has 1, 2. numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(items, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements. numpy.testing.assert_array_equal( correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(), pruned_scores.data.numpy())
def test_batched_index_select(self): indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # Each element is a vector of it's index. targets = torch.ones([2, 10, 3]).cumsum(1) - 1 # Make the second batch double it's index so they're different. targets[1, :, :] *= 2 indices = Variable(torch.LongTensor(indices)) targets = Variable(targets) selected = util.batched_index_select(targets, indices) assert list(selected.size()) == [2, 2, 2, 3] ones = numpy.ones([3]) numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones) numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2) numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3) numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4) numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10) numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12) numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14) numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
def test_span_scorer_works_for_completely_masked_rows(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) # type: ignore spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 mask[2, :] = 0 # fully masked last batch element. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner( spans, mask, 2) # We can't check the last row here, because it's completely masked. # Instead we'll check that the scores for these elements are -inf. numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1], [1, 2]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1], [1, 1], [0, 0]])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements, with # masked elements equal to -inf. correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy() correct_scores[2, :] = float(u"-inf") numpy.testing.assert_array_equal(correct_scores, pruned_scores.data.numpy())
def test_span_scorer_works_for_completely_masked_rows(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) # type: ignore spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 mask[2, :] = 0 # fully masked last batch element. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2) # We can't check the last row here, because it's completely masked. # Instead we'll check that the scores for these elements are -inf. numpy.testing.assert_array_equal(pruned_indices[:2].data.numpy(), numpy.array([[0, 1], [1, 2]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.array([[1, 1], [1, 1], [0, 0]])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements, with # masked elements equal to -inf. correct_scores = correct_embeddings.sum(-1).unsqueeze(-1).data.numpy() correct_scores[2, :] = float("-inf") numpy.testing.assert_array_equal(correct_scores, pruned_scores.data.numpy())
def test_span_pruner_selects_top_scored_spans_and_respects_masking(self): # Really simple scorer - sum up the embedding_dim. scorer = lambda tensor: tensor.sum(-1).unsqueeze(-1) pruner = SpanPruner(scorer=scorer) spans = torch.randn([3, 4, 5]).clamp(min=0.0, max=1.0) spans[0, :2, :] = 1 spans[1, 2:, :] = 1 spans[2, 2:, :] = 1 mask = torch.ones([3, 4]) mask[1, 0] = 0 mask[1, 3] = 0 pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, 2) # Second element in the batch would have indices 2, 3, but # 3 and 0 are masked, so instead it has 1, 2. numpy.testing.assert_array_equal(pruned_indices.data.numpy(), numpy.array([[0, 1], [1, 2], [2, 3]])) numpy.testing.assert_array_equal(pruned_mask.data.numpy(), numpy.ones([3, 2])) # embeddings should be the result of index_selecting the pruned_indices. correct_embeddings = batched_index_select(spans, pruned_indices) numpy.testing.assert_array_equal(correct_embeddings.data.numpy(), pruned_embeddings.data.numpy()) # scores should be the sum of the correct embedding elements. numpy.testing.assert_array_equal(correct_embeddings.sum(-1).unsqueeze(-1).data.numpy(), pruned_scores.data.numpy())
def inference_coref(self, batch, embedded_text_input_relation, mask): submodel = self.model._tagger_coref ### Fast inference of coreference ### spans = batch["spans"] document_length = mask.size(1) num_spans = spans.size(1) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() spans = F.relu(spans.float()).long() encoded_text_coref = submodel._context_layer( embedded_text_input_relation, mask) endpoint_span_embeddings = submodel._endpoint_span_extractor( encoded_text_coref, spans) attended_span_embeddings = submodel._attentive_span_extractor( embedded_text_input_relation, spans) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) num_spans_to_keep = int( math.floor(submodel._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = submodel._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) max_antecedents = min(submodel._max_antecedents, num_spans_to_keep) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ submodel._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(mask)) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) span_pair_embeddings = submodel._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) coreference_scores = submodel._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } return output_dict
def score_spans_if_labels( self, output_dict, span_labels, metadata, top_span_indices, flat_top_span_indices, top_span_mask, top_spans, valid_antecedent_indices, valid_antecedent_log_mask, coreference_scores, predicted_antecedents, ): if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def test_masked_indices_are_handled_correctly_with_exclusive_indices(self): sequence_tensor = Variable(torch.randn([2, 5, 8])) # concatentate start and end points together to form our representation # for both the forward and backward directions. extractor = EndpointSpanExtractor(8, "x,y", use_exclusive_start_indices=True) indices = Variable( torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [0, 1]]])) sequence_mask = Variable( torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]])) span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. start_embeddings, end_embeddings = span_representations.split(8, -1) correct_start_indices = Variable(torch.LongTensor([[0, 1], [-1, -1]])) # These indices should be -1, so they'll be replaced with a sentinel. Here, # we'll set them to a value other than -1 so we can index select the indices and # replace them later. correct_start_indices[1, 0] = 1 correct_start_indices[1, 1] = 1 correct_end_indices = Variable(torch.LongTensor([[3, 4], [2, 1]])) correct_start_embeddings = batched_index_select( sequence_tensor.contiguous(), correct_start_indices) # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel. correct_start_embeddings[1, 0] = extractor._start_sentinel.data correct_start_embeddings[1, 1] = extractor._start_sentinel.data numpy.testing.assert_array_equal(start_embeddings.data.numpy(), correct_start_embeddings.data.numpy()) correct_end_embeddings = batched_index_select( sequence_tensor.contiguous(), correct_end_indices) numpy.testing.assert_array_equal(end_embeddings.data.numpy(), correct_end_embeddings.data.numpy())
def forward(self, # pylint: disable=arguments-differ sequence_tensor: torch.FloatTensor, indicies: torch.LongTensor) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in indicies.split(1, dim=-1)] start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) return combined_tensors
def forward(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, qarg_labeled_clauses, qarg_labeled_spans, qarg_labels=None, **kwargs): # Shape: batch_size, num_tokens, encoder_output_dim encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) # Shape: batch_size, encoder_output_dim pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) batch_size, num_labeled_instances, _ = qarg_labeled_spans.size() # Shape: batch_size, num_labeled_instances qarg_labeled_mask = (qarg_labeled_spans[:, :, 0] >= 0).squeeze(-1).long() if len(qarg_labeled_mask.size()) == 1: qarg_labeled_mask = qarg_labeled_mask.unsqueeze(-1) # max to prevent the padded labels from messing up the embedding module # Shape: batch_size, num_labeled_instances, self._clause_embedding_dim input_clauses = self._clause_embedding( qarg_labeled_clauses.max(torch.zeros_like(qarg_labeled_clauses))) # Shape: batch_size, num_spans, 2 * encoder_output_projected_dim span_embeddings = self._span_extractor(encoded_text, qarg_labeled_spans, text_mask, qarg_labeled_mask) # Shape: batch_size, num_spans, self._span_hidden_dim input_span_hidden = self._span_hidden(span_embeddings) # Shape: batch_size, 1, self._final_input_dim pred_embedding = self._predicate_hidden(pred_rep) # Shape: batch_size, num_labeled_instances, self._final_input_dim qarg_inputs = F.relu( pred_embedding.unsqueeze(1) + input_clauses + input_span_hidden) # Shape: batch_size, num_labeled_instances, get_vocab_size("qarg-labels") qarg_logits = self._qarg_predictor(self._qarg_ffnn(qarg_inputs)) final_mask = qarg_labeled_mask.unsqueeze(-1) \ .expand(batch_size, num_labeled_instances, self.vocab.get_vocab_size("qarg-labels")) \ .float() qarg_probs = torch.sigmoid(qarg_logits).squeeze(-1) * final_mask output_dict = {"logits": qarg_logits, "probs": qarg_probs} if qarg_labels is not None: output_dict["loss"] = F.binary_cross_entropy_with_logits( qarg_logits, qarg_labels, weight=final_mask, reduction="sum") self._metric(qarg_probs, qarg_labels) return output_dict
def forward(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, **kwargs): # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim() encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) # Shape: batch_size, encoder_output_dim pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) # Shape: batch_size, get_vocab_size(self._label_namespace) logits = self._final_pred(pred_rep) return self._classifier(logits, None, kwargs.get(self._label_name), kwargs.get(self._label_name + "_counts"))
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def beam_decode(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, max_beam_size: int, min_beam_probability: float, clause_mode: bool = False): # Shape: batch_size, num_tokens, self._sentence_encoder.get_output_dim() encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) # Shape: batch_size, self._sentence_encoder.get_output_dim() pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) return self._question_generator.beam_decode(pred_rep, max_beam_size, min_beam_probability, clause_mode)
def nd_batched_index_select(target: torch.Tensor, indices: torch.IntTensor) -> torch.Tensor: """ Multidimensional version of `util.batched_index_select`. """ batch_axes = target.size()[:-2] num_batch_axes = len(batch_axes) target_shape = target.size() indices_shape = indices.size() target_reshaped = target.view(-1, *target_shape[num_batch_axes:]) indices_reshaped = indices.view(-1, *indices_shape[num_batch_axes:]) output_reshaped = util.batched_index_select(target_reshaped, indices_reshaped) return output_reshaped.view(*indices_shape, -1)
def forward(self, text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, tan_spans, tan_labels=None, **kwargs): # Shape: batch_size, num_tokens, encoder_output_dim encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) # Shape: batch_size, num_labeled_instances span_mask = (tan_spans[:, :, 0] >= 0).squeeze(-1).float() if len(span_mask.size()) == 1: span_mask = span_mask.unsqueeze(-1) # Shape: batch_size, num_spans, 2 * encoder_output_projected_dim span_embeddings = self._span_extractor(encoded_text, tan_spans, text_mask, span_mask.long()) batch_size, num_spans, _ = span_embeddings.size() if self._inject_predicate: expanded_pred_embedding = batched_index_select(encoded_text, predicate_index) \ .expand(batch_size, num_spans, self._sentence_encoder.get_output_dim()) input_embeddings = torch.cat( [expanded_pred_embedding, span_embeddings], -1) else: input_embeddings = span_embeddings tan_mask = span_mask.unsqueeze(-1) # broadcast ops to all tan IDs # Shape: batch_size, num_labeled_instances, self.vocab.get_vocab_size("tan-string-labels") tan_logits = self._tan_pred(input_embeddings) tan_probs = torch.sigmoid(tan_logits) * tan_mask output_dict = {"logits": tan_logits, "probs": tan_probs} if tan_labels is not None: output_dict["loss"] = F.binary_cross_entropy_with_logits( tan_logits, tan_labels.float(), weight=tan_mask, reduction="sum") self._metric(tan_probs, tan_labels, tan_mask.long()) return output_dict
def _prune_spans(self, spans, span_mask, span_embeddings, sentence_lengths): # Prune num_spans = spans.size(1) # Max number of spans for the minibatch. # Keep different number of spans for each minibatch entry. num_spans_to_keep = torch.ceil(sentence_lengths.float() * self._spans_per_word).long() (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores, num_spans_kept) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) return top_span_embeddings, top_span_mention_scores, num_spans_to_keep, top_span_mask, top_span_indices, top_spans
def _get_question_inputs(self, text, predicate_indicator, predicate_index, answer_spans): encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) span_mask = (answer_spans[:, :, 0] >= 0).long() # Shape: batch_size, num_spans, 2 * self._sentence_encoder.get_output_dim() span_reps = self._span_extractor(encoded_text, answer_spans, sequence_mask=text_mask, span_indices_mask=span_mask) batch_size, _, encoding_dim = encoded_text.size() num_spans = span_reps.size(1) if self._inject_predicate: pred_rep_expanded = batched_index_select(encoded_text, predicate_index) \ .expand(batch_size, num_spans, encoding_dim) question_inputs = torch.cat([pred_rep_expanded, span_reps], -1) else: question_inputs = span_reps return question_inputs, span_mask
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor, predicate_index: torch.LongTensor, answer_spans: torch.LongTensor = None, num_answers: torch.LongTensor = None, num_invalids: torch.LongTensor = None, metadata=None, **kwargs): # each of gold_slot_labels[slot_name] is of # Shape: batch_size slot_labels = self._get_slot_labels(**kwargs) if slot_labels is None: raise ConfigurationError( "QuestionAnswerer must receive question slots as input.") encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) pred_rep = batched_index_select(encoded_text, predicate_index).squeeze(1) question_encoding = self._question_encoder(pred_rep, slot_labels) question_rep = torch.cat([pred_rep, question_encoding], -1) output_dict = self._span_selector(encoded_text, text_mask, extra_input_embedding=question_rep, answer_spans=answer_spans, num_answers=num_answers, metadata=metadata) if self._classify_invalids: invalid_logits = self._invalid_pred(question_rep).squeeze(-1) invalid_probs = torch.sigmoid(invalid_logits) output_dict["invalid_prob"] = invalid_probs if num_invalids is not None: invalid_labels = (num_invalids > 0.0).float() invalid_loss = F.binary_cross_entropy_with_logits( invalid_logits, invalid_labels, reduction="sum") output_dict["loss"] += invalid_loss self._invalid_metric(invalid_probs, invalid_labels) return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], predicate_indicator: torch.LongTensor = None, predicate_index: torch.LongTensor = None, answer_spans: torch.LongTensor = None, span_counts: torch.LongTensor = None, num_answers: torch.LongTensor = None, metadata=None, **kwargs): encoded_text, text_mask = self._sentence_encoder( text, predicate_indicator) extra_input = batched_index_select( encoded_text, predicate_index) if self._inject_predicate else None return self._span_selector(encoded_text, text_mask, extra_input_embedding=extra_input, answer_spans=answer_spans, span_counts=span_counts, num_answers=num_answers, metadata=metadata)
def test_batched_index_select(self): indices = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # Each element is a vector of it's index. targets = torch.ones([2, 10, 3]).cumsum(1) - 1 # Make the second batch double it's index so they're different. targets[1, :, :] *= 2 indices = torch.tensor(indices, dtype=torch.long) selected = util.batched_index_select(targets, indices) assert list(selected.size()) == [2, 2, 2, 3] ones = numpy.ones([3]) numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones) numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2) numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3) numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4) numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10) numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12) numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14) numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
def test_correct_sequence_elements_are_embedded_with_a_masked_sequence(self): sequence_tensor = torch.randn([2, 5, 8]) # concatentate start and end points together to form our representation # for both the forward and backward directions. extractor = BidirectionalEndpointSpanExtractor(input_dim=8, forward_combination="x,y", backward_combination="x,y") indices = torch.LongTensor([[[1, 3], [2, 4]], # This span has an end index at the # end of the padded sequence. [[0, 2], [0, 1]]]) sequence_mask = torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]) span_representations = extractor(sequence_tensor, indices, sequence_mask=sequence_mask) # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. (forward_start_embeddings, forward_end_embeddings, backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1) forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1) # Forward direction => subtract 1 from start indices to make them exlusive. correct_forward_start_indices = torch.LongTensor([[0, 1], [-1, -1]]) # These indices should be -1, so they'll be replaced with a sentinel. Here, # we'll set them to a value other than -1 so we can index select the indices and # replace them later. correct_forward_start_indices[1, 0] = 1 correct_forward_start_indices[1, 1] = 1 # Forward direction => end indices are the same. correct_forward_end_indices = torch.LongTensor([[3, 4], [2, 1]]) # Backward direction => start indices are exclusive, so add 1 to the end indices. correct_backward_start_indices = torch.LongTensor([[4, 5], [3, 2]]) # These exclusive backward start indices are outside the tensor, so will be replaced # with the end sentinel. Here we replace them with ones so we can index select using # these indices without torch complaining. correct_backward_start_indices[0, 1] = 1 # Backward direction => end indices are inclusive and equal to the forward start indices. correct_backward_end_indices = torch.LongTensor([[1, 2], [0, 0]]) correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(), correct_forward_start_indices) # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel. correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data correct_forward_start_embeddings[1, 1] = extractor._start_sentinel.data numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(), correct_forward_start_embeddings.data.numpy()) correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(), correct_forward_end_indices) numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(), correct_forward_end_embeddings.data.numpy()) correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(), correct_backward_end_indices) numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(), correct_backward_end_embeddings.data.numpy()) correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(), correct_backward_start_indices) # This element had sequence_tensor index == sequence_tensor.size(1), # so it's exclusive index is the end sentinel. correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data # This element has sequence_tensor index == the masked length of the batch element, # so it should be the end_sentinel even though it isn't greater than sequence_tensor.size(1). correct_backward_start_embeddings[1, 0] = extractor._end_sentinel.data numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(), correct_backward_start_embeddings.data.numpy())
def test_correct_sequence_elements_are_embedded(self): sequence_tensor = Variable(torch.randn([2, 5, 8])) # concatentate start and end points together to form our representation # for both the forward and backward directions. extractor = BidirectionalEndpointSpanExtractor(input_dim=8, forward_combination="x,y", backward_combination="x,y") indices = Variable(torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])) span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 16] assert extractor.get_output_dim() == 16 assert extractor.get_input_dim() == 8 # We just concatenated the start and end embeddings together, so # we can check they match the original indices if we split them apart. (forward_start_embeddings, forward_end_embeddings, backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1) forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1) # Forward direction => subtract 1 from start indices to make them exlusive. correct_forward_start_indices = Variable(torch.LongTensor([[0, 1], [-1, 2]])) # This index should be -1, so it will be replaced with a sentinel. Here, # we'll set it to a value other than -1 so we can index select the indices and # replace it later. correct_forward_start_indices[1, 0] = 1 # Forward direction => end indices are the same. correct_forward_end_indices = Variable(torch.LongTensor([[3, 4], [2, 4]])) # Backward direction => start indices are exclusive, so add 1 to the end indices. correct_backward_start_indices = Variable(torch.LongTensor([[4, 5], [3, 5]])) # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel. # Here we replace them with ones so we can index select using these indices without torch # complaining. correct_backward_start_indices[0, 1] = 1 correct_backward_start_indices[1, 1] = 1 # Backward direction => end indices are inclusive and equal to the forward start indices. correct_backward_end_indices = Variable(torch.LongTensor([[1, 2], [0, 3]])) correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(), correct_forward_start_indices) # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel. correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(), correct_forward_start_embeddings.data.numpy()) correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(), correct_forward_end_indices) numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(), correct_forward_end_embeddings.data.numpy()) correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(), correct_backward_end_indices) numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(), correct_backward_end_embeddings.data.numpy()) correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(), correct_backward_start_indices) # This element had sequence_tensor index == sequence_tensor.size(1), # so it's exclusive index is the end sentinel. correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(), correct_backward_start_embeddings.data.numpy())
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors(self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors(self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def _get_initial_rnn_and_grammar_state(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet], List[LambdaGrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings
def forward(self, # pylint: disable=arguments-differ embeddings: torch.FloatTensor, mask: torch.LongTensor, num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ Extracts the top-k scoring items with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that downstream components can rely on the original ordering (e.g., for knowing what spans are valid antecedents in a coreference resolution model). Parameters ---------- embeddings : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for each item in the list that we want to prune. mask : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_items), denoting unpadded elements of ``embeddings``. num_items_to_keep : ``int``, required. The number of items to keep when pruning. Returns ------- top_embeddings : ``torch.FloatTensor`` The representations of the top-k scoring items. Has shape (batch_size, num_items_to_keep, embedding_size). top_mask : ``torch.LongTensor`` The corresponding mask for ``top_embeddings``. Has shape (batch_size, num_items_to_keep). top_indices : ``torch.IntTensor`` The indices of the top-k scoring items into the original ``embeddings`` tensor. This is returned because it can be useful to retain pointers to the original items, if each item is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_items_to_keep). top_item_scores : ``torch.FloatTensor`` The values of the top-k scoring items. Has shape (batch_size, num_items_to_keep, 1). """ mask = mask.unsqueeze(-1) num_items = embeddings.size(1) # Shape: (batch_size, num_items, 1) scores = self._scorer(embeddings) if scores.size(-1) != 1 or scores.dim() != 3: raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape" f"(batch_size, num_items, 1), but found shape {scores.size()}") # Make sure that we don't select any masked items by setting their scores to be very # negative. These are logits, typically, so -1e20 should be plenty negative. scores = util.replace_masked_values(scores, mask, -1e20) # Shape: (batch_size, num_items_to_keep, 1) _, top_indices = scores.topk(num_items_to_keep, 1) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). top_indices, _ = torch.sort(top_indices, 1) # Shape: (batch_size, num_items_to_keep) top_indices = top_indices.squeeze(-1) # Shape: (batch_size * num_items_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select items for each element in the batch. flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items) # Shape: (batch_size, num_items_to_keep, embedding_size) top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep) top_mask = util.batched_index_select(mask, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep, 1) top_scores = util.batched_index_select(scores, top_indices, flat_top_indices) return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
def _get_initial_state_and_scores(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return {"initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score}