def test_masked_mean(self): # Testing the general masked 1D case. vector_1d = torch.FloatTensor([1.0, 12.0, 5.0]) mask_1d = torch.FloatTensor([1.0, 0.0, 1.0]) vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy() assert_array_almost_equal(vector_1d_mean, 3.0) # Testing if all masks are zero, the output will be arbitrary, but it should not be nan. vector_1d = torch.FloatTensor([1.0, 12.0, 5.0]) mask_1d = torch.FloatTensor([0.0, 0.0, 0.0]) vector_1d_mean = util.masked_mean(vector_1d, mask_1d, dim=0).data.numpy() assert not numpy.isnan(vector_1d_mean).any() # Testing batch value and batch masks matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]) matrix_mean = util.masked_mean(matrix, mask, dim=-1).data.numpy() assert_array_almost_equal(matrix_mean, numpy.array([3.0, -1.5])) # Testing keepdim for batch value and batch masks matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]) matrix_mean = util.masked_mean(matrix, mask, dim=-1, keepdim=True).data.numpy() assert_array_almost_equal(matrix_mean, numpy.array([[3.0], [-1.5]])) # Testing broadcast matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]], [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1) matrix_mean = util.masked_mean(matrix, mask, dim=1).data.numpy() assert_array_almost_equal(matrix_mean, numpy.array([[3.0, 0.5], [-1.5, -1.75]]))
def _encode_utt_schema(self, enc, offsets, relation, lengths): embedded_utterance_schema = self.emb_q(enc) ( embedded_utterance_schema, embedded_utterance_schema_mask, ) = vec_utils.batched_span_select(embedded_utterance_schema, offsets) embedded_utterance_schema = masked_mean( embedded_utterance_schema, embedded_utterance_schema_mask.unsqueeze(-1), dim=-2, ) relation_mask = (relation >= 0).float() # TODO: fixme torch.abs(relation, out=relation) embedded_utterance_schema = self._emb_to_action_dim( embedded_utterance_schema) enriched_utterance_schema = self._schema_encoder( embedded_utterance_schema, relation.long(), relation_mask) utterance_schema, utterance_schema_mask = vec_utils.batched_span_select( enriched_utterance_schema, lengths) utterance, schema = torch.split(utterance_schema, 1, dim=1) utterance_mask, schema_mask = torch.split(utterance_schema_mask, 1, dim=1) utterance_mask = torch.squeeze(utterance_mask, 1) schema_mask = torch.squeeze(schema_mask, 1) embedded_utterance = torch.squeeze(utterance, 1) schema = torch.squeeze(schema, 1) return schema, schema_mask, embedded_utterance, utterance_mask
def _init_state(self, triples: Dict[str, torch.LongTensor], predicate: Dict[str, torch.LongTensor], draft: Dict[str, torch.LongTensor], triple_ids: torch.LongTensor) -> Dict[str, torch.Tensor]: emb_pred = util.masked_mean( self.EMB(predicate), util.get_text_field_mask( predicate, num_wrapping_dims=1, ).unsqueeze(-1), 2) emb_triple = self.EMB(triples) triple_mask = util.get_text_field_mask(triples) flat_triples = torch.cat((emb_triple.flatten(2, 3), emb_pred), dim=-1) encoded_triples = self.FACT_ENCODER(flat_triples) emb_draft = self.EMB(draft) draft_mask = util.get_text_field_mask(draft) end_point = (draft_mask.sum(dim=1) - 1) encoded_draft = self.BUFFER(emb_draft, draft_mask) return { "draft_mask": draft_mask, "triple_mask": triple_mask, "end_point": end_point, "encoded_triple": encoded_triples, "encoded_draft": encoded_draft, "triple_tokens": triples["tokens"][:, :, -1], "triple_token_ids": triple_ids }
def forward(self, **kwargs) -> torch.FloatTensor: mask = kwargs['mask'] embedded_text = kwargs['embedded_text'] encoded_output = self._architecture(embedded_text, mask) encoded_repr = [] for aggregation in self._aggregations: if aggregation == "meanpool": broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_mean(context_vectors, broadcast_mask, dim=1, keepdim=False) elif aggregation == 'maxpool': broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_max(context_vectors, broadcast_mask, dim=1) elif aggregation == 'final_state': is_bi = self._architecture.is_bidirectional() encoded_text = get_final_encoder_states(encoded_output, mask, is_bi) elif aggregation == 'attention': alpha = self._attention_layer(encoded_output) alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp() encoded_text = alpha * encoded_output encoded_text = encoded_text.sum(dim=1) else: raise ConfigurationError(f"{aggregation} aggregation not available.") encoded_repr.append(encoded_text) encoded_repr = torch.cat(encoded_repr, 1) return encoded_repr
def _encode_definition( self, definition: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # [batch_size, seq_len] definition_mask = util.get_text_field_mask(definition) # [batch_size, seq_len, emb_dim] embedded_definition = self.text_embedder(definition) # either [batch_size, emb_dim] or [batch_size, seq_len, emb_dim] encoded_definition = self.definition_encoder(embedded_definition, definition_mask) # if len(encoded_definition.size()) == 3: if self.definition_pooling == 'last': # [batch_size, emb_dim] encoded_definition = util.get_final_encoder_states( encoded_definition, definition_mask) elif self.definition_pooling == 'max': # encoded_definition = F.adaptive_max_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2) encoded_definition = util.masked_max(encoded_definition, definition_mask.unsqueeze(2), dim=1) elif self.definition_pooling == 'mean': # encoded_definition = F.adaptive_avg_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2) encoded_definition = util.masked_mean(encoded_definition, definition_mask.unsqueeze(2), dim=1) elif self.definition_pooling == 'self-attentive': self_attentive_logits = self.self_attentive_pooling_projection( encoded_definition).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, definition_mask) encoded_definition = util.weighted_sum(encoded_definition, self_weights) # [batch_size, emb_dim] definition_embedding = self.definition_feedforward(encoded_definition) # [batch_size, vocab_size(num_class)] definition_logits = self.definition_decoder(definition_embedding) # [batch_size, seq_len, vocab_size] sequence_definition_logits = definition_logits.unsqueeze(1).repeat( 1, definition_mask.size(1), 1) # ``average`` can be None, "batch", or "token" # loss for ``average==None`` is a vector of shape (batch_size,); otherwise, a scalar targets = definition['tokens'].clone() if self.limited_word_vocab_size is not None: targets[targets >= self.limited_word_vocab_size] = self._oov_index cross_entropy_loss = util.sequence_cross_entropy_with_logits( sequence_definition_logits, targets, # definition['tokens'], weights=definition_mask, average='token') return { "definition_embedding": definition_embedding, "cross_entropy_loss": cross_entropy_loss }
def forward(self, # pylint: disable=arguments-differ premises_relevance_logits: torch.Tensor, premises_presence_mask: torch.Tensor, relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10) binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask) coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1) coverage_loss = coverage_losses.mean() return coverage_loss
def pred(self, ws, ctxs): # ws : no use # ctxs : B,C,S ctxs = torch.stack(ctxs) x = self.wvec[ctxs].cuda() # B,C,S,D mask = (ctxs != -1).cuda() # B,C,S B, C, S, D = x.shape x = x.reshape(B * C, S, D) mask = mask.reshape(B * C, S) x = self.posenc(x) # B*C,S,D x = self.ctxenc(x, mask) # B*C,S,D x = masked_mean(x, mask[:, :, None], dim=-2) # B*C,D x = x.reshape(B, C, D) mask = mask.reshape(B, C, S).any(-1) # B,C x = self.ctxagg(x, mask) # B,C,D x = masked_mean(x, mask[:, :, None], dim=-2) # B,D return x
def _compute_answer(self, premise_memory: torch.Tensor, hypothesis_memory: torch.Tensor, premise_mask: torch.Tensor, hypothesis_mask: torch.Tensor) -> torch.Tensor: batch_size = premise_memory.size(0) num_labels = self._output_logit.get_output_dim() # Shape: (batch_size, hypothesis_length) hypothesis_attention = util.masked_softmax( self._answer_attention(hypothesis_memory).squeeze(), hypothesis_mask, ) # Shape: (batch_size, embedding_dim) answer_state = util.weighted_sum(hypothesis_memory, hypothesis_attention) label_prob_steps: torch.Tensor = answer_state.new_zeros( (batch_size, num_labels, self._answer_steps)) for step in range(self._answer_steps): # Shape: (batch_size, premise_length) premise_attention = self._answer_bilinear(answer_state, premise_memory, premise_mask) # Shape: (batch_size, embedding_dim) cell_input = util.weighted_sum(premise_memory, premise_attention) answer_state = self._answer_gru_cell(cell_input, answer_state) output_hidden = torch.cat([ answer_state, cell_input, (answer_state - cell_input).abs(), answer_state * cell_input, ], dim=-1) label_logits = self._output_logit( self._output_feedforward(output_hidden)) label_prob_steps[:, :, step] = label_logits.softmax(-1) if self.training and self._dropout: # stochastic prediction dropout binary_mask = (torch.rand( (batch_size, self._answer_steps)) > self._dropout.p).to( label_prob_steps.device) label_probs = util.masked_mean(label_prob_steps, binary_mask.float().unsqueeze(1), dim=2) label_probs = util.replace_masked_values( label_probs, binary_mask.sum(1, keepdim=True).bool().float(), 1.0 / num_labels) else: label_probs = label_prob_steps.mean(2) return label_probs
def _decoder_init(self, state): mean_draft = util.masked_mean(state["encoded_draft"], state["draft_mask"].unsqueeze(-1), 1) mean_triple = util.masked_mean(state["encoded_triple"], state["triple_mask"].unsqueeze(-1), 1) concatenated = torch.cat((mean_draft, mean_triple), dim=-1) batch_size = state["draft_mask"].size(0) zeros = mean_draft.new_zeros((batch_size, self.decoder_size)) state["stream_hidden"], state["stream_context"] = self.U( concatenated), zeros state["draft_pointer"] = state["draft_mask"].new_ones((batch_size, )) action_mask = mean_draft.new_ones((batch_size, self.vocab_size)) action_mask[:, self.PAD] = 0 action_mask[:, self.END] = 0 state["action_mask"] = action_mask return state
def _get_summary_of_encoder_outputs(self, encoder_outputs, source_mask): # This returns last final encoder output in case of RNN encoders, # and mean of the outputs in case of other encoders if type(self._encoder) == PytorchSeq2SeqWrapper: summary = util.get_final_encoder_states( encoder_outputs, source_mask, self._encoder.is_bidirectional()) else: summary = masked_mean(encoder_outputs, source_mask.unsqueeze(-1).to( encoder_outputs.device), dim=1, keepdim=False) return summary
def pool(vector: torch.Tensor, mask: torch.Tensor, dim: int, pooling: str, is_bidirectional: bool) -> torch.Tensor: if pooling == "max": return masked_max(vector, mask, dim) elif pooling == "mean": return masked_mean(vector, mask, dim) elif pooling == "sum": return torch.sum(vector, dim) elif pooling == "final": return get_final_encoder_states(vector, mask, is_bidirectional) else: raise ValueError(f"'{pooling}' is not a valid pooling operation.")
def forward(self, text: Dict[str, torch.LongTensor], metadata=None, page: torch.IntTensor = None): # pylint: disable=arguments-differ input_ids: torch.LongTensor = text["text"] # Grab the representation of CLS token, which is always first if self._pool == "cls": bert_emb = self._bert(input_ids)[:, 0, :] elif self._pool == "mean": mask = (input_ids != 0).long()[:, :, None] bert_seq_emb = self._bert(input_ids) bert_emb = util.masked_mean(bert_seq_emb, mask, dim=1) else: raise ValueError("Invalid config") return self._hidden_to_output(bert_emb, page)
def find_max_window(p_prob, mask, offset): batch_size = p_prob.size(0) out_idx = [] for b in range(batch_size): mean_prob = allenutil.masked_mean(p_prob[b], mask[b], dim=-1) max_idx = np.argmax(p_prob[b].detach().cpu().numpy()) # There are many possible ways to determine max_id, the above method is simply choosing the highest probability. # But you can use some other ideas, like calculating the total probability of a 3-gram window instead of each token. # max_idx = find_max_ind(p_prob[b]) max_value = p_prob[b][max_idx] start = find_surrounding_with_max(p_prob[b], max_idx, max(4 * mean_prob, 0.0), 'L') end = find_surrounding_with_max(p_prob[b], max_idx, max(4 * mean_prob, 0.0), 'R') start += offset[b].tolist() end += offset[b].tolist() out_idx.append([start, end]) return out_idx
def pool_node_embeddings(self, last_layers, masks, gdata, batch_num_nodes): """ Convert wordpiece embeddings into word (i.e. node) embeddings using the alignment in wpidx2graphid = gdata['wpidx2graphid'] Parameters: g_data: dictinoary with values having shape: (bsz, ...) masks: (bsz, max_sent_pair_len) last_layers: (bsz, max_sent_pair_len, emb_dim) Returns: node_embs: (bsz, max_num_nodes, emb_dim) node_embeddings_mask: (bsz, max_num_nodes) """ wpidx2graphid = gdata['wpidx2graphid'] # (bsz, max_sent_len, max_n_nodes) device = last_layers.device bsz, max_sent_len, max_n_nodes = wpidx2graphid.shape emb_dim = last_layers.shape[-1] assert max(batch_num_nodes) == wpidx2graphid.shape[-1] # the following logic happens to work if the graph is empty, in which case its sentence_end is guaranteed to be 1 (exclusive) masks_cumsum = masks.cumsum(1) sentence_starts = first_true_idx(masks, 1, masks_cumsum) sentence_ends = last_true_idx(masks, 1, masks_cumsum) + 1 # exclusive max_sentence_len = (sentence_ends - sentence_starts).max() # we're using a for loop here since only doing rolling across the batch dimension shouldn't be very expensive # that said, can we do it without a loop? rolled_last_layers = torch.stack([last_layer.roll(-sentence_start.item(), dims=0) for last_layer, sentence_start in zip(last_layers, sentence_starts)]) segmented_last_layers = rolled_last_layers[:, :max_sentence_len, :] # (bsz, max_sent_len, emb_dim) assert segmented_last_layers.shape[:2] == wpidx2graphid.shape[:2] # (bsz, max_sent_len, max_n_nodes, emb_dim) expanded_wpidx2graphid = wpidx2graphid.unsqueeze(-1).expand(-1, -1, -1, emb_dim) expanded_segmented_last_layers = segmented_last_layers.unsqueeze(2).expand(-1, -1, max_n_nodes, -1) # (bsz, max_n_nodes, emb_dim) node_embeddings = masked_mean(expanded_segmented_last_layers, expanded_wpidx2graphid, 1) node_embeddings = torch.where(expanded_wpidx2graphid.any(1), node_embeddings, torch.tensor(0., device=device)) # some nodes don't have corresponding wordpieces node_embeddings_mask = torch.arange(max(batch_num_nodes), device=device).expand(bsz, -1) < torch.tensor(batch_num_nodes, dtype=torch.long, device=device).unsqueeze(1) return node_embeddings, node_embeddings_mask
def forward(self, document, query=None, label=None, metadata=None, rationale=None, **kwargs) -> Dict[str, Any]: #pylint: disable=arguments-differ bert_document = self.combine_document_query(document, query) last_hidden_states, _ = self._bert_model( bert_document["bert"]["wordpiece-ids"], attention_mask=bert_document["bert"]["wordpiece-mask"], position_ids=bert_document["bert"]["position-ids"], token_type_ids=bert_document["bert"]["type-ids"], ) token_embeddings, span_mask = generate_embeddings_for_pooling( last_hidden_states, bert_document["bert"]['document-starting-offsets'], bert_document["bert"]['document-ending-offsets']) token_embeddings = util.masked_mean(token_embeddings, span_mask.unsqueeze(-1), dim=2) token_embeddings = token_embeddings * bert_document['bert'][ "mask"].unsqueeze(-1) logits = torch.nn.functional.softplus( self._classification_layer(self._dropout(token_embeddings))) a, b = logits[:, :, 0], logits[:, :, 1] mask = bert_document['bert']['mask'] output_dict = {} output_dict["a"] = a * mask output_dict["b"] = b * mask output_dict['mask'] = mask output_dict['wordpiece-to-token'] = bert_document['bert'][ 'wordpiece-to-token'] return output_dict
def _average_image_features( self, image_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r""" Perform mean pooling of bottom-up image features, while taking care of variable ``num_boxes`` in case of adaptive features. Extended Summary ---------------- For a single training/evaluation instance, the image features remain the same from first time-step to maximum decoding steps. To keep a clean API, we use LRU cache -- which would maintain a cache of last 10 return values because on call signature, and not actually execute itself if it is called with the same image features seen at least once in last 10 calls. This saves some computation. Parameters ---------- image_features: torch.Tensor A tensor of shape ``(batch_size, num_boxes, image_feature_size)``. ``num_boxes`` for each instance in a batch might be different. Instances with lesser boxes are padded with zeros up to ``num_boxes``. Returns ------- Tuple[torch.Tensor, torch.Tensor] Averaged image features of shape ``(batch_size, image_feature_size)`` and a binary mask of shape ``(batch_size, num_boxes)`` which is zero for padded features. """ # shape: (batch_size, num_boxes) image_features_mask = torch.sum(torch.abs(image_features), dim=-1) > 0 # shape: (batch_size, image_feature_size) averaged_image_features = masked_mean( image_features, image_features_mask.unsqueeze(-1), dim=1) return averaged_image_features, image_features_mask
def forward(self, document, rationale=None) -> Dict[str, Any]: embedded_text = self._text_field_embedder(document) mask = util.get_text_field_mask(document).float() embedded_text = self._dropout( self._seq2seq_encoder(embedded_text, mask=mask)) embedded_text = self._feedforward_encoder(embedded_text) logits = self._classification_layer(embedded_text).squeeze(-1) probs = torch.sigmoid(logits) output_dict = {} predicted_rationale = (probs > 0.5).long() output_dict['predicted_rationale'] = predicted_rationale * mask output_dict["prob_z"] = probs * mask class_probs = torch.cat([1 - probs.unsqueeze(-1), probs.unsqueeze(-1)], dim=-1) average_rationale_length = util.masked_mean( output_dict['predicted_rationale'], mask, dim=-1).mean() self._rationale_length(average_rationale_length.item()) if rationale is not None: rationale_loss = F.binary_cross_entropy_with_logits( logits, rationale.float(), weight=mask) output_dict['rationale_supervision_loss'] = rationale_loss output_dict['gold_rationale'] = rationale * mask self._rationale_f1_metric(predictions=class_probs, gold_labels=rationale, mask=mask) self._rationale_supervision_loss(rationale_loss.item()) return output_dict
def forward(self, definition: Dict[str, torch.LongTensor], word: Dict[str, torch.LongTensor] = None, word_to_definition: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]: output_dict = {} output_dict.update(self._encode_definition(definition)) output_dict['loss'] = self.alpha * output_dict['cross_entropy_loss'] if self.beta > 0 and word is not None: # [batch_size, seq_len(1)] word_in_definition_mask = (word_to_definition != self._oov_index).float() # [batch_size] word_in_definition_mask = word_in_definition_mask.squeeze(dim=1) # [batch_size, seq_len(1), emb_dim] embedded_word = self.text_embedder({'tokens': word_to_definition}) # [batch_size, emb_dim] embedded_word = embedded_word.squeeze(dim=1) mse = self.pdist(output_dict['definition_embedding'], embedded_word) consistency_loss = util.masked_mean(mse, word_in_definition_mask, dim=0) output_dict['consistency_loss'] = consistency_loss output_dict['loss'] += self.beta * output_dict['consistency_loss'] for metric in self.metrics.values(): metric(output_dict['definition_embedding'], embedded_word, word_in_definition_mask) return output_dict
def forward(self, document, query=None, label=None, metadata=None, rationale=None) -> Dict[str, Any]: # pylint: disable=arguments-differ generator_dict = self._generator(document, query, label) mask = generator_dict["mask"] assert "a" in generator_dict assert "b" in generator_dict a, b = generator_dict['a'], generator_dict['b'] a = a.clamp(1e-6, 100.) # extreme values could result in NaNs b = b.clamp(1e-6, 100.) # extreme values could result in NaNs output_dict = {} sampler = HardKuma([a, b], support=[ self.support[0].to(a.device), self.support[1].to(b.device) ]) generator_dict['predicted_rationale'] = (sampler.mean() > 0.5).long() * mask if self.prediction_mode or not self.training: if self._rationale_extractor is None: # We constrain rationales to be 0 or 1 strictly. See Pruthi et al # for pathologies when this is not the case. sample_z = (sampler.mean() > 0.5).long() * mask else: prob_z = sampler.mean() sample_z = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=True) output_dict[ "rationale"] = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=False) sample_z = torch.Tensor(sample_z).to(prob_z.device).float() else: sample_z = sampler.sample() sample_z = sample_z * mask # Because BERT is BERT wordpiece_to_token = generator_dict['wordpiece-to-token'] wtt0 = torch.where(wordpiece_to_token == -1, torch.tensor([0]).to(wordpiece_to_token.device), wordpiece_to_token) wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1), wtt0) wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0 def scale_embeddings(module, input, output): output = output * wordpiece_sample return output hook = self._encoder.embedding_layers[0].register_forward_hook( scale_embeddings) encoder_dict = self._encoder( document=document, query=query, label=label, metadata=metadata, ) hook.remove() loss = 0.0 if label is not None: assert "loss" in encoder_dict base_loss = F.cross_entropy(encoder_dict["logits"], label) # (B,) lasso_loss = ((1 - sampler.pdf(0.)) * mask).sum(1) lengths = mask.sum(1) lasso_loss = lasso_loss / (lengths + 1e-9) censored_lasso_loss = F.relu(lasso_loss / (lengths + 1e-9) - self._desired_length) censored_lasso_loss = censored_lasso_loss.mean() # diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs() # mask_last = mask[:, :-1] # fused_lasso_loss = diff.sum(-1) / mask_last.sum(-1) self._loss_tracks["_lasso_loss"](lasso_loss.mean().item()) self._loss_tracks["_censored_lasso_loss"]( censored_lasso_loss.mean().item()) # self._loss_tracks["_fused_lasso_loss"](fused_lasso_loss.mean().item()) self._loss_tracks["_base_loss"](base_loss.mean().item()) generator_loss = self._reg_loss_lambda * censored_lasso_loss self._loss_tracks["_generator_loss"](generator_loss.mean().item()) loss += (base_loss + generator_loss).mean() output_dict["probs"] = encoder_dict["probs"] output_dict["predicted_labels"] = encoder_dict["predicted_labels"] output_dict["loss"] = loss output_dict["gold_labels"] = label output_dict["metadata"] = metadata output_dict["predicted_rationale"] = generator_dict[ "predicted_rationale"] self._loss_tracks["_rat_length"](util.masked_mean( generator_dict["predicted_rationale"], mask, dim=-1).mean().item()) self._call_metrics(output_dict) return output_dict
def forward( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, len_q=None, bs_seperator_index=None, s_first=True, max_q_length=30, max_s_length=200, max_b_length=400, sp_relevance=None, sp_tp_polarity=None, tp_relevance=None, object1_label=None, object2_label=None, SP_Object1_label=None, SP_Object2_label=None, SP_Back_label=None, TP_Back_label=None, ): outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) torch.set_printoptions(precision=8, sci_mode=False) # Step1 Encoder, we need to split Background , situation and question from the whole contextual representation, then pad them into a fixed length. sequence_output = outputs[0] batch_size = sequence_output.size(0) max_seq_length = sequence_output.size(1) hidden_size = sequence_output.size(2) device = outputs[0].device padded_HQ = torch.zeros([batch_size, max_q_length, hidden_size]).to(device) padded_mask_HQ = torch.zeros([batch_size, max_q_length]).to(device) padded_HS = torch.zeros([batch_size, max_s_length, hidden_size]).to(device) padded_mask_HS = torch.zeros([batch_size, max_s_length]).to(device) padded_HB = torch.zeros([batch_size, max_b_length, hidden_size]).to(device) padded_mask_HB = torch.zeros([batch_size, max_b_length]).to(device) s_inds = [] for ind in range(batch_size): try: mask_ind = (attention_mask[ind] == 0).nonzero()[0][0] - 1 except: mask_ind = max_seq_length - 1 HQ, padded_h_q, padded_mask_q = pad_hiddenstate( sequence_output[ind, 1:1 + len_q[ind], :], max_q_length) if s_first: HS, padded_h_s, padded_mask_s = pad_hiddenstate( sequence_output[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1, :], max_s_length) HB, padded_h_b, padded_mask_b = pad_hiddenstate( sequence_output[ind, bs_seperator_index[ind] + 1:mask_ind, :], max_b_length) else: HB, padded_h_b, padded_mask_b = pad_hiddenstate( sequence_output[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1, :], max_s_length) HS, padded_h_s, padded_mask_s = pad_hiddenstate( sequence_output[ind, bs_seperator_index[ind] + 1:mask_ind, :], max_b_length) s_inds.append(bs_seperator_index[ind] - 2 - len_q[ind]) padded_HQ[ind, :, :] = padded_h_q padded_mask_HQ[ind, :] = padded_mask_q padded_HS[ind, :, :] = padded_h_s padded_mask_HS[ind, :] = padded_mask_s padded_HB[ind, :, :] = padded_h_b padded_mask_HB[ind, :] = padded_mask_b # auxiliary labels also need padding. padded_O1 = torch.zeros([batch_size, max_s_length]).to(device) padded_O2 = torch.zeros([batch_size, max_s_length]).to(device) padded_SP_o1 = torch.zeros([batch_size, max_s_length]).to(device) padded_SP_o2 = torch.zeros([batch_size, max_s_length]).to(device) padded_SP = torch.zeros([batch_size, max_b_length]).to(device) padded_TP = torch.zeros([batch_size, max_b_length]).to(device) for ind in range(batch_size): try: mask_ind = (attention_mask[ind] == 0).nonzero()[0][0] except: mask_ind = max_seq_length - 1 if s_first: _, padded_o1 = pad_supervison_label( object1_label[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1], max_s_length) if object1_label is not None else [ None, None ] _, padded_o2 = pad_supervison_label( object2_label[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1], max_s_length) if object2_label is not None else [ None, None ] _, padded_sp_o1 = pad_supervison_label( SP_Object1_label[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1], max_s_length) if SP_Object1_label is not None else [ None, None ] _, padded_sp_o2 = pad_supervison_label( SP_Object2_label[ind, len_q[ind] + 3:bs_seperator_index[ind] + 1], max_s_length) if SP_Object2_label is not None else [ None, None ] _, padded_sp = pad_supervison_label( SP_Back_label[ind, bs_seperator_index[ind] + 1:mask_ind], max_b_length) if SP_Back_label is not None else [ None, None ] _, padded_tp = pad_supervison_label( TP_Back_label[ind, bs_seperator_index[ind] + 1:mask_ind], max_b_length) if TP_Back_label is not None else [ None, None ] padded_O1[ ind, :] = padded_o1 if padded_o1 is not None else padded_O1[ ind, :] padded_O2[ ind, :] = padded_o2 if padded_o2 is not None else padded_O2[ ind, :] padded_SP_o1[ ind, :] = padded_sp_o1 if padded_sp_o1 is not None else padded_SP_o1[ ind, :] padded_SP_o2[ ind, :] = padded_sp_o2 if padded_sp_o2 is not None else padded_SP_o2[ ind, :] padded_SP[ ind, :] = padded_sp if padded_sp is not None else padded_SP[ ind, :] padded_TP[ ind, :] = padded_tp if padded_tp is not None else padded_TP[ ind, :] # **************************** STEP 2 Find OBJECT/World **************************** # [b,n,d] -> [b,n,1] ps_object1 = self.find_object1(padded_HS) ps_object2 = self.find_object2(padded_HS) ps_object1 = allenutil.masked_softmax(ps_object1.squeeze(), padded_mask_HS, memory_efficient=True) ps_object2 = allenutil.masked_softmax(ps_object2.squeeze(), padded_mask_HS, memory_efficient=True) # ****************************STEP 3 Find TP/Effect in B **************************** # [b,m,d] -> [b,m,1] pb_TP = self.find_TP(padded_HB) pb_TP = allenutil.masked_softmax(pb_TP.squeeze(), padded_mask_HB, memory_efficient=True) # ****************************STEP 4 Relocate TP/Effect to SP/Cause **************************** mean_HS = allenutil.masked_mean(padded_HS, padded_mask_HS.unsqueeze(-1), dim=1) relocate_bb_similarity_matrix = self.bb_matrix_attention( torch.add(mean_HS.unsqueeze(1), padded_HB), padded_HB) b2b_attention_matrix = allenutil.masked_softmax( relocate_bb_similarity_matrix, padded_mask_HB, memory_efficient=True, dim=-1) pb_SP = torch.sum(torch.mul(pb_TP.unsqueeze(-1), b2b_attention_matrix), dim=1) # if we dont have labels, we can comment out this two lines padded_TP_normal = torch.nn.functional.normalize(padded_TP, p=1, dim=-1) pb_SP_gold = torch.sum(torch.mul(padded_TP_normal.unsqueeze(-1), b2b_attention_matrix), dim=1) # *************************Step 5: Find SP/cause for object/world 1 and object/world 2 **************************** # Explained in Comparison module, treat two worlds as masks. s2b_similarity_matrix = self.bs_bilinear_imilairty( padded_HS, padded_HB) s2b_similarity_attention = allenutil.masked_softmax( s2b_similarity_matrix, padded_mask_HB, memory_efficient=True, dim=1) b2s_similarity_matrix = torch.transpose(s2b_similarity_matrix, 1, 2) ps_guided_SP = torch.sum(torch.mul(pb_SP.unsqueeze(-1), b2s_similarity_matrix), dim=1) mask_score_object1 = ps_object1 mask_score_object2 = ps_object2 ps_SP_object1 = torch.mul(mask_score_object1, ps_guided_SP) ps_SP_object1 = allenutil.masked_softmax(ps_SP_object1, padded_mask_HS, memory_efficient=True, dim=-1) ps_SP_object2 = torch.mul(mask_score_object2, ps_guided_SP) ps_SP_object2 = allenutil.masked_softmax(ps_SP_object2, padded_mask_HS, memory_efficient=True, dim=-1) # gold label if we have. padded_SP_normal = torch.nn.functional.normalize(padded_SP, p=1, dim=-1) ps_guided_SP_gold = torch.sum(torch.mul(padded_SP_normal.unsqueeze(-1), b2s_similarity_matrix), dim=1) mask_score_object1_gold = torch.nn.functional.normalize(padded_O1 + 0.01, p=1, dim=-1) mask_score_object2_gold = torch.nn.functional.normalize(padded_O2 + 0.01, p=1, dim=-1) ps_SP_object1_gold = torch.mul(mask_score_object1_gold, ps_guided_SP_gold) ps_SP_object1_gold = allenutil.masked_softmax(ps_SP_object1_gold, padded_mask_HS, memory_efficient=True, dim=-1) ps_SP_object2_gold = torch.mul(mask_score_object2_gold, ps_guided_SP_gold) ps_SP_object2_gold = allenutil.masked_softmax(ps_SP_object2_gold, padded_mask_HS, memory_efficient=True, dim=-1) # ****************************Step 6 relevance/comparison check **************************** summed_HB_weighted_pb_SP = torch.matmul( pb_SP.unsqueeze(1), padded_HB) # 1XMXMXD => [B,1,D] summed_HS_weighted_ps_SP_o1 = torch.matmul( ps_SP_object1.unsqueeze(1), padded_HS) # 1XNXNXD => [B,1,D] summed_HS_weighted_ps_SP_o2 = torch.matmul( ps_SP_object2.unsqueeze(1), padded_HS) # 1XNXNXD => [B,1,D] p_relevance_logits = self.rel_SPo1_SPo2(summed_HB_weighted_pb_SP, summed_HS_weighted_ps_SP_o1, summed_HS_weighted_ps_SP_o2) normal_p_relevance_logits = torch.nn.functional.normalize( p_relevance_logits, p=1) p_relevance = torch.softmax(normal_p_relevance_logits, dim=-1) # gold label if we have. padded_SP_o1_normal = torch.nn.functional.normalize(padded_SP_o1, p=1, dim=-1) padded_SP_o2_normal = torch.nn.functional.normalize(padded_SP_o2, p=1, dim=-1) GOLD_summed_HB_weighted_pb_SP = torch.matmul( padded_SP_normal.unsqueeze(1).type(dtype=torch.float), padded_HB) # 1XMXMXD => [B,1,D] GOLD_summed_HS_weighted_ps_SP_o1 = torch.matmul( padded_SP_o1_normal.unsqueeze(1).type(dtype=torch.float), padded_HS) # 1XNXNXD => [B,1,D] GOLD_summed_HS_weighted_ps_SP_o2 = torch.matmul( padded_SP_o2_normal.unsqueeze(1).type(dtype=torch.float), padded_HS) # 1XNXNXD => [B,1,D] p_relevance_logits_gold = self.rel_SPo1_SPo2( GOLD_summed_HB_weighted_pb_SP, GOLD_summed_HS_weighted_ps_SP_o1, GOLD_summed_HS_weighted_ps_SP_o2) normal_p_relevance_logits_gold = torch.nn.functional.normalize( p_relevance_logits_gold, p=1) p_relevance_gold = torch.softmax(normal_p_relevance_logits_gold, dim=-1) # ****************************Step 7 relation classification/polarity **************************** summed_HB_weighted_pb_TP = torch.matmul(pb_TP.unsqueeze(1), padded_HB) summed_HB_weighted_TP_SP = torch.cat( (summed_HB_weighted_pb_SP, summed_HB_weighted_pb_TP), dim=-1).squeeze(1) p_polarity_logits = self.pol_TP_SP(summed_HB_weighted_TP_SP) p_polarity = torch.softmax(p_polarity_logits, dim=-1) p_polarity_negative = p_polarity[:, 0] p_polarity_positive = p_polarity[:, 1] # gold label if we have. summed_HB_weighted_pb_TP_gold = torch.matmul( padded_TP_normal.unsqueeze(1), padded_HB) summed_HB_weighted_TP_SP_gold = torch.cat( (GOLD_summed_HB_weighted_pb_SP, summed_HB_weighted_pb_TP_gold), dim=-1).squeeze(1) p_polarity_logits_gold = self.pol_TP_SP(summed_HB_weighted_TP_SP_gold) p_polarity_gold = torch.softmax(p_polarity_logits_gold, dim=-1) p_polarity_negative_gold = p_polarity_gold[:, 0] p_polarity_positive_gold = p_polarity_gold[:, 1] # ****************************Step 8 Reasoning **************************** object1 = p_relevance[:, 0] object2 = p_relevance[:, 1] p_TP_object1 = p_polarity_positive * object1 + p_polarity_negative * object2 p_TP_object2 = p_polarity_negative * object1 + p_polarity_positive * object2 p_TP_objects = torch.stack((p_TP_object1, p_TP_object2), dim=1).squeeze() # gold label if we have. object1_gold = p_relevance_gold[:, 0] object2_gold = p_relevance_gold[:, 1] p_TP_object1_gold = p_polarity_positive_gold * object1_gold + p_polarity_negative_gold * object2_gold p_TP_object2_gold = p_polarity_negative_gold * object1_gold + p_polarity_positive_gold * object2_gold p_TP_objects_gold = torch.stack((p_TP_object1_gold, p_TP_object2_gold), dim=1).squeeze() try: assert torch.sum(padded_O1) != 0 assert torch.sum(padded_O2) != 0 assert torch.sum(padded_SP_o1) != 0 assert torch.sum(padded_SP_o2) != 0 assert torch.sum(padded_SP) != 0 assert torch.sum(padded_TP) != 0 except: pass loss_o1 = compute_loss( ps_object1, padded_O1, "find_object1") if object1_label is not None else 0.0 loss_o2 = compute_loss( ps_object2, padded_O2, "find_object2") if object2_label is not None else 0.0 loss_TP = compute_loss(pb_TP, padded_TP, "find_TP") if TP_Back_label is not None else 0.0 loss_SP = compute_loss(pb_SP, padded_SP, "find_SP") if SP_Back_label is not None else 0.0 loss_SP_o1 = compute_loss( ps_SP_object1, padded_SP_o1, "find_SP_object1") if SP_Object1_label is not None else 0.0 loss_SP_o2 = compute_loss( ps_SP_object2, padded_SP_o2, "find_SP_object2") if SP_Object2_label is not None else 0.0 loss_rel = compute_loss( normal_p_relevance_logits, sp_relevance, "relevance") if sp_relevance is not None else 0.0 loss_pol = compute_loss( p_polarity_logits, sp_tp_polarity, "polarity") if sp_tp_polarity is not None else 0.0 loss_on_TP = compute_loss( p_TP_objects, tp_relevance, "TP_relevance") if tp_relevance is not None else 0.0 loss2_SP = compute_loss( pb_SP_gold, padded_SP, "find_SP") if SP_Back_label is not None else 0.0 loss2_SP_o1 = compute_loss( ps_SP_object1_gold, padded_SP_o1, "find_SP_object1") if SP_Object1_label is not None else 0.0 loss2_SP_o2 = compute_loss( ps_SP_object2_gold, padded_SP_o2, "find_SP_object2") if SP_Object2_label is not None else 0.0 loss2_rel = compute_loss( normal_p_relevance_logits_gold, sp_relevance, "relevance") if sp_relevance is not None else 0.0 loss2_pol = compute_loss( p_polarity_logits_gold, sp_tp_polarity, "polarity") if sp_tp_polarity is not None else 0.0 loss2_on_TP = compute_loss( p_TP_objects, tp_relevance, "TP_relevance") if tp_relevance is not None else 0.0 out = { "object1": (loss_o1).tolist(), "object2": (loss_o2).tolist(), "TP": (loss_TP).tolist(), "SP": (loss_SP).tolist(), "loss_SP_o1": (loss_SP_o1).tolist(), "loss_SP_o2": (loss_SP_o2).tolist(), "loss_rel": (loss_rel).tolist(), "loss_pol": (loss_pol).tolist(), "loss_on_TP": (loss_on_TP).tolist(), } # Loss function, play around it. loss = 0.05 * loss_o1 + 0.05 * loss_o2 + 0.05 * loss_SP + 0.05 * loss_TP + 0.05 * loss_SP_o1 + 0.05 * loss_SP_o2 + 0.2 * loss_pol + 0.2 * loss_rel + 0.3 * loss_on_TP # The following part works as: returning the necessary numbers for predicting intermediate output for each modules. object1_ind = find_max_window(ps_object1, padded_mask_HS, offset=len_q + 3) object2_ind = find_max_window(ps_object2, padded_mask_HS, offset=len_q + 3) TP_ind = find_max_window(pb_TP, padded_mask_HB, offset=bs_seperator_index + 1) try: SP_object1_ind = find_max_window(ps_SP_object1, padded_mask_HS, offset=len_q + 3) SP_object2_ind = find_max_window(ps_SP_object2, padded_mask_HS, offset=len_q + 3) SP_ind = find_max_window(pb_SP, padded_mask_HB, offset=bs_seperator_index + 1) except: SP_object1_ind = [1, 1] SP_object2_ind = [1, 1] SP_ind = [1, 1] predict = { "p_o1": ps_object1.tolist(), "p_o2": ps_object2.tolist(), "p_TP": pb_TP.tolist(), "p_SP": pb_SP.tolist(), "p_sp_o1": ps_SP_object1.tolist(), "p_sp_o2": ps_SP_object2.tolist(), "object1": object1_ind, "object2": object2_ind, "TP": TP_ind, "SP": SP_ind, "SP_o1": SP_object1_ind, "SP_o2": SP_object2_ind, "relevance": p_relevance, "polarity": p_polarity, "tp_relevance": p_TP_objects, } output = [loss, out, predict] return output
def forward(self, document, query=None, label=None, metadata=None) -> Dict[str, Any]: generator_dict = self._generator(document) mask = util.get_text_field_mask(document) assert "a" in generator_dict assert "b" in generator_dict a, b = generator_dict['a'], generator_dict['b'] a = a.clamp(1e-6, 100.) # extreme values could result in NaNs b = b.clamp(1e-6, 100.) # extreme values could result in NaNs output_dict = {} sampler = HardKuma([a, b], support=[ self.support[0].to(a.device), self.support[1].to(b.device) ]) generator_dict['predicted_rationale'] = (sampler.mean() > 0.5).long() * mask if self.prediction_mode or not self.training: if self._rationale_extractor is None: sample_z = (sampler.mean() > 0.5).long() * mask else: prob_z = sampler.mean() sample_z = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=True) output_dict[ "rationale"] = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=False) sample_z = torch.Tensor(sample_z).to(prob_z.device).float() else: sample_z = sampler.sample() sample_z = sample_z * mask wordpiece_to_token = document['bert']['wordpiece-to-token'] wtt0 = torch.where(wordpiece_to_token == -1, torch.tensor([0]).to(wordpiece_to_token.device), wordpiece_to_token) wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1), wtt0) wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0 def scale_embeddings(module, input, output): output = output * wordpiece_sample return output hook = self._encoder._embedding_layer.register_forward_hook( scale_embeddings) encoder_dict = self._encoder( document=document, query=query, label=label, metadata=metadata, ) hook.remove() loss = 0.0 if label is not None: assert "loss" in encoder_dict base_loss = F.cross_entropy(encoder_dict["logits"], label) # (B,) lasso = ((1 - sampler.pdf(0.)) * mask).sum(1) lengths = mask.sum(1) sparsity_loss = lasso / (lengths + 1e-9) - self._desired_length sparsity_loss = sparsity_loss.mean() self._loss_tracks["_lasso_loss"](sparsity_loss.item()) # # moving average of the constraint # self.sparsity_ma = self.lagrange_alpha * self.sparsity_ma + (1 - self.lagrange_alpha) * sparsity_loss.item() # # update lambda # self.lambda0 = self.lambda0 * torch.exp(self.lagrange_lr * self.sparsity_ma.detach()) self._loss_tracks["_base_loss"](base_loss.item()) # self._loss_tracks["_fused_lasso_loss"](self.lambda0.item()) # loss += base_loss + min(max(self.lambda0.detach().item(), 0.01), 1.0) * sparsity_loss loss += base_loss + self._reg_loss_lambda * sparsity_loss output_dict["probs"] = encoder_dict["probs"] output_dict["predicted_labels"] = encoder_dict["predicted_labels"] output_dict["loss"] = loss output_dict["gold_labels"] = label output_dict["metadata"] = metadata output_dict["predicted_rationale"] = generator_dict[ "predicted_rationale"] self._loss_tracks["_rat_length"](util.masked_mean( generator_dict["predicted_rationale"], mask, dim=-1).mean().item()) self._call_metrics(output_dict) return output_dict
def forward( self, # type: ignore inputs: torch.FloatTensor, mask: torch.FloatTensor): """ Parameters ---------- inputs : ``torch.FloatTensor`` A tensor of shape (batch_size, seq_len, hidden_size) mask : ``torch.FloatTensor`` A tensor of shape (batch_size, seq_len) Returns ------- An output dictionary consisting of: hiddens: ``torch.FloatTensor`` A tensor of shape (batch_size, seq_len, hidden_size) """ batch_size, _, hidden_size = inputs.size() # filters for attention mask = mask.unsqueeze(-1) ############################################################################ # Init states ############################################################################ # randomly initialize the states hidden = torch.rand_like(inputs) - 0.5 cell = torch.rand_like(inputs) - 0.5 global_hidden = masked_mean(hidden, mask, dim=1) global_cell = masked_mean(cell, mask, dim=1) for _ in range(self.num_layers): ############################# # update global node states # ############################# hidden_avg = masked_mean(hidden, mask, dim=1) projected_input = self.g_input_linearity(global_hidden) projected_hiddens = self.g_hidden_linearity(hidden) projected_avg = self.g_avg_linearity(hidden_avg) input_gate = torch.sigmoid(self.layer_norms[0]( projected_input[:, 0 * hidden_size:1 * hidden_size] + projected_avg[:, 0 * hidden_size:1 * hidden_size])) hidden_gates = torch.sigmoid(self.layer_norms[1]( projected_input[:, 1 * hidden_size:2 * hidden_size].unsqueeze(1).expand_as(hidden) + projected_hiddens)) output_gate = torch.sigmoid(self.layer_norms[2]( projected_input[:, 2 * hidden_size:3 * hidden_size] + projected_avg[:, 1 * hidden_size:2 * hidden_size])) masked_hidden_gates = hidden_gates.masked_fill((1 - mask).byte(), -1e32) all_gates = torch.cat( [input_gate.unsqueeze(1), masked_hidden_gates], dim=1) gates_normalized = torch.nn.functional.softmax(all_gates, dim=1) input_gate_normalized = gates_normalized[:, 0, :] hidden_gates_normalized = gates_normalized[:, 1:, :] # new global states global_cell = (hidden_gates_normalized * cell).sum(1) + \ global_cell * input_gate_normalized global_hidden = output_gate * torch.tanh(global_cell) ############################# # update hidden node states # ############################# # Note: add <bos> and <eos> before hand in case that the valid words are omitted! hidden_l = torch.cat([ hidden.new_zeros(batch_size, 1, hidden_size), hidden[:, :-1, :] ], dim=1) hidden_r = torch.cat([ hidden[:, 1:, :], hidden.new_zeros(batch_size, 1, hidden_size) ], dim=1) cell_l = torch.cat( [cell.new_zeros(batch_size, 1, hidden_size), cell[:, :-1, :]], dim=1) cell_r = torch.cat( [cell[:, 1:, :], cell.new_zeros(batch_size, 1, hidden_size)], dim=1) # concat with neighbors contexts = torch.cat([hidden_l, hidden_r], dim=-1) projected_contexts = self.h_context_linearity(contexts) projected_current = self.h_current_linearity(hidden) projected_input = self.h_input_linearity(inputs) projected_global = self.h_global_linearity(global_hidden) gates = [] for offset in range(6): gates.append( torch.sigmoid( self.layer_norms[offset + 3] (projected_contexts[..., offset * hidden_size: (offset + 1) * hidden_size] + projected_current[..., offset * hidden_size: (offset + 1) * hidden_size] + projected_input[..., offset * hidden_size: (offset + 1) * hidden_size] + projected_global[ ..., offset * hidden_size:(offset + 1) * hidden_size].unsqueeze(1).expand_as(inputs)))) memory_init = torch.tanh(self.layer_norms[-1]( projected_contexts[..., 6 * hidden_size:7 * hidden_size] + projected_current[..., 6 * hidden_size:7 * hidden_size] + projected_input[..., 6 * hidden_size:7 * hidden_size] + projected_global[..., 6 * hidden_size:7 * hidden_size].unsqueeze(1).expand_as(inputs))) # gate: batch x seq_len x hidden_size gates_normalized = F.softmax(torch.stack(gates[:-1]), dim=0) input_gate = gates_normalized[0, ...] left_gate = gates_normalized[1, ...] right_gate = gates_normalized[2, ...] forget_gate = gates_normalized[3, ...] global_gate = gates_normalized[4, ...] output_gate = gates[-1] cell = left_gate * cell_l +\ right_gate * cell_r +\ forget_gate * cell +\ input_gate * memory_init +\ global_gate * global_cell.unsqueeze(1).expand_as(global_gate) hidden = output_gate * torch.tanh(cell) hidden = hidden * mask cell = cell * mask return hidden
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, store_metrics: bool = True, valid_output_mask: torch.LongTensor = None, sent_targets: torch.Tensor = None, stance: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. store_metrics : bool If true, stores metrics (if applicable) within model metric tracker. If false, returns resulting metrics immediately, without updating the model metric tracker. valid_output_mask: ``torch.LongTensor``, optional The locations for a valid answer. Used to limit the model's output space. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Debate: Conditioning on whose turn it is (A/B) if not self.is_judge: turn_film_params = self._turn_film_gen( stance.to(final_merged_passage).unsqueeze(1)) turn_gammas, turn_betas = torch.split( turn_film_params, self._modeling_layer.get_input_dim(), dim=-1) final_merged_passage_mask = ( final_merged_passage != 0).float() # NOTE: Using heuristic to get mask final_merged_passage = self._film( final_merged_passage, 1. + turn_gammas, turn_betas) * final_merged_passage_mask modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input_full = torch.cat( [final_merged_passage, modeled_passage], dim=-1) span_start_input = self._dropout(span_start_input_full) if not self.is_judge: value_head_input = span_start_input_full.detach( ) if self._detach_value_head else span_start_input_full # Shape: (batch_size) tokenwise_values = self._value_head(value_head_input).squeeze(-1) value, value_loc = util.replace_masked_values( tokenwise_values, passage_mask, -1e7).max(-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) valid_output_mask = passage_mask if valid_output_mask is None else valid_output_mask # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, valid_output_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, valid_output_mask) span_start_logits = util.replace_masked_values(span_start_logits, valid_output_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, valid_output_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, "value": value if not self.is_judge else None, "prob": torch.tensor([ span_start_probs[i, span_start[i]] if span_start[i] < span_start_probs.size(1) else 0. for i in range(batch_size) ]) if self.is_judge else None, # prob(true answer) "prob_dist": span_start_probs, } # Compute the loss for training. if (span_start is not None) and self.is_judge: span_start[span_start >= passage_mask.size( 1)] = -100 # NB: Hacky. Don't add to loss if span not in input loss = nll_loss( util.masked_log_softmax(span_start_logits, valid_output_mask), span_start.squeeze(-1)) if store_metrics: self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) span_end[span_end >= passage_mask.size( 1)] = -100 # NB: Hacky. Don't add to loss if span not in input loss += nll_loss( util.masked_log_softmax(span_end_logits, valid_output_mask), span_end.squeeze(-1)) if store_metrics: self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss elif not self.is_judge: # Debate SL if self.reward_method == 'sl': # sent_targets should be a vector of target indices output_dict["loss"] = nll_loss( util.masked_log_softmax(span_start_logits, valid_output_mask), sent_targets.squeeze(-1)) if store_metrics: self._span_start_accuracy(span_start_logits, sent_targets.squeeze(-1)) elif self.reward_method.startswith('sl-sents'): # sent_targets should be a matrix of target values (non-zero only in EOS indices) sent_targets = util.replace_masked_values( sent_targets, valid_output_mask, -1e7) output_dict["loss"] = util.masked_mean( ((span_start_logits - sent_targets)**2), valid_output_mask, 1) if store_metrics: self._span_start_accuracy(span_start_logits, sent_targets.max(-1)[1]) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. batch_ems = [] batch_f1s = [] if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) sample_squad_metrics = SquadEmAndF1() sample_squad_metrics(best_span_string, answer_texts) sample_em, sample_f1 = sample_squad_metrics.get_metric( reset=True) batch_ems.append(sample_em) batch_f1s.append(sample_f1) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['em'] = torch.tensor(batch_ems) output_dict['f1'] = torch.tensor(batch_f1s) return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector( max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu( raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, embedding_dim) # span_embeddings = util.masked_max(span_embeddings, span_mask.unsqueeze(-1), dim=2) span_embeddings = util.masked_mean(span_embeddings, span_mask.unsqueeze(-1), dim=2) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. span_width_embeddings = self._span_width_embedding( span_widths.squeeze(-1)) span_embeddings = torch.cat( [span_embeddings, span_width_embeddings], -1) return span_embeddings
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], choices_list: Dict[str, torch.LongTensor], choice_kb: Dict[str, torch.LongTensor], answer_text: Dict[str, torch.LongTensor], fact: Dict[str, torch.LongTensor], answer_spans: torch.IntTensor, relations: torch.IntTensor = None, relation_label: torch.IntTensor = None, answer_id: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # B X C X Ct X D embedded_choice, choice_mask = get_embedding(choices_list, 1, self._text_field_embedder, self._encoder, self._var_dropout) # B X C X D # agg_choice, agg_choice_mask = get_agg_rep(embedded_choice, choice_mask, 1, self._encoder, self._aggregate) num_choices = embedded_choice.size()[1] batch_size = embedded_choice.size()[0] # B X Qt X D embedded_question, question_mask = get_embedding( question, 0, self._text_field_embedder, self._encoder, self._var_dropout) # B X D agg_question, agg_question_mask = get_agg_rep(embedded_question, question_mask, 0, self._encoder, self._aggregate) # B X Ft X D embedded_fact, fact_mask = get_embedding(fact, 0, self._text_field_embedder, self._encoder, self._var_dropout) # B X D agg_fact, agg_fact_mask = get_agg_rep(embedded_fact, fact_mask, 0, self._encoder, self._aggregate) # ============================================== # Interaction between fact and question # ============================================== # B x Ft x Qt fact_question_att = self._attention(embedded_fact, embedded_question) fact_question_mask = self.add_dimension(question_mask, 1, fact_question_att.shape[1]) masked_fact_question_att = replace_masked_values( fact_question_att, fact_question_mask, -1e7) # B X Ft fact_question_att_max = masked_fact_question_att.max( dim=-1)[0].squeeze(-1) fact_question_att_softmax = masked_softmax(fact_question_att_max, fact_mask) # B X D fact_question_att_rep = weighted_sum(embedded_fact, fact_question_att_softmax) # B*C X D cmerged_fact_question_att_rep = self.merge_dimensions( self.add_dimension(fact_question_att_rep, 1, num_choices)) # ============================================== # Interaction between fact and answer choices # ============================================== # B*C X Ft X D cmerged_embedded_fact = self.merge_dimensions( self.add_dimension(embedded_fact, 1, num_choices)) cmerged_fact_mask = self.merge_dimensions( self.add_dimension(fact_mask, 1, num_choices)) # B*C X Ct X D cmerged_embedded_choice = self.merge_dimensions(embedded_choice) cmerged_choice_mask = self.merge_dimensions(choice_mask) # B*C X Ft X Ct cmerged_fact_choice_att = self._attention(cmerged_embedded_fact, cmerged_embedded_choice) cmerged_fact_choice_mask = self.add_dimension( cmerged_choice_mask, 1, cmerged_fact_choice_att.shape[1]) masked_cmerged_fact_choice_att = replace_masked_values( cmerged_fact_choice_att, cmerged_fact_choice_mask, -1e7) # B*C X Ft cmerged_fact_choice_att_max = masked_cmerged_fact_choice_att.max( dim=-1)[0].squeeze(-1) cmerged_fact_choice_att_softmax = masked_softmax( cmerged_fact_choice_att_max, cmerged_fact_mask) # B*C X D cmerged_fact_choice_att_rep = weighted_sum( cmerged_embedded_fact, cmerged_fact_choice_att_softmax) # ============================================== # Combined fact + choice + question + span rep # ============================================== if not self._ignore_spans and not self._ignore_ann: # B X A per_span_mask = (answer_spans >= 0).long()[:, :, 0] # B X A X D per_span_rep = self._span_extractor(embedded_fact, answer_spans, fact_mask, per_span_mask) # expanded_span_mask = per_span_mask.unsqueeze(-1).expand_as(per_span_rep) # B X D answer_span_rep = per_span_rep[:, 0, :] # B*C X D cmerged_span_rep = self.merge_dimensions( self.add_dimension(answer_span_rep, 1, num_choices)) fact_choice_question_rep = (cmerged_fact_choice_att_rep + cmerged_fact_question_att_rep + cmerged_span_rep) / 3 else: fact_choice_question_rep = (cmerged_fact_choice_att_rep + cmerged_fact_question_att_rep) / 2 # B*C X D cmerged_fact_rep = masked_mean( cmerged_embedded_fact, cmerged_fact_mask.unsqueeze(-1).expand_as(cmerged_embedded_fact), 1) # B*C X D fact_question_combined_rep = combine_tensors( self._coverage_combination, [fact_choice_question_rep, cmerged_fact_rep]) # B X C X D new_size = [batch_size, num_choices, -1] fact_question_combined_rep = fact_question_combined_rep.contiguous( ).view(*new_size) # B X C coverage_score = self._coverage_ff(fact_question_combined_rep).squeeze( -1) logger.info("coverage_score" + str(coverage_score.shape)) # ============================================== # Interaction between spans+choices and KB # ============================================== # B X C X K X Kt x D embedded_choice_kb, choice_kb_mask = get_embedding( choice_kb, 2, self._text_field_embedder, self._encoder, self._var_dropout) num_kb = embedded_choice_kb.size()[2] # B X A X At X D embedded_answer, answer_mask = get_embedding(answer_text, 1, self._text_field_embedder, self._encoder, self._var_dropout) # B X At X D embedded_answer = embedded_answer[:, 0, :, :] answer_mask = answer_mask[:, 0, :] # B*C*K X Kt X D ckmerged_embedded_choice_kb = self.merge_dimensions( self.merge_dimensions(embedded_choice_kb)) ckmerged_choice_kb_mask = self.merge_dimensions( self.merge_dimensions(choice_kb_mask)) # B*C X At X D cmerged_embedded_answer = self.merge_dimensions( self.add_dimension(embedded_answer, 1, num_choices)) cmerged_answer_mask = self.merge_dimensions( self.add_dimension(answer_mask, 1, num_choices)) # B*C*K X At X D ckmerged_embedded_answer = self.merge_dimensions( self.add_dimension(cmerged_embedded_answer, 1, num_kb)) ckmerged_answer_mask = self.merge_dimensions( self.add_dimension(cmerged_answer_mask, 1, num_kb)) # B*C*K X Ct X D ckmerged_embedded_choice = self.merge_dimensions( self.add_dimension(cmerged_embedded_choice, 1, num_kb)) ckmerged_choice_mask = self.merge_dimensions( self.add_dimension(cmerged_choice_mask, 1, num_kb)) logger.info("ckmerged_choice_mask" + str(ckmerged_choice_mask.shape)) # == KB rep based on answer span == if self._ignore_ann: # B*C*K X Ft X D ckmerged_embedded_fact = self.merge_dimensions( self.add_dimension(cmerged_embedded_fact, 1, num_kb)) ckmerged_fact_mask = self.merge_dimensions( self.add_dimension(cmerged_fact_mask, 1, num_kb)) # B*C*K X Kt x Ft ckmerged_kb_fact_att = self._attention(ckmerged_embedded_choice_kb, ckmerged_embedded_fact) ckmerged_kb_fact_mask = self.add_dimension( ckmerged_fact_mask, 1, ckmerged_kb_fact_att.shape[1]) masked_ckmerged_kb_fact_att = replace_masked_values( ckmerged_kb_fact_att, ckmerged_kb_fact_mask, -1e7) # B*C*K X Kt ckmerged_kb_answer_att_max = masked_ckmerged_kb_fact_att.max( dim=-1)[0].squeeze(-1) else: # B*C*K X Kt x At ckmerged_kb_answer_att = self._attention( ckmerged_embedded_choice_kb, ckmerged_embedded_answer) ckmerged_kb_answer_mask = self.add_dimension( ckmerged_answer_mask, 1, ckmerged_kb_answer_att.shape[1]) masked_ckmerged_kb_answer_att = replace_masked_values( ckmerged_kb_answer_att, ckmerged_kb_answer_mask, -1e7) # B*C*K X Kt ckmerged_kb_answer_att_max = masked_ckmerged_kb_answer_att.max( dim=-1)[0].squeeze(-1) ckmerged_kb_answer_att_softmax = masked_softmax( ckmerged_kb_answer_att_max, ckmerged_choice_kb_mask) # B*C*K X D kb_answer_att_rep = weighted_sum(ckmerged_embedded_choice_kb, ckmerged_kb_answer_att_softmax) # == KB rep based on answer choice == # B*C*K X Kt x Ct ckmerged_kb_choice_att = self._attention(ckmerged_embedded_choice_kb, ckmerged_embedded_choice) ckmerged_kb_choice_mask = self.add_dimension( ckmerged_choice_mask, 1, ckmerged_kb_choice_att.shape[1]) masked_ckmerged_kb_choice_att = replace_masked_values( ckmerged_kb_choice_att, ckmerged_kb_choice_mask, -1e7) # B*C*K X Kt ckmerged_kb_choice_att_max = masked_ckmerged_kb_choice_att.max( dim=-1)[0].squeeze(-1) ckmerged_kb_choice_att_softmax = masked_softmax( ckmerged_kb_choice_att_max, ckmerged_choice_kb_mask) # B*C*K X D kb_choice_att_rep = weighted_sum(ckmerged_embedded_choice_kb, ckmerged_kb_choice_att_softmax) # B*C*K X D answer_choice_kb_combined_rep = combine_tensors( self._answer_choice_combination, [kb_answer_att_rep, kb_choice_att_rep]) logger.info("answer_choice_kb_combined_rep" + str(answer_choice_kb_combined_rep.shape)) # ============================================== # Relation Predictions # ============================================== # B*C*K x R choice_kb_relation_rep = self._relation_predictor( answer_choice_kb_combined_rep) new_choice_kb_size = [batch_size * num_choices, num_kb, -1] # B*C*K merged_choice_kb_mask = (torch.sum(ckmerged_choice_kb_mask, dim=-1) > 0).float() if self._num_relations and not self._ignore_ann: if self._relation_projector: choice_kb_relation_pred = self._relation_projector( choice_kb_relation_rep) else: choice_kb_relation_pred = choice_kb_relation_rep # Aggregate the predictions # B*C*K choice_kb_relation_mask = self.add_dimension( merged_choice_kb_mask, -1, choice_kb_relation_pred.shape[-1]) choice_kb_relation_pred_masked = replace_masked_values( choice_kb_relation_pred, choice_kb_relation_mask, -1e7) # B*C X K X R relation_pred_perkb = choice_kb_relation_pred_masked.contiguous( ).view(*new_choice_kb_size) # B*C X R relation_pred_max = relation_pred_perkb.max(dim=1)[0].squeeze(1) # B X C X R choice_relation_size = [batch_size, num_choices, -1] relation_label_logits = relation_pred_max.contiguous().view( *choice_relation_size) relation_label_probs = softmax(relation_label_logits, dim=-1) # B X C add_relation_predictions(self.vocab, relation_label_probs, metadata) # B X C X K X R choice_kb_relation_size = [batch_size, num_choices, num_kb, -1] relation_predictions = choice_kb_relation_rep.contiguous().view( *choice_kb_relation_size) add_tuple_predictions(relation_predictions, metadata) logger.info("relation_predictions" + str(relation_predictions.shape)) else: relation_label_logits = None relation_label_probs = None if not self._ignore_relns: # B X C X D expanded_size = [batch_size, num_choices, -1] # Aggregate the relation representation if self._relation_projector or self._num_relations == 0 or self._ignore_ann: # B*C X K X D relation_rep_perkb = choice_kb_relation_rep.contiguous().view( *new_choice_kb_size) # B*C*K X D merged_relation_rep_mask = self.add_dimension( merged_choice_kb_mask, -1, relation_rep_perkb.shape[-1]) # B*C X K X D relation_rep_perkb_mask = merged_relation_rep_mask.contiguous( ).view(*relation_rep_perkb.size()) # B*C X D agg_relation_rep = masked_mean(relation_rep_perkb, relation_rep_perkb_mask, dim=1) # B X C X D expanded_relation_rep = agg_relation_rep.contiguous().view( *expanded_size) else: expanded_relation_rep = relation_label_logits expanded_question_rep = agg_question.unsqueeze(1).expand( expanded_size) expanded_fact_rep = agg_fact.unsqueeze(1).expand(expanded_size) question_fact_rep = combine_tensors( self._combination, [expanded_question_rep, expanded_fact_rep]) relation_score_rep = torch.cat( [question_fact_rep, expanded_relation_rep], dim=-1) relation_score = self._reln_ff(relation_score_rep).squeeze(-1) choice_label_logits = (coverage_score + relation_score) / 2 else: choice_label_logits = coverage_score logger.info("choice_label_logits" + str(choice_label_logits.shape)) choice_label_probs = softmax(choice_label_logits, dim=-1) output_dict = { "label_logits": choice_label_logits, "label_probs": choice_label_probs, "metadata": metadata } if relation_label_logits is not None: output_dict["relation_label_logits"] = relation_label_logits output_dict["relation_label_probs"] = relation_label_probs if answer_id is not None or relation_label is not None: self.compute_loss_and_accuracy(answer_id, relation_label, relation_label_logits, choice_label_logits, output_dict) return output_dict
def forward( self, context_1: torch.Tensor, mask_1: torch.Tensor, context_2: torch.Tensor, mask_2: torch.Tensor, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral matching functions between them in one direction. Parameters ---------- context_1 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence. mask_1 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len1), indicating which positions in the first sentence are padding (0) and which are not (1). context_2 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence. mask_2 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len2), indicating which positions in the second sentence are padding (0) and which are not (1). Returns ------- A tuple of matching vectors for the two sentences. Each of which is a list of matching vectors of shape (batch, seq_len, num_perspectives or 1) """ assert (not mask_2.requires_grad) and (not mask_1.requires_grad) assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim # (batch,) len_1 = get_lengths_from_binary_sequence_mask(mask_1) len_2 = get_lengths_from_binary_sequence_mask(mask_2) # (batch, seq_len*) mask_1, mask_2 = mask_1.float(), mask_2.float() # explicitly set masked weights to zero # (batch_size, seq_len*, hidden_dim) context_1 = context_1 * mask_1.unsqueeze(-1) context_2 = context_2 * mask_2.unsqueeze(-1) # array to keep the matching vectors for the two sentences matching_vector_1: List[torch.Tensor] = [] matching_vector_2: List[torch.Tensor] = [] # Step 0. unweighted cosine # First calculate the cosine similarities between each forward # (or backward) contextual embedding and every forward (or backward) # contextual embedding of the other sentence. # (batch, seq_len1, seq_len2) cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3) # (batch, seq_len*, 1) cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) matching_vector_1.extend([cosine_max_1, cosine_mean_1]) matching_vector_2.extend([cosine_max_2, cosine_mean_2]) # Step 1. Full-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with the last time step of the forward (or backward) # contextual embedding of the other sentence if self.with_full_match: # (batch, 1, hidden_dim) if self.is_forward: # (batch, 1, hidden_dim) last_position_1 = (len_1 - 1).clamp(min=0) last_position_1 = last_position_1.view(-1, 1, 1).expand( -1, 1, self.hidden_dim) last_position_2 = (len_2 - 1).clamp(min=0) last_position_2 = last_position_2.view(-1, 1, 1).expand( -1, 1, self.hidden_dim) context_1_last = context_1.gather(1, last_position_1) context_2_last = context_2.gather(1, last_position_2) else: context_1_last = context_1[:, 0:1, :] context_2_last = context_2[:, 0:1, :] # (batch, seq_len*, num_perspectives) matching_vector_1_full = multi_perspective_match( context_1, context_2_last, self.full_match_weights) matching_vector_2_full = multi_perspective_match( context_2, context_1_last, self.full_match_weights_reversed) matching_vector_1.extend(matching_vector_1_full) matching_vector_2.extend(matching_vector_2_full) # Step 2. Maxpooling-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with every time step of the forward (or backward) # contextual embedding of the other sentence, and only the max value of each # dimension is retained. if self.with_maxpool_match: # (batch, seq_len1, seq_len2, num_perspectives) matching_vector_max = multi_perspective_match_pairwise( context_1, context_2, self.maxpool_match_weights) # (batch, seq_len*, num_perspectives) matching_vector_1_max = masked_max( matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1_mean = masked_mean( matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_max = masked_max( matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_mean = masked_mean( matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1.extend( [matching_vector_1_max, matching_vector_1_mean]) matching_vector_2.extend( [matching_vector_2_max, matching_vector_2_mean]) # Step 3. Attentive-Matching # Each forward (or backward) similarity is taken as the weight # of the forward (or backward) contextual embedding, and calculate an # attentive vector for the sentence by weighted summing all its # contextual embeddings. # Finally match each forward (or backward) contextual embedding # with its corresponding attentive vector. # (batch, seq_len1, seq_len2, hidden_dim) att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1) # (batch, seq_len1, seq_len2, hidden_dim) att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1) if self.with_attentive_match: # (batch, seq_len*, hidden_dim) att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1)) att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1)) # (batch, seq_len*, num_perspectives) matching_vector_1_att_mean = multi_perspective_match( context_1, att_mean_2, self.attentive_match_weights) matching_vector_2_att_mean = multi_perspective_match( context_2, att_mean_1, self.attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_mean) matching_vector_2.extend(matching_vector_2_att_mean) # Step 4. Max-Attentive-Matching # Pick the contextual embeddings with the highest cosine similarity as the attentive # vector, and match each forward (or backward) contextual embedding with its # corresponding attentive vector. if self.with_max_attentive_match: # (batch, seq_len*, hidden_dim) att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) # (batch, seq_len*, num_perspectives) matching_vector_1_att_max = multi_perspective_match( context_1, att_max_2, self.max_attentive_match_weights) matching_vector_2_att_max = multi_perspective_match( context_2, att_max_1, self.max_attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_max) matching_vector_2.extend(matching_vector_2_att_max) return matching_vector_1, matching_vector_2
def forward( self, messages: Dict[str, torch.Tensor], # (batch_size, n_turns, n_facts, n_words) facts: Dict[str, torch.Tensor], # (batch_size, n_turns) senders: torch.Tensor, # (batch_size, n_turns, n_acts) dialog_acts: torch.Tensor, # (batch_size, n_turns) dialog_acts_mask: torch.Tensor, # (batch_size, n_entities) known_entities: Dict[str, torch.Tensor], # (batch_size, 1) focus_entity: Dict[str, torch.Tensor], # (batch_size, n_turns, n_facts) fact_labels: Optional[torch.Tensor] = None, # (batch_size, n_turns, 2) likes: Optional[torch.Tensor] = None, metadata: Optional[Dict] = None, ): output = {} # Take care of the easy stuff first # (batch_size, n_entities) known_entities_mask = get_text_field_mask(known_entities) # (batch_size, n_turns, sender_emb_size) sender_emb = self._sender_emb(senders) known_emb = self._mention_embedder(known_entities) # TODO: This could instead of averaged, be attended known_vec = self._known_net( masked_mean(known_emb, known_entities_mask.unsqueeze(-1), dim=1)) # There is always exactly one entity focus_emb = self._focus_net( self._mention_embedder(focus_entity)[:, 0, :]) if self._use_bert: # (batch_size, n_turns, n_words, emb_dim) context, utter_mask = self._bert_encoder(messages) context = self._dropout(context) else: # (batch_size, n_turns) # This is the mask since not all dialogs have same number # of turns utter_mask = get_text_field_mask(messages) # (batch_size, n_turns, n_words) # Mask since not all utterances have same number of words # Wrapping dim skips over n_messages dim text_mask = get_text_field_mask(messages, num_wrapping_dims=1) # (batch_size, n_turns, n_words, emb_dim) embed = self._dropout(self._utter_embedder(messages)) # (batch_size, n_turns, hidden_dim) context = self._dist_utter_context(embed, text_mask) # (batch_size, n_turns, act_emb_size) act_emb = self._act_embedder(dialog_acts.float()) act_emb = self._clamp_dialog_acts(act_emb) # (batch_size, n_turns, hidden_dim + known_dim + focus_dim + sender_dim + act_dim) n_turns = context.shape[1] full_context = torch.cat( ( context, sender_emb, act_emb, focus_emb[:, None, :].repeat_interleave(n_turns, 1), known_vec[:, None, :].repeat_interleave(n_turns, 1), ), dim=-1, ) # (batch_size, n_turns, hidden_dim) # This assumes dialog_context does not peek into future dialog_context = self._dialog_context(full_context, utter_mask) # shift context one right, pad with zeros at front # This makes it so that utter_t is paired with context_t-1 # which is what we want # This is useful in a few different places, so compute it here once shape = dialog_context.shape shifted_context = torch.cat( ( dialog_context.new_zeros([shape[0], 1, shape[2]]), dialog_context[:, :-1, :], ), dim=1, ) has_loss = False if self._disable_dialog_acts: da_loss = 0 policy_loss = 0 else: # Dialog act per utter loss has_loss = True da_loss = self._compute_da_loss( output, context, shifted_context, utter_mask, dialog_acts, dialog_acts_mask, ) # Policy loss policy_loss = self._compute_policy_loss(output, shifted_context, utter_mask, dialog_acts, dialog_acts_mask) if self._disable_facts: # If facts are disabled, don't output anything related # to them fact_loss = 0 else: if self._use_bert: # (batch_size, n_turns, n_words, emb_dim) fact_repr, fact_mask = self._bert_encoder(facts) fact_repr = self._dropout(fact_repr) fact_mask[:, ::2] = 0 else: # (batch_size, n_turns, n_facts) # Wrapping dim skips over n_messages fact_mask = get_text_field_mask(facts, num_wrapping_dims=1) # In addition to masking padded facts, also explicitly mask # user turns just in case fact_mask[:, ::2] = 0 # (batch_size, n_turns, n_facts, n_words) # Wrapping dim skips over n_turns and n_facts fact_text_mask = get_text_field_mask(facts, num_wrapping_dims=2) # (batch_size, n_turns, n_facts, n_words, emb_dim) # Share encoder with utter encoder # Again, stupid dimensions fact_embed = self._dropout(self._utter_embedder(facts)) shape = fact_embed.shape word_dim = shape[-2] emb_dim = shape[-1] reshaped_facts = fact_embed.view(-1, word_dim, emb_dim) reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim) reshaped_fact_repr = self._utter_context( reshaped_facts, reshaped_fact_text_mask) # No more emb dimension or word/seq dim fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1, )) fact_logits = self._fact_ranker( shifted_context, fact_repr, ) output["fact_logits"] = fact_logits if fact_labels is not None: has_loss = True fact_loss = self._compute_fact_loss(fact_logits, fact_labels, fact_mask) self._fact_loss_metric(fact_loss.item()) self._fact_mrr(fact_logits, fact_labels, mask=fact_mask) else: fact_loss = 0 if self._disable_likes: like_loss = 0 else: has_loss = True # (batch_size, n_turns, 2) like_logits = self._like_classifier(dialog_context) output["like_logits"] = like_logits # There are several masks here to get the loss/metrics correct # - utter_mask: mask out positions that do not have an utterance # - user_mask: mask out positions that have a user utterances # since their turns are never liked # Using new_ones() preserves the type of the tensor user_mask = utter_mask.new_ones(utter_mask.shape) # Since the user is always even, this masks out user positions user_mask[:, ::2] = 0 final_mask = utter_mask * user_mask masked_likes = likes * final_mask if likes is not None: has_loss = True like_loss = sequence_cross_entropy_with_logits( like_logits, masked_likes, final_mask) self._like_accuracy(like_logits, masked_likes, final_mask) self._like_loss_metric(like_loss.item()) else: like_loss = 0 if has_loss: output["loss"] = (self._fact_loss_weight * fact_loss + like_loss + da_loss + policy_loss) return output
def forward(self, document, kept_tokens, rationale=None, label=None, metadata=None) -> Dict[str, Any]: generator_dict = self._generator(document, rationale) mask = util.get_text_field_mask(document) assert "prob_z" in generator_dict prob_z = generator_dict["prob_z"] assert len(prob_z.shape) == 2 prob_z = kept_tokens.float() + prob_z * (1 - kept_tokens) sampler = D.bernoulli.Bernoulli(probs=prob_z) sample_z = sampler.sample() * mask.float() encoder_dict = self._encoder(sample_z=sample_z, label=label, metadata=metadata) loss = 0.0 if label is not None: assert "loss" in encoder_dict loss_sample = encoder_dict["loss"] # (B,) loss += loss_sample.mean() lasso_loss = util.masked_mean(sample_z, mask, dim=-1) # (B,) masked_sum = mask[:, :-1].sum(-1).clamp(1e-5) diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs() masked_diff = (diff * mask[:, :-1]).sum(-1) fused_lasso_loss = masked_diff / masked_sum self._loss_tracks["lasso_loss"](lasso_loss.mean().item()) self._loss_tracks["fused_lasso_loss"]( fused_lasso_loss.mean().item()) self._loss_tracks["base_loss"](loss_sample.mean().item()) log_prob_z = torch.log( 1 + torch.exp(sampler.log_prob(sample_z))) # (B, L) log_prob_z_sum = (mask * log_prob_z).mean(-1) # (B,) generator_loss = ( loss_sample.detach() + lasso_loss * self._reg_loss_lambda + fused_lasso_loss * (self._reg_loss_mu * self._reg_loss_lambda)) * log_prob_z_sum loss += self._reinforce_loss_weight * generator_loss.mean() output_dict = generator_dict loss += self._rationale_supervision_loss_weight * generator_dict.get( "rationale_supervision_loss", 0.0) output_dict["logits"] = encoder_dict["logits"] output_dict['probs'] = encoder_dict['probs'] output_dict["class_probs"] = encoder_dict["class_probs"] output_dict["predicted_labels"] = encoder_dict["predicted_labels"] output_dict["gold_labels"] = encoder_dict["gold_labels"] output_dict["loss"] = loss output_dict["metadata"] = metadata output_dict["mask"] = mask self._call_metrics(output_dict) return output_dict
def forward(self, document, rationale=None, kept_tokens=None, query=None, label=None, metadata=None) -> Dict[str, Any]: generator_dict = self._generator(document, rationale) mask = util.get_text_field_mask(document) assert "probs" in generator_dict prob_z = generator_dict["probs"] assert len(prob_z.shape) == 2 output_dict = {} sampler = D.bernoulli.Bernoulli(probs=prob_z) if self.prediction_mode or not self.training: if self._rationale_extractor is None: sample_z = generator_dict['predicted_rationale'].float() else: sample_z = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=True) output_dict[ "rationale"] = self._rationale_extractor.extract_rationale( prob_z, metadata, as_one_hot=False) sample_z = torch.Tensor(sample_z).to(prob_z.device).float() else: sample_z = sampler.sample() sample_z = sample_z * mask reduced_document = self.regenerate_tokens(metadata, sample_z) encoder_dict = self._encoder( document=reduced_document, query=query, label=label, metadata=metadata, ) loss = generator_dict['loss'] if label is not None: assert "loss" in encoder_dict log_prob_z = sampler.log_prob(sample_z) # (B, L) log_prob_z_sum = (mask * log_prob_z).sum(-1) # (B,) loss_sample = F.cross_entropy(encoder_dict["logits"], label, reduction="none") # (B,) sparsity = util.masked_mean(sample_z, mask, dim=-1) censored_lasso_loss = F.relu(sparsity - self._desired_length) diff = (sample_z[:, 1:] - sample_z[:, :-1]).abs() mask_last = mask[:, :-1] fused_lasso_loss = diff.sum(-1) / mask_last.sum(-1) self._loss_tracks["_lasso_loss"](sparsity.mean().item()) self._loss_tracks["_fused_lasso_loss"]( fused_lasso_loss.mean().item()) self._loss_tracks["_base_loss"](loss_sample.mean().item()) base_loss = loss_sample generator_loss = ( loss_sample.detach() + censored_lasso_loss * self._reg_loss_lambda + fused_lasso_loss * (self._reg_loss_mu * self._reg_loss_lambda)) * log_prob_z_sum loss += (base_loss + generator_loss).mean() output_dict["probs"] = encoder_dict["probs"] output_dict["predicted_labels"] = encoder_dict["predicted_labels"] output_dict["loss"] = loss output_dict["gold_labels"] = label output_dict["metadata"] = metadata output_dict["prob_z"] = generator_dict["prob_z"] output_dict["predicted_rationale"] = generator_dict[ "predicted_rationale"] self._loss_tracks["_rat_length"](util.masked_mean( generator_dict["predicted_rationale"], mask, dim=-1).mean().item()) self._call_metrics(output_dict) return output_dict
def forward(self, document, query=None, label=None, metadata=None, rationale=None) -> Dict[str, Any]: # pylint: disable=arguments-differ generator_dict = self._generator(document, query, label) mask = generator_dict["mask"] assert "a" in generator_dict assert "b" in generator_dict a, b = generator_dict["a"], generator_dict["b"] a = a.clamp(1e-6, 100.0) # extreme values could result in NaNs b = b.clamp(1e-6, 100.0) # extreme values could result in NaNs output_dict = {} sampler = HardKuma([a, b], support=[ self.support[0].to(a.device), self.support[1].to(b.device) ]) generator_dict["predicted_rationale"] = (sampler.mean() > 0.5).long() * mask if self.prediction_mode or not self.training: if self._rationale_extractor is None: # We constrain rationales to be 0 or 1 strictly. See Pruthi et al # for pathologies when this is not the case. sample_z = (sampler.mean() > 0.5).long() * mask else: prob_z = sampler.mean() sample_z = self._rationale_extractor.extract_rationale( prob_z, document, as_one_hot=True) output_dict[ "rationale"] = self._rationale_extractor.extract_rationale( prob_z, document, as_one_hot=False) sample_z = torch.Tensor(sample_z).to(prob_z.device).float() else: sample_z = sampler.sample() sample_z = sample_z * mask # Because BERT is BERT wordpiece_to_token = generator_dict["wordpiece-to-token"] wtt0 = torch.where(wordpiece_to_token == -1, torch.tensor([0]).to(wordpiece_to_token.device), wordpiece_to_token) wordpiece_sample = util.batched_index_select(sample_z.unsqueeze(-1), wtt0) wordpiece_sample[wordpiece_to_token.unsqueeze(-1) == -1] = 1.0 def scale_embeddings(module, input, output): output = output * wordpiece_sample return output hook = self._encoder.embedding_layers[0].register_forward_hook( scale_embeddings) encoder_dict = self._encoder( document=document, query=query, label=label, metadata=metadata, ) hook.remove() loss = 0.0 if label is not None: assert "loss" in encoder_dict base_loss = F.cross_entropy(encoder_dict["logits"], label) # (B,) loss += base_loss pdf0 = sampler.pdf(0.0) * mask pdf_nonzero = (1 - pdf0) * mask lasso_loss = pdf_nonzero.sum(1) lengths = mask.sum(1) lasso_loss = lasso_loss / (lengths + 1e-9) lasso_loss = lasso_loss.mean() c0_hat = F.relu(lasso_loss - self._desired_length) if self.training: self.c0_ma = self.lagrange_alpha * self.c0_ma + ( 1 - self.lagrange_alpha) * c0_hat.item() c0 = c0_hat + (self.c0_ma.detach() - c0_hat.detach()) if self.training: self.lambda0 = self.lambda0 * torch.exp( self.lagrange_lr * c0.detach()) self.lambda0 = self.lambda0.clamp(self.lambda_min, self.lambda_max) self._loss_tracks["_lasso_loss"](lasso_loss.item()) self._loss_tracks["_base_loss"](base_loss.item()) self._loss_tracks["_lambda0"](self.lambda0[0].item()) self._loss_tracks["_c0_ma"](self.c0_ma[0].item()) self._loss_tracks["_c0"](c0_hat.item()) regulariser_loss = (self.lambda0.detach() * c0)[0] loss += regulariser_loss output_dict["probs"] = encoder_dict["probs"] output_dict["predicted_labels"] = encoder_dict["predicted_labels"] output_dict["loss"] = loss output_dict["gold_labels"] = label output_dict["metadata"] = metadata output_dict["predicted_rationale"] = generator_dict[ "predicted_rationale"] self._loss_tracks["_rat_length"]( util.masked_mean(generator_dict["predicted_rationale"], mask == 1, dim=-1).mean().item()) self._call_metrics(output_dict) return output_dict
def forward(self, context_1: torch.Tensor, mask_1: torch.Tensor, context_2: torch.Tensor, mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # pylint: disable=arguments-differ """ Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral matching functions between them in one direction. Parameters ---------- context_1 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence. mask_1 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len1), indicating which positions in the first sentence are padding (0) and which are not (1). context_2 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence. mask_2 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len2), indicating which positions in the second sentence are padding (0) and which are not (1). Returns ------- A tuple of matching vectors for the two sentences. Each of which is a list of matching vectors of shape (batch, seq_len, num_perspectives or 1) """ assert (not mask_2.requires_grad) and (not mask_1.requires_grad) assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim # (batch,) len_1 = get_lengths_from_binary_sequence_mask(mask_1) len_2 = get_lengths_from_binary_sequence_mask(mask_2) # (batch, seq_len*) mask_1, mask_2 = mask_1.float(), mask_2.float() # explicitly set masked weights to zero # (batch_size, seq_len*, hidden_dim) context_1 = context_1 * mask_1.unsqueeze(-1) context_2 = context_2 * mask_2.unsqueeze(-1) # array to keep the matching vectors for the two sentences matching_vector_1: List[torch.Tensor] = [] matching_vector_2: List[torch.Tensor] = [] # Step 0. unweighted cosine # First calculate the cosine similarities between each forward # (or backward) contextual embedding and every forward (or backward) # contextual embedding of the other sentence. # (batch, seq_len1, seq_len2) cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3) # (batch, seq_len*, 1) cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) matching_vector_1.extend([cosine_max_1, cosine_mean_1]) matching_vector_2.extend([cosine_max_2, cosine_mean_2]) # Step 1. Full-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with the last time step of the forward (or backward) # contextual embedding of the other sentence if self.with_full_match: # (batch, 1, hidden_dim) if self.is_forward: # (batch, 1, hidden_dim) last_position_1 = (len_1 - 1).clamp(min=0) last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim) last_position_2 = (len_2 - 1).clamp(min=0) last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim) context_1_last = context_1.gather(1, last_position_1) context_2_last = context_2.gather(1, last_position_2) else: context_1_last = context_1[:, 0:1, :] context_2_last = context_2[:, 0:1, :] # (batch, seq_len*, num_perspectives) matching_vector_1_full = multi_perspective_match(context_1, context_2_last, self.full_match_weights) matching_vector_2_full = multi_perspective_match(context_2, context_1_last, self.full_match_weights_reversed) matching_vector_1.extend(matching_vector_1_full) matching_vector_2.extend(matching_vector_2_full) # Step 2. Maxpooling-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with every time step of the forward (or backward) # contextual embedding of the other sentence, and only the max value of each # dimension is retained. if self.with_maxpool_match: # (batch, seq_len1, seq_len2, num_perspectives) matching_vector_max = multi_perspective_match_pairwise(context_1, context_2, self.maxpool_match_weights) # (batch, seq_len*, num_perspectives) matching_vector_1_max = masked_max(matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1_mean = masked_mean(matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_max = masked_max(matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_mean = masked_mean(matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1.extend([matching_vector_1_max, matching_vector_1_mean]) matching_vector_2.extend([matching_vector_2_max, matching_vector_2_mean]) # Step 3. Attentive-Matching # Each forward (or backward) similarity is taken as the weight # of the forward (or backward) contextual embedding, and calculate an # attentive vector for the sentence by weighted summing all its # contextual embeddings. # Finally match each forward (or backward) contextual embedding # with its corresponding attentive vector. # (batch, seq_len1, seq_len2, hidden_dim) att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1) # (batch, seq_len1, seq_len2, hidden_dim) att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1) if self.with_attentive_match: # (batch, seq_len*, hidden_dim) att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1)) att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1)) # (batch, seq_len*, num_perspectives) matching_vector_1_att_mean = multi_perspective_match(context_1, att_mean_2, self.attentive_match_weights) matching_vector_2_att_mean = multi_perspective_match(context_2, att_mean_1, self.attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_mean) matching_vector_2.extend(matching_vector_2_att_mean) # Step 4. Max-Attentive-Matching # Pick the contextual embeddings with the highest cosine similarity as the attentive # vector, and match each forward (or backward) contextual embedding with its # corresponding attentive vector. if self.with_max_attentive_match: # (batch, seq_len*, hidden_dim) att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) # (batch, seq_len*, num_perspectives) matching_vector_1_att_max = multi_perspective_match(context_1, att_max_2, self.max_attentive_match_weights) matching_vector_2_att_max = multi_perspective_match(context_2, att_max_1, self.max_attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_max) matching_vector_2.extend(matching_vector_2_att_max) return matching_vector_1, matching_vector_2