Exemplo n.º 1
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, seq_length, embedding_dim)
        embedded_p = self._text_field_embedder(premise)
        embedded_h = self._text_field_embedder(hypothesis)

        mask_p = get_text_field_mask(premise).float()
        mask_h = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_p = self.rnn_input_dropout(embedded_p)
            embedded_h = self.rnn_input_dropout(embedded_h)

        # encode p and h
        # Shape: (batch_size, seq_length, encoding_direction_num * encoding_hidden_dim)
        encoded_p = self._encoder(embedded_p, mask_p)
        encoded_h = self._encoder(embedded_h, mask_h)

        # Shape: (batch_size, p_length, h_length)
        similarity_matrix = self._matrix_attention(encoded_p, encoded_h)

        # Shape: (batch_size, p_length, h_length)
        p2h_attention = last_dim_softmax(similarity_matrix, mask_h)
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim)
        attended_h = weighted_sum(encoded_h, p2h_attention)

        # Shape: (batch_size, h_length, p_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), mask_p)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim)
        attended_p = weighted_sum(encoded_p, h2p_attention)

        # the "enhancement" layer
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_p = torch.cat([
            encoded_p, attended_h, encoded_p - attended_h,
            encoded_p * attended_h
        ],
                               dim=-1)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_h = torch.cat([
            encoded_h, attended_p, encoded_h - attended_p,
            encoded_h * attended_p
        ],
                               dim=-1)

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        # Shape: (batch_size, seq_length, projection_hidden_dim)
        projected_enhanced_p = self._projection_feedforward(enhanced_p)
        projected_enhanced_h = self._projection_feedforward(enhanced_h)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_p = self.rnn_input_dropout(projected_enhanced_p)
            projected_enhanced_h = self.rnn_input_dropout(projected_enhanced_h)

        # Shape: (batch_size, seq_length, inference_direction_num * inference_hidden_dim)
        inferenced_p = self._inference_encoder(projected_enhanced_p, mask_p)
        inferenced_h = self._inference_encoder(projected_enhanced_h, mask_h)

        # The pooling layer -- max and avg pooling.
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim)
        pooled_p_max, _ = replace_masked_values(inferenced_p,
                                                mask_p.unsqueeze(-1),
                                                -1e7).max(dim=1)
        pooled_h_max, _ = replace_masked_values(inferenced_h,
                                                mask_h.unsqueeze(-1),
                                                -1e7).max(dim=1)

        pooled_p_avg = torch.sum(inferenced_p * mask_p.unsqueeze(-1),
                                 dim=1) / torch.sum(mask_p, 1, keepdim=True)
        pooled_h_avg = torch.sum(inferenced_h * mask_h.unsqueeze(-1),
                                 dim=1) / torch.sum(mask_h, 1, keepdim=True)

        # Now concat
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim * 2)
        pooled_p_all = torch.cat([pooled_p_avg, pooled_p_max], dim=1)
        pooled_h_all = torch.cat([pooled_h_avg, pooled_h_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            pooled_p_all = self.dropout(pooled_p_all)
            pooled_h_all = self.dropout(pooled_h_all)

        # Shape: (batch_size, output_feedforward_hidden_dim)
        output_p, output_h = self._output_feedforward(pooled_p_all,
                                                      pooled_h_all)

        distance = F.pairwise_distance(output_p, output_h)
        prediction = distance < (self._margin / 2.0)
        output_dict = {'distance': distance, "prediction": prediction}

        if label is not None:
            """
            Contrastive loss function.
            Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            """
            y = label.float()
            l1 = y * torch.pow(distance, 2) / 2.0
            l2 = (1 - y) * torch.pow(
                torch.clamp(self._margin - distance, min=0.0), 2) / 2.0
            loss = torch.mean(l1 + l2)

            self._accuracy(prediction, label.byte())

            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 2
0
    def forward(self,
                response: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                original_post: Optional[Dict[str, torch.LongTensor]] = None,
                weakpoints: Optional[torch.IntTensor] = None,
                fake_data: bool = False,
                idxs: Optional[torch.LongTensor] = None,
                op_features: list = None,
                response_features: list = None) -> Dict[str, torch.Tensor]:

        response_only_output = self._response_only_predictor(
            response,
            label,
            weakpoints=weakpoints,
            fake_data=fake_data,
            idxs=idxs,
            response_features=response_features)
        op_response_output = self._op_response_predictor(
            response, label, original_post, weakpoints, fake_data, idxs,
            op_features, response_features)
        combined_input = torch.cat([
            response_only_output['representation'],
            op_response_output['representation']
        ],
                                   dim=-1)

        #the sentence representations are of shape BxSxD (assuming same dimension)
        batch_size, max_sentences, _ = response_only_output[
            'encoded_response'].shape
        #print(response_only_output['response_mask'].shape, response_only_output['encoded_response'].shape)
        orthogonality_loss = response_only_output['response_mask'].float() * (
            response_only_output['encoded_response'] *
            op_response_output['encoded_response']).sum(
                dim=-1).abs().squeeze(-1)
        #print(orthogonality_loss.shape)
        orthogonality_loss_avg = torch.sum(
            orthogonality_loss, dim=1) / torch.sum(
                response_only_output['response_mask'].float())
        #print(orthogonality_loss_avg.shape)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)

        predictions = label_probs > 0.5

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input
        }

        true_weight = 1 if ((label == 1).sum().float()
                            == 0) else ((label == 0).sum().float() /
                                        (label == 1).sum().float())

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        #print(loss, orthogonality_loss_avg.mean())

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss + orthogonality_loss_avg.mean()

        return output_dict
Exemplo n.º 3
0
    def forward(
        self,
        response: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        original_post: Optional[Dict[str, torch.LongTensor]] = None,
        weakpoints: Optional[torch.IntTensor] = None,
        fake_data: bool = False,
        idxs: Optional[torch.LongTensor] = None,
        op_features: list = None,
        response_features: list = None,
        compress_response: bool = False,
        op_doc_features: list = None,
        response_doc_features: list = None,
    ) -> Dict[str, torch.Tensor]:

        #print(original_post)
        #print(response)
        #print('label', label)
        '''
        print('LABEL', label[0])
        for key in original_post:
            print('ORIGINAL POST')            
            for i in range(original_post[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)]
                if len(o):
                    print(o)
            print('RESPONSE') 
            for i in range(response[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)]
                if len(o):
                    print(o)
        '''

        if False:
            encoded_response = None
            response_mask = None
            combined_input = []
        else:
            if idxs is not None and compress_response:
                response = extract(response, idxs)

            #print([(key, response[key].shape if response[key] is not None else None) for key in response])
            embedded_response = self._response_embedder(response,
                                                        num_wrapping_dims=1)
            #print(embedded_op.shape, embedded_response.shape)
            response_mask = get_text_field_mask(
                {i: j
                 for i, j in response.items() if i != 'mask'},
                num_wrapping_dims=1).float()

            if self.rnn_input_dropout:
                embedded_response = self.rnn_input_dropout(embedded_response)

            # encode response at sentence level
            #print(response_mask.shape, embedded_response.shape)
            batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape
            embedded_response = embedded_response.view(
                batch_size * max_response_sentences, max_response_words, -1)
            response_mask = response_mask.view(
                batch_size * max_response_sentences, max_response_words)
            sentence_encoded_response = self._response_word_attention(
                embedded_response,
                response_mask).view(batch_size, max_response_sentences, -1)

            #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)
            # apply dropout for LSTM

            #before sentences, append features to each sentence
            sentence_features = [sentence_encoded_response]
            if response_features is not None:
                sentence_features.append(response_features)
            if self._feature_feedforward is not None:
                #TODO: need to compress response features also? why doesnt this cause an error - padding

                #print(sentence_encoded_response.shape, response_features.shape)
                sentence_encoded_response = self._feature_feedforward(
                    torch.cat(sentence_features, dim=-1))

            response_mask = response_mask.view(
                batch_size, max_response_sentences, -1).sum(dim=-1) > 0
            encoded_response = self._response_encoder(
                sentence_encoded_response, response_mask)

            if original_post is not None:
                if idxs is not None and not compress_response:
                    original_post = extract(original_post, idxs)

                embedded_op = self._op_embedder(original_post,
                                                num_wrapping_dims=1)
                op_mask = get_text_field_mask(
                    {i: j
                     for i, j in original_post.items() if i != 'mask'},
                    num_wrapping_dims=1).float()

                # apply dropout for LSTM
                if self.rnn_input_dropout:
                    embedded_op = self.rnn_input_dropout(embedded_op)

                batch_size, max_op_sentences, max_op_words, op_dim = embedded_op.shape
                embedded_op = embedded_op.view(batch_size * max_op_sentences,
                                               max_op_words, -1)
                op_mask = op_mask.view(batch_size * max_op_sentences,
                                       max_op_words)
                sentence_encoded_op = self._op_word_attention(
                    embedded_op, op_mask).view(batch_size, max_op_sentences,
                                               -1)

                sentence_features = [sentence_encoded_op]
                if op_features is not None:
                    sentence_features.append(op_features)
                if self._feature_feedforward is not None:
                    sentence_encoded_op = self._feature_feedforward(
                        torch.cat(sentence_features, dim=-1))

                op_mask = op_mask.view(batch_size, max_op_sentences,
                                       -1).sum(dim=-1) > 0
                encoded_op = self._op_encoder(sentence_encoded_op, op_mask)

                #if idxs is not None and not compress_response:
                #    encoded_op = extract({'tokens': encoded_op}, idxs)['tokens']
                #    extracted_sentences = extract(original_post, idxs)
                #    op_mask = (get_text_field_mask(extracted_sentences,
                #                                  num_wrapping_dims=1).sum(dim=-1) > 0)

            else:
                op_mask = None
                encoded_op = None

            #if idxs is not None and compress_response:
            #    encoded_response = extract({'tokens': encoded_response}, idxs)['tokens']
            #    extracted_sentences = extract(response, idxs)
            #    response_mask = (get_text_field_mask(extracted_sentences,
            #                                  num_wrapping_dims=1).sum(dim=-1) > 0)

            #print(encoded_response.shape, response_mask.shape, encoded_op.shape, op_mask.shape)
            combined_input = self._response_sentence_attention(
                encoded_response, response_mask, encoded_op, op_mask)

            #now batch_size x n_dim
            #encoded_op = self._op_sentence_attention(encoded_op, op_mask)

            if self.dropout:
                combined_input = self.dropout(combined_input)

            combined_input = [combined_input]

        if op_doc_features is not None:
            combined_input.append(op_doc_features)
        if response_doc_features is not None:
            combined_input.append(response_doc_features)

        combined_input = torch.cat(combined_input, dim=-1)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)
        #print(label_probs)
        predictions = label_probs > 0.5
        #print('predictions', predictions)
        #print('1-predictions', 1-predictions)
        #print(label)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input,
            "encoded_response": encoded_response,
            "response_mask": response_mask
        }

        true_weight = 1 if ((label == 1).sum().float()
                            == 0) else ((label == 0).sum().float() /
                                        (label == 1).sum().float())
        #true_weight = (label==0).sum().float() / (label==1).sum().float()
        #print(true_weight)

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss

        return output_dict
