def _arithmetic_log_likelihood(self, answer_as_expressions, arithmetic_template_slot_log_probs, arithmetic_template_log_probs): # answer_as_expressions : (batch, #templates, #expressions, #slots) # arithmetic_template_slot_log_probs : (batch, #templates, #slots, #numbers) # arithmetic_template_log_probs : (batch, #templates) # shape : (batch, #templates, #slots, #expressions) gold_templates = answer_as_expressions.transpose(2,3).long() # mask for invalid/padded expressions gold_templates_mask = (gold_templates[:,:,:,:] != -1).long() clamped_gold_templates = \ util.replace_masked_values(gold_templates, gold_templates_mask, 0) # shape : (batch, #templates, #slots, #expressions) log_likelihood_per_slot = \ torch.gather(arithmetic_template_slot_log_probs, -1, clamped_gold_templates) # shape : (batch, #templates, #expressions) log_likelihood_per_expression = log_likelihood_per_slot.sum(2) # mask out padded expressions log_likelihood_per_expression = util.replace_masked_values(log_likelihood_per_expression, gold_templates_mask[:,:,0,:], -1e7) # shape : (batch, #templates) log_likelihood_per_template = util.logsumexp(log_likelihood_per_expression) log_joint_likelihood_for_arithmetic = log_likelihood_per_template + arithmetic_template_log_probs # Shape: (batch_size, ) log_marginal_likelihood_for_arithmetic = util.logsumexp(log_joint_likelihood_for_arithmetic) return log_marginal_likelihood_for_arithmetic
def _base_arithmetic_module(self, passage_vector, passage_out, number_indices, number_mask): number_indices = number_indices[:,:,0].long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) encoded_numbers = torch.gather( passage_out, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1))) if self.num_special_numbers > 0: special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device)) special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1) encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1) mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long() number_mask = torch.cat([mask, number_mask], -1) # Shape: (batch_size, # of numbers, 2*bert_dim) encoded_numbers = torch.cat( [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0) return number_sign_log_probs, best_signs_for_numbers, number_mask
def _question_span_log_likelihood(self, answer_as_question_spans, question_span_start_log_probs, question_span_end_log_probs): # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = (gold_question_span_starts != -1).long() clamped_gold_question_span_starts = \ util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = \ util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = \ torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = \ torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = \ log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = \ util.replace_masked_values(log_likelihood_for_question_spans, gold_question_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_question_span = \ util.logsumexp(log_likelihood_for_question_spans) return log_marginal_likelihood_for_question_span
def forward_span(self, ds_name, dialog_repr, repeated_ds_embeddings, context_masks, span_labels=None, spans_start=None, spans_end=None): batch_size, max_dialog_len = context_masks.size() ds_dialog_sim = self._ds_dialog_attention( self._dropout(repeated_ds_embeddings), self._dropout(dialog_repr)) ds_dialog_att = util.masked_softmax( ds_dialog_sim.view(-1, max_dialog_len), context_masks.view(-1, max_dialog_len)) ds_dialog_att = ds_dialog_att.view(batch_size, max_dialog_len) ds_dialog_repr = util.weighted_sum(dialog_repr, ds_dialog_att) ds_dialog_repr = ds_dialog_repr + repeated_ds_embeddings.squeeze(1) span_label_logits = self._span_label_predictor( F.relu(self._dropout(ds_dialog_repr))) span_label_prediction = torch.argmax(span_label_logits, dim=1) span_label_loss = 0.0 if span_labels is not None: span_label_loss = self._cross_entropy( span_label_logits, span_labels) # loss averaged by #turn self._accuracy.span_label_acc(ds_name, span_label_logits, span_labels, span_labels != -1) loss = span_label_loss w = self._span_prediction_layer( self._dropout(ds_dialog_repr)).unsqueeze(1) span_start_repr = self._span_start_encoder(self._dropout(dialog_repr)) span_start_logits = torch.bmm(w, span_start_repr.transpose(1, 2)).squeeze(1) span_start_probs = util.masked_softmax(span_start_logits, context_masks) span_start_logits = util.replace_masked_values( span_start_logits, context_masks.to(dtype=torch.int8), -1e7) span_end_repr = self._span_end_encoder(self._dropout(span_start_repr)) span_end_logits = torch.bmm(w, span_end_repr.transpose(1, 2)).squeeze(1) span_end_probs = util.masked_softmax(span_end_logits, context_masks) span_end_logits = util.replace_masked_values( span_end_logits, context_masks.to(dtype=torch.int8), -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) best_span = best_span.view(batch_size, -1) spans_loss = 0.0 if spans_start is not None: spans_loss = self._cross_entropy(span_start_logits, spans_start) self._accuracy.span_start_acc(ds_name, span_start_logits, spans_start, spans_start != -1) spans_loss += self._cross_entropy(span_end_logits, spans_end) self._accuracy.span_end_acc(ds_name, span_end_logits, spans_end, spans_end != -1) loss += spans_loss return loss, (span_label_prediction, best_span)
def _base_arithmetic_log_likelihood(self, answer_as_expressions, number_sign_log_probs, number_mask, answer_as_expressions_extra): if self.num_special_numbers > 0: answer_as_expressions = torch.cat( [answer_as_expressions_extra, answer_as_expressions], -1) # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_expressions.transpose(1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather(number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = \ util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = \ util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size,) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) return log_marginal_likelihood_for_add_sub
def _question_span_module(self, passage_vector, question_out, question_mask): # Shape: (batch_size, question_length) encoded_question_for_span_prediction = \ torch.cat([question_out, passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1) question_span_start_logits = \ self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = \ self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax( question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = \ util.replace_masked_values(question_span_start_logits, question_mask, -1e7) question_span_end_logits = \ util.replace_masked_values(question_span_end_logits, question_mask, -1e7) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) return question_span_start_log_probs, question_span_end_log_probs, best_question_span
def _passage_span_module(self, passage_out, passage_mask): # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( passage_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( passage_span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
def _get_edge_probabilities(self, encoded_premise, mean_node_premise_attention, edge_sources, edge_targets, edge_labels, metadata) -> FloatTensor: # dim: batch x nodes x emb. dim aggregate_node_premise_lstm_representation = weighted_sum( encoded_premise, mean_node_premise_attention) # dim: batch x edges x 1 edge_mask = (edge_sources != -1).float() edge_source_lstm_repr = self._select_embeddings_using_index( aggregate_node_premise_lstm_representation, replace_masked_values(edge_sources.float(), edge_mask, 0)) edge_target_lstm_repr = self._select_embeddings_using_index( aggregate_node_premise_lstm_representation, replace_masked_values(edge_targets.float(), edge_mask, 0)) # edge label embeddings. dim: batch x edges x edge dim masked_edge_labels = replace_masked_values(edge_labels.float(), edge_mask, 0).squeeze(2).long() edge_label_embeddings = self._edge_embedding(masked_edge_labels) # dim: batch x edges x (2* emb dim + edge dim) combined_edge_representation = torch.cat([ edge_source_lstm_repr, edge_label_embeddings, edge_target_lstm_repr ], 2) edge_prob_distribution = self._edge_probability( combined_edge_representation) edges_only_mask = edge_mask.expand_as(edge_prob_distribution).float() mean_edge_distribution = masked_mean(edge_prob_distribution, 1, edges_only_mask) return mean_edge_distribution
def forward(self, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]: input, mask = self.get_input_and_mask(kwargs) # Shape: (batch_size, passage_length) start_logits = self._start_output_layer(input).squeeze(-1) # Shape: (batch_size, passage_length) end_logits = self._end_output_layer(input).squeeze(-1) start_log_probs = masked_log_softmax(start_logits, mask) end_log_probs = masked_log_softmax(end_logits, mask) # Info about the best span prediction start_logits = replace_masked_values(start_logits, mask, -1e7) end_logits = replace_masked_values(end_logits, mask, -1e7) # Shape: (batch_size, 2) best_span = get_best_span(start_logits, end_logits) output_dict = { 'start_log_probs': start_log_probs, 'end_log_probs': end_log_probs, 'best_span': best_span } return output_dict
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: sentence_mask = util.get_text_field_mask(sentence).float() embedded_sentence = self._text_field_embedder(sentence) dropped_embedded_sent = self._embedding_dropout(embedded_sentence) pre_encoded_sent = self._pre_encode_feedforward(dropped_embedded_sent) encoded_tokens = self._encoder(pre_encoded_sent, sentence_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm( encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.last_dim_softmax(attention_logits, sentence_mask) encoded_sentence = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([ encoded_tokens, encoded_tokens - encoded_sentence, encoded_tokens * encoded_sentence ], 2) integrated_encodings = self._integrator(integrator_input, sentence_mask) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, sentence_mask.unsqueeze(2), -1e7) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, sentence_mask.unsqueeze(2), +1e7) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum( sentence_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, sentence_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat( [max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self._integrator_dropout( pooled_representations).squeeze(1) logits = self._output_layer(pooled_representations_dropped) output_dict = {'logits': logits} if label is not None: loss = self.loss(logits, label.squeeze(-1)) for metric in self.metrics.values(): metric(logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask): # # ===============================================================test============================================ # # Layer 5: Span prediction for before and after location # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape #print("contextual_seq_embedding: ", contextual_seq_embedding.shape) # size(span_start_input_after): batch_size * num_sentences * # num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size) span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1) #print("span_start_input_after: ", span_start_input_after.shape) # Shape: (bs, ns , np, sl) span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1) #print("span_start_logits_after: ", span_start_logits_after.shape) # Shape: (bs, ns , np, sl) span_start_probs_after = util.masked_softmax(span_start_logits_after, mask) #print("span_start_probs_after: ", span_start_probs_after.shape) # span_start_representation_after: (bs, ns , np, encoder_dim) span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after) #print("span_start_representation_after: ", span_start_representation_after.shape) # span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size) span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size, num_sentences, num_participants, sentence_length, encoder_dim) #print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape) # Shape: (batch_size, passage_length, (embedding+2 + encoder_dim + encoder_dim + encoder_dim)) span_end_representation_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding, span_tiled_start_representation_after, contextual_seq_embedding * span_tiled_start_representation_after], dim=-1) #print("span_end_representation_after: ", span_end_representation_after.shape) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask) #print("encoded_span_end_after: ", encoded_span_end_after.shape) span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1) #print("span_end_logits_after: ", span_end_logits_after.shape) span_end_probs_after = util.masked_softmax(span_end_logits_after, mask) #print("span_end_probs_after: ", span_end_probs_after.shape) span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7) span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7) # Fixme: we should condition this on predicted_action so that we can output '-' when needed # Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after) #print("best_span_after: ", best_span_after) return best_span_after, span_start_logits_after, span_end_logits_after
def _get_span_answer_log_prob( answer_as_spans: torch.LongTensor, span_log_probs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """ Compute the log_marginal_likelihood for the answer_spans given log_probs for start/end Compute log_likelihood (product of start/end probs) of each ans_span Sum the prob (logsumexp) for each span and return the log_likelihood Parameters: ----------- answer: ``torch.LongTensor`` Shape: (number_of_spans, 2) These are the gold spans span_log_probs: ``torch.FloatTensor`` 2-Tuple with tensors of Shape: (length_of_sequence) for span_start/span_end log_probs Returns: log_marginal_likelihood_for_passage_span """ # Unsqueezing dim=0 to make a batch_size of 1 answer_as_spans = answer_as_spans.unsqueeze(0) span_start_log_probs, span_end_log_probs = span_log_probs span_start_log_probs = span_start_log_probs.unsqueeze(0) span_end_log_probs = span_end_log_probs.unsqueeze(0) # (batch_size, number_of_ans_spans) gold_passage_span_starts = answer_as_spans[:, :, 0] gold_passage_span_ends = answer_as_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = allenutil.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = allenutil.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_span_starts = torch.gather( span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_span_ends = torch.gather( span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_spans = log_likelihood_for_span_starts + log_likelihood_for_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_spans = allenutil.replace_masked_values( log_likelihood_for_spans, gold_passage_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_span = allenutil.logsumexp( log_likelihood_for_spans) return log_marginal_likelihood_for_span
def encode(self, sentence, mask): out1, (ht1, ct1) = self.rnn1(sentence) # max pool emb1, _ = replace_masked_values(out1, mask.unsqueeze(-1), -1e7).max(dim=1) out2, (ht2, ct2) = self.rnn2(sentence, (ht1, ct1)) # max pool emb2, _ = replace_masked_values(out2, mask.unsqueeze(-1), -1e7).max(dim=1) out3, (ht3, ct3) = self.rnn3(sentence, (ht2, ct2)) emb3, _ = replace_masked_values(out3, mask.unsqueeze(-1), -1e7).max(dim=1) return torch.cat([emb1, emb2, emb3], dim=-1)
def _arithmetic_module(self, arithmetic_passage_vector, passage_out, number_indices, number_mask): if self.number_rep in ['average', 'attention']: # Shape: (batch_size, # of numbers, # of pieces) number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long() batch_size = number_indices.shape[0] num_numbers = number_indices.shape[1] seqlen = passage_out.shape[1] # Shape : (batch_size, # of numbers, seqlen) mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter( 2, number_indices, torch.ones(number_indices.shape, device=number_indices.device).long()) mask[:,:,0] = 0 # Shape : (batch_size, # of numbers, seqlen, bert_dim) epassage_out = passage_out.unsqueeze(1).repeat(1,num_numbers,1,1) # Shape : (batch_size, # of numbers, bert_dim) encoded_numbers = self.summary_vector(epassage_out, mask, "numbers") else: number_indices = number_indices[:,:,0].long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) encoded_numbers = torch.gather( passage_out, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1))) if self.num_special_numbers > 0: special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device)) special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1) encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1) mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long() number_mask = torch.cat([mask, number_mask], -1) # Shape: (batch_size, # of numbers, 2*bert_dim) encoded_numbers = torch.cat( [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, #templates, #slots, #numbers) arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2) arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask) arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0], self.num_arithmetic_templates, self.num_template_slots, number_mask.shape[-1]) # Shape: (batch_size, #templates, #slots) arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1) return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
def _base_arithmetic_module(self, passage_vector, passage_out, number_indices, number_mask): if self.number_rep in ['average', 'attention']: # Shape: (batch_size, # of numbers, # of pieces) number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long() batch_size = number_indices.shape[0] num_numbers = number_indices.shape[1] seqlen = passage_out.shape[1] # Shape : (batch_size, # of numbers, seqlen) mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter( 2, number_indices, torch.ones(number_indices.shape, device=number_indices.device).long()) mask[:, :, 0] = 0 # Shape : (batch_size, # of numbers, seqlen, bert_dim) epassage_out = passage_out.unsqueeze(1).repeat(1, num_numbers, 1, 1) # Shape : (batch_size, # of numbers, bert_dim) encoded_numbers = self.summary_vector(epassage_out, mask, "numbers") else: number_indices = number_indices[:, :, 0].long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) encoded_numbers = torch.gather( passage_out, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1))) if self.num_special_numbers > 0: special_numbers = self.special_embedding( torch.arange(self.num_special_numbers, device=number_indices.device)) special_numbers = special_numbers.expand(number_indices.shape[0], -1, -1) encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1) mask = torch.ones((number_indices.shape[0], self.num_special_numbers), device=number_indices.device).long() number_mask = torch.cat([mask, number_mask], -1) # Shape: (batch_size, # of numbers, 2*bert_dim) encoded_numbers = torch.cat( [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0) return number_sign_log_probs, best_signs_for_numbers, number_mask
def _count_log_likelihood(self, answer_as_counts, count_number_log_probs): # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = \ util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts) return log_marginal_likelihood_for_count
def module(self, bert_out, seq_mask=None): logits = self.predictor(bert_out) if self._use_crf: # The mask should not be applied here when using CRF, but should be passed ot the CRF log_probs = torch.nn.functional.log_softmax(logits, dim=-1) else: if seq_mask is not None: log_probs = replace_masked_values(torch.nn.functional.log_softmax(logits, dim=-1), seq_mask.unsqueeze(-1), 0.0) logits = replace_masked_values(logits, seq_mask.unsqueeze(-1), -1e7) else: log_probs = torch.nn.functional.log_softmax(logits) return log_probs, logits
def _count_loss(self, answer_as_counts, count_number, max_prob, min_prob): # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size,) gold_counts_masked = util.replace_masked_values(answer_as_counts, gold_count_mask, 0) count_number_masked = util.replace_masked_values(count_number, gold_count_mask, 0) huber_loss = self.huber_loss(count_number_masked, gold_counts_masked.float()) selection_loss = (1 - (max_prob - min_prob)) * 1000 # Shape: (batch_size, ) return huber_loss + selection_loss
def aux_window_loss(ptop_attention, passage_mask, inwindow_mask): """Auxiliary loss to encourage p-to-p attention to be within a certain window. Args: ptop_attention: (passage_length, passage_length) passage_mask: (passage_length) inwindow_mask: (passage_length, passage_length) Returns: inwindow_aux_loss: () """ inwindow_mask = inwindow_mask * passage_mask.unsqueeze(0) inwindow_mask = inwindow_mask * passage_mask.unsqueeze(1) inwindow_probs = ptop_attention * inwindow_mask # Sum inwindow_probs for each token, signifying the token can distribute its alignment prob in any way # Shape: (passage_length) sum_inwindow_probs = inwindow_probs.sum(1) # Shape: (passage_length) -- mask for tokens that have empty windows mask_sum = (inwindow_mask.sum(1) > 0).float() masked_sum_inwindow_probs = allenutil.replace_masked_values( sum_inwindow_probs, mask_sum, replace_with=1e-40) log_sum_inwindow_probs = torch.log(masked_sum_inwindow_probs + 1e-40) * mask_sum inwindow_likelihood = torch.sum(log_sum_inwindow_probs) inwindow_likelihood_avg = inwindow_likelihood / torch.sum(inwindow_mask) inwindow_aux_loss = -1.0 * inwindow_likelihood_avg return inwindow_aux_loss
def replace_masked_values_with_big_negative_number(x: torch.Tensor, mask: torch.Tensor): """ Replace the masked values in a tensor something really negative so that they won't affect a max operation. """ return replace_masked_values(x, mask, min_value_of_dtype(x.dtype))
def masked_mean(tensor, dim, mask): """ ``Performs a mean on just the non-masked portions of the ``tensor`` in the ``dim`` dimension of the tensor. """ if mask is None: return torch.mean(tensor, dim) '''print("****") print(tensor.size()) print(mask.size()) print(tensor.dim()) print(mask.dim()) print(dim) print("****")''' if tensor.dim() != mask.dim(): raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())) masked_tensor = replace_masked_values(tensor, mask, 0.0) # total value total_tensor = torch.sum(masked_tensor, dim) # count count_tensor = torch.sum((mask != 0), dim) # set zero count to 1 to avoid nans # zero_count_mask = (count_tensor == 0) zero_count_mask = (count_tensor == 0).long() count_plus_zeros = (count_tensor + zero_count_mask).float() # average mean_tensor = total_tensor / count_plus_zeros return mean_tensor
def masked_mean(tensor, dim, mask): """ ``Performs a mean on just the non-masked portions of the ``tensor`` in the ``dim`` dimension of the tensor. ===================================================================== From Decomposable Graph Entailment Model code replicated from SciTail repo https://github.com/allenai/scitail ===================================================================== """ if mask is None: return torch.mean(tensor, dim) if tensor.dim() != mask.dim(): raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())) masked_tensor = replace_masked_values(tensor, mask, 0.0) # total value total_tensor = torch.sum(masked_tensor, dim) # count count_tensor = torch.sum((mask != 0), dim) # set zero count to 1 to avoid nans zero_count_mask = (count_tensor == 0) zero_count_mask = zero_count_mask.long() count_plus_zeros = (count_tensor + zero_count_mask).float() # average mean_tensor = total_tensor / count_plus_zeros return mean_tensor
def gold_log_marginal_likelihood( self, gold_answer_representations: Dict[str, torch.LongTensor], log_probs: torch.LongTensor, question_and_passage_mask: torch.LongTensor, passage_mask: torch.LongTensor, first_wordpiece_mask: torch.LongTensor, is_bio_mask: torch.LongTensor, **kwargs: Any): mask = self._get_mask(question_and_passage_mask, passage_mask, first_wordpiece_mask) gold_bio_seqs = self._get_gold_answer(gold_answer_representations, log_probs, mask) if self._training_style == 'soft_em': log_marginal_likelihood = self._marginal_likelihood( gold_bio_seqs, log_probs) elif self._training_style == 'hard_em': log_marginal_likelihood = self._get_most_likely_likelihood( gold_bio_seqs, log_probs) else: raise Exception("Illegal training_style") # For questions without spans, we set their log likelihood to be very small negative value log_marginal_likelihood = replace_masked_values( log_marginal_likelihood, is_bio_mask, -1e7) return log_marginal_likelihood
def _get_most_likely_likelihood(self, bio_seqs: torch.LongTensor, log_probs: torch.LongTensor): # bio_seqs - Shape: (batch_size, # of correct sequences, seq_length) # log_probs - Shape: (batch_size, seq_length, 3) # Shape: (batch_size, # of correct sequences, seq_length, 3) # duplicate log_probs for each gold bios sequence expanded_log_probs = log_probs.unsqueeze(1).expand( -1, bio_seqs.size()[1], -1, -1) # get the log-likelihood per each sequence index # Shape: (batch_size, # of correct sequences, seq_length) log_likelihoods = \ torch.gather(expanded_log_probs, dim=-1, index=bio_seqs.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size, # of correct sequences) correct_sequences_pad_mask = (bio_seqs.sum(-1) > 0).long() # Sum the log-likelihoods for each index to get the log-likelihood of the sequence # Shape: (batch_size, # of correct sequences) sequences_log_likelihoods = log_likelihoods.sum(dim=-1) sequences_log_likelihoods = replace_masked_values( sequences_log_likelihoods, correct_sequences_pad_mask, -1e7) most_likely_sequence_index = sequences_log_likelihoods.argmax(dim=-1) return sequences_log_likelihoods.gather( dim=1, index=most_likely_sequence_index.unsqueeze(-1)).squeeze(dim=-1)
def forward(self, embedded_input, input_mask, other_input=None, other_mask=None): #assumes input is batch_size * num_words * embedding_dim if self._hidden_feedforward is not None: embedded_input = self._hidden_feedforward(embedded_input) to_cat = [] if self._max_pool: input_max, _ = replace_masked_values(embedded_input, input_mask.unsqueeze(-1), -1e7).max(dim=1) to_cat.append(input_max) if self._avg_pool: input_avg = torch.sum( embedded_input * input_mask.float().unsqueeze(-1), dim=1) / torch.sum(input_mask.float(), 1, keepdim=True) to_cat.append(input_avg) output = torch.cat(to_cat, dim=1) if self._projection_feedforward is not None: output = self._projection_feedforward(output) return output
def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # get the relevant scores for the time step class_log_probabilities = state['log_probs'][:, state['step_num'][0], :] is_wordpiece = ( 1 - state['wordpiece_mask'][:, state['step_num'][0]]).byte() # mask illegal BIO transitions transitions_mask = torch.cat( (torch.ones_like(class_log_probabilities[:, :3]), torch.zeros_like(class_log_probabilities[:, -2:])), dim=-1).byte() transitions_mask[:, 2] &= ((last_predictions == 1) | (last_predictions == 2)) transitions_mask[:, 1:3] &= ((class_log_probabilities[:, :3] == 0.0).sum(-1) != 3).unsqueeze(-1).repeat( 1, 2) # assuming the wordpiece mask doesn't intersect with the other masks (pad, cls/sep) transitions_mask[:, 2] |= is_wordpiece & ((last_predictions == 1) | (last_predictions == 2)) class_log_probabilities = replace_masked_values( class_log_probabilities, transitions_mask, -1e7) state['step_num'] = state['step_num'].clone() + 1 return class_log_probabilities, state
def posAttnConv(self, sentence, other_sen, interaction, sentence_mask, other_sen_mask, matrix_mask): """ @brief Compute the position-aware attentive convolution @param self The object @param sentence The embeded sentence (n x s x d) @param other_sen The other sentence (n x s' x d) @param interaction The interaction matrix (n x s x s') @param sentence_mask The mask of the sentence (n x s) @param other_sen_mask The mask of other sentence (n x s') @param matrix_mask The mask of the interaction matrix (n x s x s') @return The position-aware attentive convolution """ # calculate the representation of the sentence interaction_softmax = last_dim_softmax( interaction, other_sen_mask) # (n x s x s') sentence_tilda = weighted_sum( other_sen, interaction_softmax) # (n x s x d) # get index of the best-matched word _, x = replace_masked_values(interaction, matrix_mask, -1e7).max(dim=-1) # (n x s) z = self._pos_embedder(x) # (n x s x dm) sentence_combined = torch.cat((sentence_tilda, sentence, z), dim=2) # (n x s x (2d + dm)) return self._pos_attn_encoder(sentence_combined, sentence_mask)
def _get_combined_likelihood(self, answer_as_list_of_bios, log_probs): # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length) # log_probs - Shape: (batch_size, seq_length, 3) # Shape: (batch_size, # of correct sequences, seq_length, 3) # duplicate log_probs for each gold bios sequence expanded_log_probs = log_probs.unsqueeze(1).expand( -1, answer_as_list_of_bios.size()[1], -1, -1) # get the log-likelihood per each sequence index # Shape: (batch_size, # of correct sequences, seq_length) log_likelihoods = \ torch.gather(expanded_log_probs, dim=-1, index=answer_as_list_of_bios.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size, # of correct sequences) correct_sequences_pad_mask = (answer_as_list_of_bios.sum(-1) > 0).long() # Sum the log-likelihoods for each index to get the log-likelihood of the sequence # Shape: (batch_size, # of correct sequences) sequences_log_likelihoods = log_likelihoods.sum(dim=-1) sequences_log_likelihoods = replace_masked_values( sequences_log_likelihoods, correct_sequences_pad_mask, -1e7) # Sum the log-likelihoods for each sequence to get the marginalized log-likelihood over the correct answers log_marginal_likelihood = logsumexp(sequences_log_likelihoods, dim=-1) return log_marginal_likelihood
def forward(self, token_representations: torch.LongTensor, passage_summary_vector: torch.LongTensor, number_indices: torch.LongTensor, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]: number_mask = self._get_mask(number_indices, with_special_numbers=False) clamped_number_indices = replace_masked_values( number_indices[:, :, 0].long(), number_mask, 0) encoded_numbers = torch.gather( token_representations, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, token_representations.size(-1))) if self._num_special_numbers > 0: special_numbers = self._special_embeddings( torch.arange(self._num_special_numbers, device=number_indices.device)) special_numbers = special_numbers.expand(number_indices.shape[0], -1, -1) encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1) # Shape: (batch_size, # of numbers, 2*bert_dim) encoded_numbers = torch.cat([ encoded_numbers, passage_summary_vector.unsqueeze(1).repeat( 1, encoded_numbers.size(1), 1) ], -1) # Shape: (batch_size, # of numbers in the passage, 3) logits = self._output_layer(encoded_numbers) log_probs = torch.nn.functional.log_softmax(logits, -1) number_mask = self._get_mask(number_indices, with_special_numbers=True) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = replace_masked_values(best_signs_for_numbers, number_mask, 0) output_dict = { 'log_probs': log_probs, 'logits': logits, 'best_signs_for_numbers': best_signs_for_numbers } return output_dict
def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], logical_forms: Dict[str, torch.LongTensor], utterance_string: List[str], logical_form_strings: List[List[str]]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- Returns ------- """ # (batch_size, num_utterance_tokens, utterance_embedding_dim) embedded_utterance = self.utterance_embedder(utterance) # (batch_size, num_logical_forms, num_lf_tokens, lf_embedding_dim) embedded_logical_forms = self.logical_form_embedder( logical_forms, num_wrapping_dims=1) # (batch_size, num_logical_forms, num_lf_tokens) logical_form_token_mask = util.get_text_field_mask(logical_forms, num_wrapping_dims=1) # (batch_size, num_logical_forms) logical_form_mask = logical_form_token_mask.sum(dim=-1).clamp(max=1) # Because we're just summing everything in the end, we can do the sum upfront to save some # time. # (batch_size, utterance_embedding_dim) encoded_utterance = embedded_utterance.sum(dim=1) # (batch_size, num_logical_forms, lf_embedding_dim) encoded_logical_forms = embedded_logical_forms.sum(dim=2) # (batch_size, num_logical_forms, utterance_embedding_dim) predicted_embeddings = self.translation_layer(encoded_logical_forms) # (batch_size, num_logical_forms) similarities = torch.nn.functional.cosine_similarity( predicted_embeddings, encoded_utterance.unsqueeze(1), dim=2) # Make sure masked logical forms aren't included in the max. similarities = util.replace_masked_values(similarities, logical_form_mask, -1e7) max_similarity, most_similar = similarities.max(dim=-1) loss = (1 - max_similarity).sum() most_similar_strings = [] for instance_most_similar, instance_logical_forms in zip( most_similar.tolist(), logical_form_strings): most_similar_strings.append( instance_logical_forms[instance_most_similar]) return { "loss": loss, "most_similar": most_similar_strings, "utterance": utterance_string }
def forward(self, # pylint: disable=arguments-differ embeddings: torch.FloatTensor, mask: torch.LongTensor, num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ Extracts the top-k scoring items with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that downstream components can rely on the original ordering (e.g., for knowing what spans are valid antecedents in a coreference resolution model). Parameters ---------- embeddings : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for each item in the list that we want to prune. mask : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_items), denoting unpadded elements of ``embeddings``. num_items_to_keep : ``int``, required. The number of items to keep when pruning. Returns ------- top_embeddings : ``torch.FloatTensor`` The representations of the top-k scoring items. Has shape (batch_size, num_items_to_keep, embedding_size). top_mask : ``torch.LongTensor`` The corresponding mask for ``top_embeddings``. Has shape (batch_size, num_items_to_keep). top_indices : ``torch.IntTensor`` The indices of the top-k scoring items into the original ``embeddings`` tensor. This is returned because it can be useful to retain pointers to the original items, if each item is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_items_to_keep). top_item_scores : ``torch.FloatTensor`` The values of the top-k scoring items. Has shape (batch_size, num_items_to_keep, 1). """ mask = mask.unsqueeze(-1) num_items = embeddings.size(1) # Shape: (batch_size, num_items, 1) scores = self._scorer(embeddings) if scores.size(-1) != 1 or scores.dim() != 3: raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape" f"(batch_size, num_items, 1), but found shape {scores.size()}") # Make sure that we don't select any masked items by setting their scores to be very # negative. These are logits, typically, so -1e20 should be plenty negative. scores = util.replace_masked_values(scores, mask, -1e20) # Shape: (batch_size, num_items_to_keep, 1) _, top_indices = scores.topk(num_items_to_keep, 1) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). top_indices, _ = torch.sort(top_indices, 1) # Shape: (batch_size, num_items_to_keep) top_indices = top_indices.squeeze(-1) # Shape: (batch_size * num_items_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select items for each element in the batch. flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items) # Shape: (batch_size, num_items_to_keep, embedding_size) top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep) top_mask = util.batched_index_select(mask, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep, 1) top_scores = util.batched_index_select(scores, top_indices, flat_top_indices) return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
matrix_attention = LegacyMatrixAttention(similarity_function) passage_question_similarity = matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) print ("passage question similarity: ", passage_question_similarity.shape) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors,
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def test_replace_masked_values_replaces_masked_values_with_finite_value(self): tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]]) mask = torch.FloatTensor([[1, 1, 0]]) replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy() assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``. label : torch.LongTensor, optional (default = None) A variable representing the label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ text_mask = util.get_text_field_mask(tokens).float() # Pop elmo tokens, since elmo embedder should not be present. elmo_tokens = tokens.pop("elmo", None) embedded_text = self._text_field_embedder(tokens) # Add the "elmo" key back to "tokens" if not None, since the tests and the # subsequent training epochs rely not being modified during forward() if elmo_tokens is not None: tokens["elmo"] = elmo_tokens # Create ELMo embeddings if applicable if self._elmo: if elmo_tokens is not None: elmo_representations = self._elmo(elmo_tokens)["elmo_representations"] # Pop from the end is more performant with list if self._use_integrator_output_elmo: integrator_output_elmo = elmo_representations.pop() if self._use_input_elmo: input_elmo = elmo_representations.pop() assert not elmo_representations else: raise ConfigurationError( "Model was built to use Elmo, but input text is not tokenized for Elmo.") if self._use_input_elmo: embedded_text = torch.cat([embedded_text, input_elmo], dim=-1) dropped_embedded_text = self._embedding_dropout(embedded_text) pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text) encoded_tokens = self._encoder(pre_encoded_text, text_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.last_dim_softmax(attention_logits, text_mask) encoded_text = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([encoded_tokens, encoded_tokens - encoded_text, encoded_tokens * encoded_text], 2) integrated_encodings = self._integrator(integrator_input, text_mask) # Concatenate ELMo representations to integrated_encodings if specified if self._use_integrator_output_elmo: integrated_encodings = torch.cat([integrated_encodings, integrator_output_elmo], dim=-1) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), -1e7) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), +1e7) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, text_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self._integrator_dropout(pooled_representations) logits = self._output_layer(pooled_representations_dropped) class_probabilities = F.softmax(logits, dim=-1) output_dict = {'logits': logits, 'class_probabilities': class_probabilities} if label is not None: loss = self.loss(logits, label) for metric in self.metrics.values(): metric(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis], dim=-1 ) hypothesis_enhanced = torch.cat( [encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise], dim=-1 ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward(premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values( v_ai, premise_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_b_max, _ = replace_masked_values( v_bi, hypothesis_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True ) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True ) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict