def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # Shape: (batch_size, seq_length, embedding_dim) embedded_p = self._text_field_embedder(premise) embedded_h = self._text_field_embedder(hypothesis) mask_p = get_text_field_mask(premise).float() mask_h = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_p = self.rnn_input_dropout(embedded_p) embedded_h = self.rnn_input_dropout(embedded_h) # encode p and h # Shape: (batch_size, seq_length, encoding_direction_num * encoding_hidden_dim) encoded_p = self._encoder(embedded_p, mask_p) encoded_h = self._encoder(embedded_h, mask_h) # Shape: (batch_size, p_length, h_length) similarity_matrix = self._matrix_attention(encoded_p, encoded_h) # Shape: (batch_size, p_length, h_length) p2h_attention = last_dim_softmax(similarity_matrix, mask_h) # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim) attended_h = weighted_sum(encoded_h, p2h_attention) # Shape: (batch_size, h_length, p_length) h2p_attention = last_dim_softmax( similarity_matrix.transpose(1, 2).contiguous(), mask_p) # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim) attended_p = weighted_sum(encoded_p, h2p_attention) # the "enhancement" layer # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching) enhanced_p = torch.cat([ encoded_p, attended_h, encoded_p - attended_h, encoded_p * attended_h ], dim=-1) # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching) enhanced_h = torch.cat([ encoded_h, attended_p, encoded_h - attended_p, encoded_h * attended_p ], dim=-1) # The projection layer down to the model dimension. Dropout is not applied before # projection. # Shape: (batch_size, seq_length, projection_hidden_dim) projected_enhanced_p = self._projection_feedforward(enhanced_p) projected_enhanced_h = self._projection_feedforward(enhanced_h) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_p = self.rnn_input_dropout(projected_enhanced_p) projected_enhanced_h = self.rnn_input_dropout(projected_enhanced_h) # Shape: (batch_size, seq_length, inference_direction_num * inference_hidden_dim) inferenced_p = self._inference_encoder(projected_enhanced_p, mask_p) inferenced_h = self._inference_encoder(projected_enhanced_h, mask_h) # The pooling layer -- max and avg pooling. # Shape: (batch_size, inference_direction_num * inference_hidden_dim) pooled_p_max, _ = replace_masked_values(inferenced_p, mask_p.unsqueeze(-1), -1e7).max(dim=1) pooled_h_max, _ = replace_masked_values(inferenced_h, mask_h.unsqueeze(-1), -1e7).max(dim=1) pooled_p_avg = torch.sum(inferenced_p * mask_p.unsqueeze(-1), dim=1) / torch.sum(mask_p, 1, keepdim=True) pooled_h_avg = torch.sum(inferenced_h * mask_h.unsqueeze(-1), dim=1) / torch.sum(mask_h, 1, keepdim=True) # Now concat # Shape: (batch_size, inference_direction_num * inference_hidden_dim * 2) pooled_p_all = torch.cat([pooled_p_avg, pooled_p_max], dim=1) pooled_h_all = torch.cat([pooled_h_avg, pooled_h_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: pooled_p_all = self.dropout(pooled_p_all) pooled_h_all = self.dropout(pooled_h_all) # Shape: (batch_size, output_feedforward_hidden_dim) output_p, output_h = self._output_feedforward(pooled_p_all, pooled_h_all) distance = F.pairwise_distance(output_p, output_h) prediction = distance < (self._margin / 2.0) output_dict = {'distance': distance, "prediction": prediction} if label is not None: """ Contrastive loss function. Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf """ y = label.float() l1 = y * torch.pow(distance, 2) / 2.0 l2 = (1 - y) * torch.pow( torch.clamp(self._margin - distance, min=0.0), 2) / 2.0 loss = torch.mean(l1 + l2) self._accuracy(prediction, label.byte()) output_dict["loss"] = loss return output_dict
def forward(self, response: Dict[str, torch.LongTensor], label: torch.IntTensor = None, original_post: Optional[Dict[str, torch.LongTensor]] = None, weakpoints: Optional[torch.IntTensor] = None, fake_data: bool = False, idxs: Optional[torch.LongTensor] = None, op_features: list = None, response_features: list = None) -> Dict[str, torch.Tensor]: response_only_output = self._response_only_predictor( response, label, weakpoints=weakpoints, fake_data=fake_data, idxs=idxs, response_features=response_features) op_response_output = self._op_response_predictor( response, label, original_post, weakpoints, fake_data, idxs, op_features, response_features) combined_input = torch.cat([ response_only_output['representation'], op_response_output['representation'] ], dim=-1) #the sentence representations are of shape BxSxD (assuming same dimension) batch_size, max_sentences, _ = response_only_output[ 'encoded_response'].shape #print(response_only_output['response_mask'].shape, response_only_output['encoded_response'].shape) orthogonality_loss = response_only_output['response_mask'].float() * ( response_only_output['encoded_response'] * op_response_output['encoded_response']).sum( dim=-1).abs().squeeze(-1) #print(orthogonality_loss.shape) orthogonality_loss_avg = torch.sum( orthogonality_loss, dim=1) / torch.sum( response_only_output['response_mask'].float()) #print(orthogonality_loss_avg.shape) label_logits = self._output_feedforward(combined_input).squeeze(-1) label_probs = torch.sigmoid(label_logits) predictions = label_probs > 0.5 output_dict = { "label_logits": label_logits, "label_probs": label_probs, "representation": combined_input } true_weight = 1 if ((label == 1).sum().float() == 0) else ((label == 0).sum().float() / (label == 1).sum().float()) weight = label.eq(0).float() + label.eq(1).float() * true_weight loss = self._loss(label_logits, label.float(), weight=weight) #print(loss, orthogonality_loss_avg.mean()) if fake_data: self._fake_accuracy(predictions, label.byte()) self._fake_fscore( torch.stack([1 - predictions, predictions], dim=1), label) else: self._accuracy(predictions, label.byte()) self._fscore(torch.stack([1 - predictions, predictions], dim=1), label) output_dict["loss"] = loss + orthogonality_loss_avg.mean() return output_dict
def forward( self, response: Dict[str, torch.LongTensor], label: torch.IntTensor = None, original_post: Optional[Dict[str, torch.LongTensor]] = None, weakpoints: Optional[torch.IntTensor] = None, fake_data: bool = False, idxs: Optional[torch.LongTensor] = None, op_features: list = None, response_features: list = None, compress_response: bool = False, op_doc_features: list = None, response_doc_features: list = None, ) -> Dict[str, torch.Tensor]: #print(original_post) #print(response) #print('label', label) ''' print('LABEL', label[0]) for key in original_post: print('ORIGINAL POST') for i in range(original_post[key][0].size(0)): o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)] if len(o): print(o) print('RESPONSE') for i in range(response[key][0].size(0)): o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)] if len(o): print(o) ''' if False: encoded_response = None response_mask = None combined_input = [] else: if idxs is not None and compress_response: response = extract(response, idxs) #print([(key, response[key].shape if response[key] is not None else None) for key in response]) embedded_response = self._response_embedder(response, num_wrapping_dims=1) #print(embedded_op.shape, embedded_response.shape) response_mask = get_text_field_mask( {i: j for i, j in response.items() if i != 'mask'}, num_wrapping_dims=1).float() if self.rnn_input_dropout: embedded_response = self.rnn_input_dropout(embedded_response) # encode response at sentence level #print(response_mask.shape, embedded_response.shape) batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape embedded_response = embedded_response.view( batch_size * max_response_sentences, max_response_words, -1) response_mask = response_mask.view( batch_size * max_response_sentences, max_response_words) sentence_encoded_response = self._response_word_attention( embedded_response, response_mask).view(batch_size, max_response_sentences, -1) #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape) # apply dropout for LSTM #before sentences, append features to each sentence sentence_features = [sentence_encoded_response] if response_features is not None: sentence_features.append(response_features) if self._feature_feedforward is not None: #TODO: need to compress response features also? why doesnt this cause an error - padding #print(sentence_encoded_response.shape, response_features.shape) sentence_encoded_response = self._feature_feedforward( torch.cat(sentence_features, dim=-1)) response_mask = response_mask.view( batch_size, max_response_sentences, -1).sum(dim=-1) > 0 encoded_response = self._response_encoder( sentence_encoded_response, response_mask) if original_post is not None: if idxs is not None and not compress_response: original_post = extract(original_post, idxs) embedded_op = self._op_embedder(original_post, num_wrapping_dims=1) op_mask = get_text_field_mask( {i: j for i, j in original_post.items() if i != 'mask'}, num_wrapping_dims=1).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_op = self.rnn_input_dropout(embedded_op) batch_size, max_op_sentences, max_op_words, op_dim = embedded_op.shape embedded_op = embedded_op.view(batch_size * max_op_sentences, max_op_words, -1) op_mask = op_mask.view(batch_size * max_op_sentences, max_op_words) sentence_encoded_op = self._op_word_attention( embedded_op, op_mask).view(batch_size, max_op_sentences, -1) sentence_features = [sentence_encoded_op] if op_features is not None: sentence_features.append(op_features) if self._feature_feedforward is not None: sentence_encoded_op = self._feature_feedforward( torch.cat(sentence_features, dim=-1)) op_mask = op_mask.view(batch_size, max_op_sentences, -1).sum(dim=-1) > 0 encoded_op = self._op_encoder(sentence_encoded_op, op_mask) #if idxs is not None and not compress_response: # encoded_op = extract({'tokens': encoded_op}, idxs)['tokens'] # extracted_sentences = extract(original_post, idxs) # op_mask = (get_text_field_mask(extracted_sentences, # num_wrapping_dims=1).sum(dim=-1) > 0) else: op_mask = None encoded_op = None #if idxs is not None and compress_response: # encoded_response = extract({'tokens': encoded_response}, idxs)['tokens'] # extracted_sentences = extract(response, idxs) # response_mask = (get_text_field_mask(extracted_sentences, # num_wrapping_dims=1).sum(dim=-1) > 0) #print(encoded_response.shape, response_mask.shape, encoded_op.shape, op_mask.shape) combined_input = self._response_sentence_attention( encoded_response, response_mask, encoded_op, op_mask) #now batch_size x n_dim #encoded_op = self._op_sentence_attention(encoded_op, op_mask) if self.dropout: combined_input = self.dropout(combined_input) combined_input = [combined_input] if op_doc_features is not None: combined_input.append(op_doc_features) if response_doc_features is not None: combined_input.append(response_doc_features) combined_input = torch.cat(combined_input, dim=-1) label_logits = self._output_feedforward(combined_input).squeeze(-1) label_probs = torch.sigmoid(label_logits) #print(label_probs) predictions = label_probs > 0.5 #print('predictions', predictions) #print('1-predictions', 1-predictions) #print(label) output_dict = { "label_logits": label_logits, "label_probs": label_probs, "representation": combined_input, "encoded_response": encoded_response, "response_mask": response_mask } true_weight = 1 if ((label == 1).sum().float() == 0) else ((label == 0).sum().float() / (label == 1).sum().float()) #true_weight = (label==0).sum().float() / (label==1).sum().float() #print(true_weight) weight = label.eq(0).float() + label.eq(1).float() * true_weight loss = self._loss(label_logits, label.float(), weight=weight) if fake_data: self._fake_accuracy(predictions, label.byte()) self._fake_fscore( torch.stack([1 - predictions, predictions], dim=1), label) else: self._accuracy(predictions, label.byte()) #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte()) self._fscore(torch.stack([1 - predictions, predictions], dim=1), label) output_dict["loss"] = loss return output_dict
def forward(self, response: Dict[str, torch.LongTensor], label: torch.IntTensor = None, original_post: Optional[Dict[str, torch.LongTensor]] = None, weakpoints: Optional[torch.IntTensor] = None, fake_data: bool = False) -> Dict[str, torch.Tensor]: #print(original_post) #print(response) #print('label', label) ''' print('LABEL', label[0]) for key in original_post: print('ORIGINAL POST') for i in range(original_post[key][0].size(0)): o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)] if len(o): print(o) print('RESPONSE') for i in range(response[key][0].size(0)): o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)] if len(o): print(o) ''' embedded_response = self._response_embedder(response, num_wrapping_dims=1) #print(embedded_op.shape, embedded_response.shape) batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape response_mask = get_text_field_mask(response, num_wrapping_dims=1).float() #get weighted average of words in sentence embedded_response = embedded_response.view( batch_size * max_response_sentences, max_response_words, -1) response_mask = response_mask.view(batch_size * max_response_sentences, max_response_words) # apply dropout for LSTM if self.rnn_input_dropout: embedded_response = self.rnn_input_dropout(embedded_response) #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape) response_attention = self._response_word_attention( embedded_response, response_mask) embedded_response = weighted_sum(embedded_response, response_attention).view( batch_size, max_response_sentences, -1) response_mask = response_mask.view(batch_size, max_response_sentences, -1).sum(dim=-1) > 0 #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape) # encode OP and response at sentence level encoded_response = self._response_encoder(embedded_response, response_mask) if original_post is not None: embedded_op = self._op_embedder(original_post, num_wrapping_dims=1) _, max_op_sentences, max_op_words, op_dim = embedded_op.shape op_mask = get_text_field_mask(original_post, num_wrapping_dims=1).float() embedded_op = embedded_op.view(batch_size * max_op_sentences, max_op_words, -1) op_mask = op_mask.view(batch_size * max_op_sentences, max_op_words) # apply dropout for LSTM if self.rnn_input_dropout: embedded_op = self.rnn_input_dropout(embedded_op) op_attention = self._op_word_attention(embedded_op, op_mask) embedded_op = weighted_sum(embedded_op, op_attention).view( batch_size, max_op_sentences, -1) op_mask = op_mask.view(batch_size, max_op_sentences, -1).sum(dim=-1) > 0 encoded_op = self._op_encoder(embedded_op, op_mask) combined_input = self._response_sentence_attention( encoded_op, encoded_response, op_mask, response_mask, self._op_sentence_attention) else: attn = self._op_sentence_attention(encoded_response, response_mask) combined_input = weighted_sum(encoded_response, attn) #now batch_size x n_dim #encoded_op = self._op_sentence_attention(encoded_op, op_mask) if self.dropout: combined_input = self.dropout(combined_input) label_logits = self._output_feedforward(combined_input).squeeze(-1) label_probs = torch.sigmoid(label_logits) #print(label_probs) predictions = label_probs > 0.5 #print('predictions', predictions) #print('1-predictions', 1-predictions) #print(label) output_dict = { "label_logits": label_logits, "label_probs": label_probs, "representation": combined_input } true_weight = (label == 0).sum().float() / (label == 1).sum().float() print(true_weight) weight = label.eq(0).float() + label.eq(1).float() * true_weight loss = self._loss(label_logits, label.float(), weight=weight) if fake_data: self._fake_accuracy(predictions, label.byte()) self._fake_fscore( torch.stack([1 - predictions, predictions], dim=1), label) else: self._accuracy(predictions, label.byte()) #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte()) self._fscore(torch.stack([1 - predictions, predictions], dim=1), label) output_dict["loss"] = loss return output_dict