Exemplo n.º 4
0
    def forward(self,
                response: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                original_post: Optional[Dict[str, torch.LongTensor]] = None,
                weakpoints: Optional[torch.IntTensor] = None,
                fake_data: bool = False) -> Dict[str, torch.Tensor]:

        #print(original_post)
        #print(response)
        #print('label', label)
        '''
        print('LABEL', label[0])
        for key in original_post:
            print('ORIGINAL POST')            
            for i in range(original_post[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)]
                if len(o):
                    print(o)
            print('RESPONSE') 
            for i in range(response[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)]
                if len(o):
                    print(o)
        '''

        embedded_response = self._response_embedder(response,
                                                    num_wrapping_dims=1)

        #print(embedded_op.shape, embedded_response.shape)
        batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape

        response_mask = get_text_field_mask(response,
                                            num_wrapping_dims=1).float()

        #get weighted average of words in sentence
        embedded_response = embedded_response.view(
            batch_size * max_response_sentences, max_response_words, -1)
        response_mask = response_mask.view(batch_size * max_response_sentences,
                                           max_response_words)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_response = self.rnn_input_dropout(embedded_response)

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)

        response_attention = self._response_word_attention(
            embedded_response, response_mask)
        embedded_response = weighted_sum(embedded_response,
                                         response_attention).view(
                                             batch_size,
                                             max_response_sentences, -1)

        response_mask = response_mask.view(batch_size, max_response_sentences,
                                           -1).sum(dim=-1) > 0

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)
        # encode OP and response at sentence level
        encoded_response = self._response_encoder(embedded_response,
                                                  response_mask)

        if original_post is not None:
            embedded_op = self._op_embedder(original_post, num_wrapping_dims=1)
            _, max_op_sentences, max_op_words, op_dim = embedded_op.shape
            op_mask = get_text_field_mask(original_post,
                                          num_wrapping_dims=1).float()

            embedded_op = embedded_op.view(batch_size * max_op_sentences,
                                           max_op_words, -1)
            op_mask = op_mask.view(batch_size * max_op_sentences, max_op_words)

            # apply dropout for LSTM
            if self.rnn_input_dropout:
                embedded_op = self.rnn_input_dropout(embedded_op)

            op_attention = self._op_word_attention(embedded_op, op_mask)
            embedded_op = weighted_sum(embedded_op, op_attention).view(
                batch_size, max_op_sentences, -1)

            op_mask = op_mask.view(batch_size, max_op_sentences,
                                   -1).sum(dim=-1) > 0

            encoded_op = self._op_encoder(embedded_op, op_mask)

            combined_input = self._response_sentence_attention(
                encoded_op, encoded_response, op_mask, response_mask,
                self._op_sentence_attention)

        else:
            attn = self._op_sentence_attention(encoded_response, response_mask)
            combined_input = weighted_sum(encoded_response, attn)

        #now batch_size x n_dim
        #encoded_op = self._op_sentence_attention(encoded_op, op_mask)

        if self.dropout:
            combined_input = self.dropout(combined_input)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)
        #print(label_probs)
        predictions = label_probs > 0.5
        #print('predictions', predictions)
        #print('1-predictions', 1-predictions)
        #print(label)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input
        }

        true_weight = (label == 0).sum().float() / (label == 1).sum().float()
        print(true_weight)

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss

        return output_dict