示例#1
0
    def _gen_sample(self, loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor:
        next_obs = self._transition(loc, scale_tril, time)
        if not self.horizon:
            return next_obs

        # Filter results
        # We're in an absorving state if the current timestep is the horizon
        return nt.where(time.eq(self.horizon), pack_obs(loc, time), next_obs)
示例#2
0
 def _trans_logp(
     loc: Tensor,
     scale_tril: Tensor,
     cur_time: IntTensor,
     state: Tensor,
     time: IntTensor,
 ) -> Tensor:
     loc, scale_tril = nt.unnamed(loc, scale_tril)
     dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril)
     trans_logp: Tensor = dist.log_prob(nt.unnamed(state))
     trans_logp = nt.where(
         # Logp only defined at next timestep
         time.eq(cur_time + 1),
         trans_logp,
         torch.full(time.shape, fill_value=float("nan")),
     )
     # We assume time is a named scalar tensor
     return trans_logp.refine_names(*time.names)
示例#3
0
    def forward(self,
                response: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                original_post: Optional[Dict[str, torch.LongTensor]] = None,
                weakpoints: Optional[torch.IntTensor] = None,
                fake_data: bool = False,
                idxs: Optional[torch.LongTensor] = None,
                op_features: list = None,
                response_features: list = None) -> Dict[str, torch.Tensor]:

        response_only_output = self._response_only_predictor(
            response,
            label,
            weakpoints=weakpoints,
            fake_data=fake_data,
            idxs=idxs,
            response_features=response_features)
        op_response_output = self._op_response_predictor(
            response, label, original_post, weakpoints, fake_data, idxs,
            op_features, response_features)
        combined_input = torch.cat([
            response_only_output['representation'],
            op_response_output['representation']
        ],
                                   dim=-1)

        #the sentence representations are of shape BxSxD (assuming same dimension)
        batch_size, max_sentences, _ = response_only_output[
            'encoded_response'].shape
        #print(response_only_output['response_mask'].shape, response_only_output['encoded_response'].shape)
        orthogonality_loss = response_only_output['response_mask'].float() * (
            response_only_output['encoded_response'] *
            op_response_output['encoded_response']).sum(
                dim=-1).abs().squeeze(-1)
        #print(orthogonality_loss.shape)
        orthogonality_loss_avg = torch.sum(
            orthogonality_loss, dim=1) / torch.sum(
                response_only_output['response_mask'].float())
        #print(orthogonality_loss_avg.shape)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)

        predictions = label_probs > 0.5

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input
        }

        true_weight = 1 if ((label == 1).sum().float()
                            == 0) else ((label == 0).sum().float() /
                                        (label == 1).sum().float())

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        #print(loss, orthogonality_loss_avg.mean())

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss + orthogonality_loss_avg.mean()

        return output_dict
