예제 #1
0
    def forward(self, utterances, utterance_length, input_conversation_length, target_utterances,
                decode=False):
        """
        Forward of VHRED
        :param utterances: [num_utterances, max_utter_len]
        :param utterance_length: [num_utterances]
        :param input_conversation_length: [batch_size]
        :param target_utterances: [num_utterances, seq_len]
        :param decode: True or False
        :return: decoder_outputs
        """
        batch_size = input_conversation_length.size(0)
        num_utterances = utterances.size(0)
        max_conv_len = input_conversation_length.data.max().item()

        encoder_outputs, encoder_hidden = self.encoder(utterances, utterance_length)
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(num_utterances, -1)
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)

        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_conv_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([encoder_hidden_inference[i, :l, :]
                                                   for i, l in enumerate(input_conversation_length.data)])

        encoder_hidden_input = encoder_hidden[:, :-1, :]

        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input, input_conversation_length)
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        mu_prior, var_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_utterances - batch_size, self.config.z_utter_size)))
        if not decode:
            mu_posterior, var_posterior = self.posterior(context_outputs, encoder_hidden_inference_flat)
            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps

            kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior)
            kl_div = torch.sum(kl_div)
        else:
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            kl_div = None

        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        if not decode:
            decoder_outputs = self.decoder(target_utterances, init_h=decoder_init, decode=decode)
            return decoder_outputs, kl_div
        else:
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
            return prediction, kl_div
예제 #2
0
    def forward(self,
                sentences,
                sentence_length,
                input_conversation_length,
                target_sentences,
                decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(
            encoder_hidden_input, input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([
            context_outputs[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        mu_prior, var_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not decode:
            mu_posterior, var_posterior = self.posterior(
                context_outputs, encoder_hidden_inference_flat)
            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()

            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # kl_div: [num_sentneces]
            kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior,
                                   var_prior)
            kl_div = torch.sum(kl_div)
        else:
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            log_q_zx = None

        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)

            return prediction, kl_div, log_p_z, log_q_zx
예제 #3
0
    def forward(self,
                sentences,
                sentence_length,
                input_conversation_length,
                target_sentences,
                decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # Standard Gaussian prior
        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(
                    np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    encoder_hidden_input[:, indices, :] = self.unk_sent

            # context_inference_outputs: [batch_size, max_len, num_directions * context_size]
            # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size]
            context_inference_outputs, context_inference_hidden = self.context_inference(
                encoder_hidden, input_conversation_length + 1)

            # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size]
            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(
                context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(
                conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior,
                                          conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior,
                                         conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                        conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, encoder_hidden_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior,
                                          sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior,
                                         sent_var_prior).sum()
            # kl_div: [num_sentences]
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            # context_outputs: [batch_size, max_len, context_size]
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior,
                                    sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior,
                                     conv_var_prior).sum()
            log_q_zx = None

        # expand z_conv to all associated sentences
        z_conv = torch.cat([
            z.view(1, -1).expand(m.item(), self.config.z_conv_size)
            for z, m in zip(z_conv, input_conversation_length)
        ])

        # latent_context: [num_sentences, context_size + z_sent_size +
        # z_conv_size]
        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:
            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)
            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)
            return prediction, kl_div, log_p_z, log_q_zx
예제 #4
0
파일: vhcr.py 프로젝트: NoSyu/VHUCM
    def forward(self,
                utterances,
                utterance_length,
                input_conversation_length,
                target_utterances,
                decode=False):
        """
        Forward of VHRED
        :param utterances: [num_utterances, max_utter_len]
        :param utterance_length: [num_utterances]
        :param input_conversation_length: [batch_size]
        :param target_utterances: [num_utterances, seq_len]
        :param decode: True or False
        :return: decoder_outputs
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = utterances.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        encoder_outputs, encoder_hidden = self.encoder(utterances,
                                                       utterance_length)

        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        encoder_hidden_input = encoder_hidden[:, :-1, :]

        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(
                    np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    encoder_hidden_input[:, indices, :] = self.unk_sent

            context_inference_outputs, context_inference_hidden = self.context_inference(
                encoder_hidden, input_conversation_length + 1)

            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(
                context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(
                conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior,
                                          conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior,
                                         conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                        conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn(
                (num_sentences, self.config.z_utter_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, encoder_hidden_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior,
                                          sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior,
                                         sent_var_prior).sum()
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn(
                (num_sentences, self.config.z_utter_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior,
                                    sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior,
                                     conv_var_prior).sum()
            log_q_zx = None

        z_conv = torch.cat([
            z.view(1, -1).expand(m.item(), self.config.z_conv_size)
            for z, m in zip(z_conv, input_conversation_length)
        ])

        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        if not decode:
            decoder_outputs = self.decoder(target_utterances,
                                           init_h=decoder_init,
                                           decode=decode)
            # return decoder_outputs, kl_div, log_p_z, log_q_zx
            return decoder_outputs, kl_div
        else:
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)
            # return prediction, kl_div, log_p_z, log_q_zx
            return prediction, kl_div