Esempio n. 1
0
    def decode(self, decoder_input, latent_z):
        """
        decode state into word indices

        :param decoder_input: list of lists of indices
        :param latent_z: sequence context with shape of [batch_size, latent_z_size]

        :return: unnormalized logits of sentense words distribution probabilities
                     with shape of [batch_size, seq_len, word_vocab_size]

        """
        padded, lengths = pad_sentences(decoder_input,
                                        pad_idx=self.pad_idx,
                                        lpad=self.start_idx)
        embeddings = self.embedder(padded)
        embeddings = self.word_dropout(embeddings)
        [batch_size, seq_len, _] = embeddings.size()
        # decoder rnn is conditioned on context via additional bias = W_cond * z to every input token
        latent_z = t.cat([latent_z] * seq_len, 1).view(batch_size, seq_len, -1)
        embeddings = t.cat([embeddings, latent_z], 2)
        rnn = self.ae_decoder
        rnn_out, _ = rnn(embeddings)
        rnn_out = rnn_out.contiguous().view(batch_size * seq_len,
                                            self.lstm_hidden_size)
        result = self.fc(rnn_out)
        result = result.view(batch_size, seq_len, self.vocab_size)
        return result
Esempio n. 2
0
    def forward(self, input, encoder):
        """
        Encode an input into a vector representation
        params:
            input : word indices
            encoder: [pt1|pt2|v1|v2]
        """
        if encoder == 'v2':
            return self.hierarchical_forward(input)

        batch_size = len(input)
        padded, lengths = pad_sentences(input, pad_idx=self.pad_idx)
        embeddings = self.embedder(padded)
        embeddings = self.word_dropout(embeddings)
        lengths, perm_idx = lengths.sort(0, descending=True)
        embeddings = embeddings[perm_idx]
        packed = torch.nn.utils.rnn.pack_padded_sequence(
            embeddings, lengths, True)
        rnn = self.get_encoder(encoder)
        _, (_, final_state) = rnn(packed, None)
        _, unperm_idx = perm_idx.sort(0)
        final_state = final_state[:, unperm_idx]
        final_state = final_state.view(self.num_layers, 2, batch_size, self.lstm_hidden_size)[-1] \
            .transpose(0, 1).contiguous() \
            .view(batch_size, 2 * self.lstm_hidden_size)
        return final_state
Esempio n. 3
0
 def reconst_loss(self, gnd_utts, reconst):
     """
     gnd_utts is a list of lists of indices (the outer list should be a minibatch)
     reconst is a tensor with the logits from a decoder [batchsize][seqlen][vocabsize]
     """
     batch_size, seq_len, vocab_size = reconst.size()
     loss = 0
     padded, lengths = pad_sentences(gnd_utts,
                                     pad_idx=self.pad_idx,
                                     rpad=self.end_idx)
     batch_size = len(lengths)
     crit = nn.CrossEntropyLoss()
     loss += crit(reconst.view(batch_size * seq_len, vocab_size),
                  padded.view(batch_size * seq_len))
     _, argmax = reconst.max(dim=-1)
     correct = (argmax == padded)
     acc = correct.float().mean().item()
     return loss, acc
Esempio n. 4
0
    def reconst_loss(self, gnd_utts, reconst):
        """
        gnd_utts is a list of lists of indices (the outer list should be a minibatch)
        reconst is a tensor with the logits from a decoder [batchsize][seqlen][vocabsize]
        (should not have passed through softmax)

        reconst should be one token longer than the inputs in gnd_utts. the additional
        token to be predicted is the end_idx token
        """
        batch_size, seq_len, vocab_size = reconst.size()
        loss = 0
        # this pad_sentences call will add token self.end_idx at the end of each sequence
        padded, lengths = pad_sentences(gnd_utts, pad_idx=self.pad_idx, rpad=self.end_idx)
        batch_size = len(lengths)
        crit = nn.CrossEntropyLoss()
        reconst_flat = reconst.view(batch_size * seq_len, vocab_size)
        padded_flat = padded.view(batch_size * seq_len)
        loss += crit(reconst_flat, padded_flat)
        _, argmax = reconst.max(dim=-1)
        correct = (argmax == padded)
        acc = correct.float().mean().item()
        return loss, acc
Esempio n. 5
0
        conversations = list(
            filter(lambda x: len(x) >= min_turn, conversations))

        # conversations: padded_sentences
        # [n_conversations, conversation_length (various), max_sentence_length]

        # sentence_length: list of length of sentences
        # [n_conversations, conversation_length (various)]

        conversation_length = [
            min(len(conversation), max_conv_len)
            for conversation in conversations
        ]

        sentences, sentence_length = pad_sentences(
            conversations,
            max_sentence_length=max_sent_len,
            max_conversation_length=max_conv_len)

        print('Saving preprocessed data at', split_data_dir)
        to_pickle(conversation_length,
                  split_data_dir.joinpath('conversation_length.pkl'))
        to_pickle(sentences, split_data_dir.joinpath('sentences.pkl'))
        to_pickle(sentence_length,
                  split_data_dir.joinpath('sentence_length.pkl'))

        if split_type == 'train':
            print('Save Vocabulary...')
            vocab = Vocab(tokenizer)
            vocab.add_dataframe(conversations)
            vocab.update(max_size=max_vocab_size, min_freq=min_freq)