示例#4
0
    def forward(
        self,
        response: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        original_post: Optional[Dict[str, torch.LongTensor]] = None,
        weakpoints: Optional[torch.IntTensor] = None,
        fake_data: bool = False,
        idxs: Optional[torch.LongTensor] = None,
        op_features: list = None,
        response_features: list = None,
        compress_response: bool = False,
        op_doc_features: list = None,
        response_doc_features: list = None,
    ) -> Dict[str, torch.Tensor]:

        #print(original_post)
        #print(response)
        #print('label', label)
        '''
        print('LABEL', label[0])
        for key in original_post:
            print('ORIGINAL POST')            
            for i in range(original_post[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)]
                if len(o):
                    print(o)
            print('RESPONSE') 
            for i in range(response[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)]
                if len(o):
                    print(o)
        '''

        if False:
            encoded_response = None
            response_mask = None
            combined_input = []
        else:
            if idxs is not None and compress_response:
                response = extract(response, idxs)

            #print([(key, response[key].shape if response[key] is not None else None) for key in response])
            embedded_response = self._response_embedder(response,
                                                        num_wrapping_dims=1)
            #print(embedded_op.shape, embedded_response.shape)
            response_mask = get_text_field_mask(
                {i: j
                 for i, j in response.items() if i != 'mask'},
                num_wrapping_dims=1).float()

            if self.rnn_input_dropout:
                embedded_response = self.rnn_input_dropout(embedded_response)

            # encode response at sentence level
            #print(response_mask.shape, embedded_response.shape)
            batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape
            embedded_response = embedded_response.view(
                batch_size * max_response_sentences, max_response_words, -1)
            response_mask = response_mask.view(
                batch_size * max_response_sentences, max_response_words)
            sentence_encoded_response = self._response_word_attention(
                embedded_response,
                response_mask).view(batch_size, max_response_sentences, -1)

            #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)
            # apply dropout for LSTM

            #before sentences, append features to each sentence
            sentence_features = [sentence_encoded_response]
            if response_features is not None:
                sentence_features.append(response_features)
            if self._feature_feedforward is not None:
                #TODO: need to compress response features also? why doesnt this cause an error - padding

                #print(sentence_encoded_response.shape, response_features.shape)
                sentence_encoded_response = self._feature_feedforward(
                    torch.cat(sentence_features, dim=-1))

            response_mask = response_mask.view(
                batch_size, max_response_sentences, -1).sum(dim=-1) > 0
            encoded_response = self._response_encoder(
                sentence_encoded_response, response_mask)

            if original_post is not None:
                if idxs is not None and not compress_response:
                    original_post = extract(original_post, idxs)

                embedded_op = self._op_embedder(original_post,
                                                num_wrapping_dims=1)
                op_mask = get_text_field_mask(
                    {i: j
                     for i, j in original_post.items() if i != 'mask'},
                    num_wrapping_dims=1).float()

                # apply dropout for LSTM
                if self.rnn_input_dropout:
                    embedded_op = self.rnn_input_dropout(embedded_op)

                batch_size, max_op_sentences, max_op_words, op_dim = embedded_op.shape
                embedded_op = embedded_op.view(batch_size * max_op_sentences,
                                               max_op_words, -1)
                op_mask = op_mask.view(batch_size * max_op_sentences,
                                       max_op_words)
                sentence_encoded_op = self._op_word_attention(
                    embedded_op, op_mask).view(batch_size, max_op_sentences,
                                               -1)

                sentence_features = [sentence_encoded_op]
                if op_features is not None:
                    sentence_features.append(op_features)
                if self._feature_feedforward is not None:
                    sentence_encoded_op = self._feature_feedforward(
                        torch.cat(sentence_features, dim=-1))

                op_mask = op_mask.view(batch_size, max_op_sentences,
                                       -1).sum(dim=-1) > 0
                encoded_op = self._op_encoder(sentence_encoded_op, op_mask)

                #if idxs is not None and not compress_response:
                #    encoded_op = extract({'tokens': encoded_op}, idxs)['tokens']
                #    extracted_sentences = extract(original_post, idxs)
                #    op_mask = (get_text_field_mask(extracted_sentences,
                #                                  num_wrapping_dims=1).sum(dim=-1) > 0)

            else:
                op_mask = None
                encoded_op = None

            #if idxs is not None and compress_response:
            #    encoded_response = extract({'tokens': encoded_response}, idxs)['tokens']
            #    extracted_sentences = extract(response, idxs)
            #    response_mask = (get_text_field_mask(extracted_sentences,
            #                                  num_wrapping_dims=1).sum(dim=-1) > 0)

            #print(encoded_response.shape, response_mask.shape, encoded_op.shape, op_mask.shape)
            combined_input = self._response_sentence_attention(
                encoded_response, response_mask, encoded_op, op_mask)

            #now batch_size x n_dim
            #encoded_op = self._op_sentence_attention(encoded_op, op_mask)

            if self.dropout:
                combined_input = self.dropout(combined_input)

            combined_input = [combined_input]

        if op_doc_features is not None:
            combined_input.append(op_doc_features)
        if response_doc_features is not None:
            combined_input.append(response_doc_features)

        combined_input = torch.cat(combined_input, dim=-1)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)
        #print(label_probs)
        predictions = label_probs > 0.5
        #print('predictions', predictions)
        #print('1-predictions', 1-predictions)
        #print(label)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input,
            "encoded_response": encoded_response,
            "response_mask": response_mask
        }

        true_weight = 1 if ((label == 1).sum().float()
                            == 0) else ((label == 0).sum().float() /
                                        (label == 1).sum().float())
        #true_weight = (label==0).sum().float() / (label==1).sum().float()
        #print(true_weight)

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss

        return output_dict
示例#5
0
    def forward(self,
                response: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                original_post: Optional[Dict[str, torch.LongTensor]] = None,
                weakpoints: Optional[torch.IntTensor] = None,
                fake_data: bool = False) -> Dict[str, torch.Tensor]:

        #print(original_post)
        #print(response)
        #print('label', label)
        '''
        print('LABEL', label[0])
        for key in original_post:
            print('ORIGINAL POST')            
            for i in range(original_post[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in original_post[key][0][i] if int(index)]
                if len(o):
                    print(o)
            print('RESPONSE') 
            for i in range(response[key][0].size(0)):
                o = [self.vocab.get_token_from_index(int(index), key) for index in response[key][0][i] if int(index)]
                if len(o):
                    print(o)
        '''

        embedded_response = self._response_embedder(response,
                                                    num_wrapping_dims=1)

        #print(embedded_op.shape, embedded_response.shape)
        batch_size, max_response_sentences, max_response_words, response_dim = embedded_response.shape

        response_mask = get_text_field_mask(response,
                                            num_wrapping_dims=1).float()

        #get weighted average of words in sentence
        embedded_response = embedded_response.view(
            batch_size * max_response_sentences, max_response_words, -1)
        response_mask = response_mask.view(batch_size * max_response_sentences,
                                           max_response_words)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_response = self.rnn_input_dropout(embedded_response)

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)

        response_attention = self._response_word_attention(
            embedded_response, response_mask)
        embedded_response = weighted_sum(embedded_response,
                                         response_attention).view(
                                             batch_size,
                                             max_response_sentences, -1)

        response_mask = response_mask.view(batch_size, max_response_sentences,
                                           -1).sum(dim=-1) > 0

        #print(embedded_op.shape, op_mask.shape, embedded_response.shape, response_mask.shape)
        # encode OP and response at sentence level
        encoded_response = self._response_encoder(embedded_response,
                                                  response_mask)

        if original_post is not None:
            embedded_op = self._op_embedder(original_post, num_wrapping_dims=1)
            _, max_op_sentences, max_op_words, op_dim = embedded_op.shape
            op_mask = get_text_field_mask(original_post,
                                          num_wrapping_dims=1).float()

            embedded_op = embedded_op.view(batch_size * max_op_sentences,
                                           max_op_words, -1)
            op_mask = op_mask.view(batch_size * max_op_sentences, max_op_words)

            # apply dropout for LSTM
            if self.rnn_input_dropout:
                embedded_op = self.rnn_input_dropout(embedded_op)

            op_attention = self._op_word_attention(embedded_op, op_mask)
            embedded_op = weighted_sum(embedded_op, op_attention).view(
                batch_size, max_op_sentences, -1)

            op_mask = op_mask.view(batch_size, max_op_sentences,
                                   -1).sum(dim=-1) > 0

            encoded_op = self._op_encoder(embedded_op, op_mask)

            combined_input = self._response_sentence_attention(
                encoded_op, encoded_response, op_mask, response_mask,
                self._op_sentence_attention)

        else:
            attn = self._op_sentence_attention(encoded_response, response_mask)
            combined_input = weighted_sum(encoded_response, attn)

        #now batch_size x n_dim
        #encoded_op = self._op_sentence_attention(encoded_op, op_mask)

        if self.dropout:
            combined_input = self.dropout(combined_input)

        label_logits = self._output_feedforward(combined_input).squeeze(-1)
        label_probs = torch.sigmoid(label_logits)
        #print(label_probs)
        predictions = label_probs > 0.5
        #print('predictions', predictions)
        #print('1-predictions', 1-predictions)
        #print(label)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "representation": combined_input
        }

        true_weight = (label == 0).sum().float() / (label == 1).sum().float()
        print(true_weight)

        weight = label.eq(0).float() + label.eq(1).float() * true_weight
        loss = self._loss(label_logits, label.float(), weight=weight)

        if fake_data:
            self._fake_accuracy(predictions, label.byte())
            self._fake_fscore(
                torch.stack([1 - predictions, predictions], dim=1), label)
        else:
            self._accuracy(predictions, label.byte())
            #self._cat_accuracy(torch.stack([1-predictions, predictions], dim=1), label.byte())
            self._fscore(torch.stack([1 - predictions, predictions], dim=1),
                         label)

        output_dict["loss"] = loss

        return output_dict