Exemplo n.º 1
0
def logpx(zs, logdet):
    logpz = tf.add_n([tf.reduce_sum(utils.normal_logpdf(z, 0., 0.), axis=[1, 2, 3]) for z in zs])
    ave_logpz = tf.reduce_mean(logpz)
    ave_logdet = tf.reduce_mean(logdet)
    total_logprob = (ave_logpz + ave_logdet)
    tf.summary.scalar("logdet", ave_logdet)
    tf.summary.scalar("logp", ave_logpz)
    return total_logprob
Exemplo n.º 2
0
    def predict(self, x, y, compute_ml = False):

        ypred_logits = self.phi_y(x)
        cross_entropy_loss = F.cross_entropy(ypred_logits, y.argmax(1))
        acc = (ypred_logits.argmax(1) == y.argmax(1)).float().mean()

        if compute_ml:
            
            S = 64
            ypred_logits = self.phi_y(x)
            ypred = F.softmax(ypred_logits, 1)
            
            ypredsamples = torch.multinomial(ypred, S , replacement=True).to(self.device)

            ml = 0.
            
            for i in range(S):
                y_sample = F.one_hot(ypredsamples[:,i].reshape(-1,), num_classes = self.dim_y).float()
                
                #distribution for q_z
                qz_param = self.phi_z(torch.cat([x, y_sample], dim=1))
                qz_mu = qz_param[:, :self.dim_z]
                qz_log_sigma_sq = qz_param[:, self.dim_z:]
                


                z_sample = self._draw_sample(qz_mu, qz_log_sigma_sq)
                #distribution p(x|y,z)
                px_param = self.theta_g(torch.cat([z_sample, y_sample], dim=1))
                px_mu = torch.sigmoid(px_param)
                
                
                log_lik =  utils.bernoulli_logpdf(x, px_mu).sum()
                log_prior_z = utils.stdnormal_logpdf(z_sample).sum()
                prior_y = 1/self.dim_y

                log_posterior_z = utils.normal_logpdf(z_sample, qz_mu, qz_log_sigma_sq).sum()
                log_posterior_y = - F.cross_entropy(ypred, ypredsamples[:,i], reduction='none').sum()
                
                ml += prior_y*(torch.exp(log_lik) *  torch.exp(log_prior_z) / (torch.exp(log_posterior_z) * torch.exp(log_posterior_y) ))
                # print(log_lik +  log_prior_z - log_posterior_z - log_posterior_y)
                


        if not compute_ml:
            return cross_entropy_loss, acc, torch.bernoulli(px_mu)
        else:
            return cross_entropy_loss, acc, ml, torch.bernoulli(px_mu)
Exemplo n.º 3
0
    def forward(self,
                sentences,
                sentence_length,
                input_conversation_length,
                target_sentences,
                decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # Standard Gaussian prior
        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(
                    np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    encoder_hidden_input[:, indices, :] = self.unk_sent

            # context_inference_outputs: [batch_size, max_len, num_directions * context_size]
            # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size]
            context_inference_outputs, context_inference_hidden = self.context_inference(
                encoder_hidden, input_conversation_length + 1)

            # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size]
            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(
                context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(
                conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior,
                                          conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior,
                                         conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                        conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, encoder_hidden_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior,
                                          sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior,
                                         sent_var_prior).sum()
            # kl_div: [num_sentences]
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            # context_outputs: [batch_size, max_len, context_size]
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior,
                                    sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior,
                                     conv_var_prior).sum()
            log_q_zx = None

        # expand z_conv to all associated sentences
        z_conv = torch.cat([
            z.view(1, -1).expand(m.item(), self.config.z_conv_size)
            for z, m in zip(z_conv, input_conversation_length)
        ])

        # latent_context: [num_sentences, context_size + z_sent_size +
        # z_conv_size]
        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:
            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)
            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)
            return prediction, kl_div, log_p_z, log_q_zx
Exemplo n.º 4
0
    def forward(self,
                sentences,
                sentence_length,
                input_conversation_length,
                target_sentences,
                decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(
            encoder_hidden_input, input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([
            context_outputs[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        mu_prior, var_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not decode:
            mu_posterior, var_posterior = self.posterior(
                context_outputs, encoder_hidden_inference_flat)
            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()

            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # kl_div: [num_sentneces]
            kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior,
                                   var_prior)
            kl_div = torch.sum(kl_div)
        else:
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            log_q_zx = None

        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)

            return prediction, kl_div, log_p_z, log_q_zx
Exemplo n.º 5
0
    def forward(self, x_l_p, y_l, x_u_p):
        """
        x_l, x_u: [0, 1] normalized pixel entries
        y_l: one-hot label
        """
        
        ## LABELLED PART ##
        # self.eval()

        # Binarize labelled images
        with torch.no_grad():
            x_l = torch.bernoulli(x_l_p)
        # x_l = x_l_p

        # Compute posterior of z:  q(z | x, y)
        # qz_l_param = self.phi_z(torch.cat([x_l, torch.argmax(y_l, dim=1, keepdim=True).float()], dim=1))
        qz_l_param = self.phi_z(torch.cat([x_l, y_l], dim=1))
        qz_l_mu = qz_l_param[:, :self.dim_z]
        qz_l_log_sigma_sq = qz_l_param[:, self.dim_z:]

        # Sample from z posterior
        z_l_sample = self._draw_sample(qz_l_mu, qz_l_log_sigma_sq)

        # Compute p(x | y, z)
        # px_l_param = self.theta_g(torch.cat([z_l_sample, torch.argmax(y_l, dim=1, keepdim=True).float()], dim=1))
        # px_l_param = self.theta_g(torch.cat([z_l_sample, y_l], dim=1))
        # px_l_mu = px_l_param[:, :self.dim_x]
        # px_l_log_sigma_sq = px_l_param[:, self.dim_x:]
        px_l_param = self.theta_g(torch.cat([z_l_sample, y_l], dim=1))
        px_l_mu = torch.sigmoid(px_l_param)

        # Compute L_l(x, y) in Eq (2)
        L_l = self._L(x_l, px_l_mu, y_l, z_l_sample, qz_l_mu, qz_l_log_sigma_sq)
        # print("Labelled")
        L_l = L_l.sum()

        # Discriminator term
        # ypred_l_logits = self.phi_y(x_l)
        ypred_l_disriminative_logits = self.theta_d(x_l)

        D_l = - F.cross_entropy(ypred_l_disriminative_logits, y_l.argmax(1), reduction='none')
        
        D_l = D_l.sum()  

        ## PRIOR PART ##
        
        L_prior = 0.

        # for i,weight in enumerate(self.theta_d):
        #     if 'Linear' in str(weight.type):
        #         alpha = 0.001
        #         lambda_d = (1-alpha)/(alpha)
        #         flattened_weight = torch.cat([weight.weight.reshape((-1,)), weight.bias.reshape((-1,))], dim=0)
        #         flattened_mean = torch.cat([self.phi_y[i].weight.reshape((-1,)), self.phi_y[i].bias.reshape((-1,))], dim=0)
        #         flattened_sigma_sq = (torch.ones(flattened_mean.size()) / lambda_d ).to(self.device)
        #         flattened_log_sigma_sq = torch.log(flattened_sigma_sq)
        #         L_prior += utils.normal_logpdf(flattened_weight, flattened_mean, flattened_log_sigma_sq).sum()
                # d = flattened_weight.shape[0]
                # L_prior += (d/2)*np.log(lambda_d) + utils.stdnormal_logpdf(np.sqrt(lambda_d)*(flattened_weight - flattened_mean)).sum()

        for i, weight in enumerate(self.theta_d.epsilon_list):
            if 'Linear' in str(weight.type):
                alpha = 0.001
                lambda_d = (1-alpha)/(alpha)
                flattened_weight = torch.cat([weight.weight.reshape((-1,)), weight.bias.reshape((-1,))], dim=0)
                # flattened_mean = torch.cat([self.phi_y[i].weight.reshape((-1,)), self.phi_y[i].bias.reshape((-1,))], dim=0)
                flattened_mean = torch.zeros(flattened_weight.size()).to(self.device)
                flattened_sigma_sq = (torch.ones(flattened_mean.size()) / lambda_d ).to(self.device)
                flattened_log_sigma_sq = torch.log(flattened_sigma_sq)
                L_prior += utils.normal_logpdf(flattened_weight, flattened_mean, flattened_log_sigma_sq).sum()
                # d = flattened_weight.shape[0]
                # L_prior += (d/2)*np.log(lambda_d) + utils.stdnormal_logpdf(np.sqrt(lambda_d)*(flattened_weight - flattened_mean)).sum()            
        
        for i,weight in enumerate(self.theta_g):
            if 'Linear' in str(weight.type):
                flattened_weight = torch.cat([weight.weight.reshape((-1,)), weight.bias.reshape((-1,))], dim=0)
                L_prior += utils.stdnormal_logpdf(flattened_weight).sum()

        for i,weight in enumerate(self.phi_z):
            if 'Linear' in str(weight.type):
                flattened_weight = torch.cat([weight.weight.reshape((-1,)), weight.bias.reshape((-1,))], dim=0)
                L_prior += utils.stdnormal_logpdf(flattened_weight).sum()
        for i,weight in enumerate(self.phi_y):
            if 'Linear' in str(weight.type):
                flattened_weight = torch.cat([weight.weight.reshape((-1,)), weight.bias.reshape((-1,))], dim=0)
                L_prior += utils.stdnormal_logpdf(flattened_weight).sum()
 
        # print(L_prior)
        
        
        ## UNLABELLED PART ##
        # self.train()

        # Binarize unlabelled images
        with torch.no_grad():
            x_u = torch.bernoulli(x_u_p)
        # x_u = x_u_p

        # Estimate logits of q(y | x)
        y_u_logits = self.phi_y(x_u)
        y_u = F.softmax(y_u_logits, dim=1)

        # Compute L_l(x, yhat) in Eq (3)
        L_lhat = torch.zeros(y_u.shape).to(self.device)

        for label in range(self.dim_y):

            yhat = torch.eye(self.dim_y)[label].unsqueeze(0).to(self.device)
            yhat = yhat.expand(x_u.size(0), -1)

            # Compute posterior of z:  q(z | x, yhat)
            # qz_u_param = self.phi_z(torch.cat([x_u, torch.argmax(yhat, dim=1, keepdim=True).float()], dim=1))
            qz_u_param = self.phi_z(torch.cat([x_u, yhat], dim=1))
            qz_u_mu = qz_u_param[:, :self.dim_z]
            qz_u_log_sigma_sq = qz_u_param[:, self.dim_z:]
            
            # Sample from z posterior
            z_u_sample = self._draw_sample(qz_u_mu, qz_u_log_sigma_sq)

            # Compute p(x | yhat, z)
            # px_u_param = self.theta_g(torch.cat([z_u_sample, torch.argmax(yhat, dim=1, keepdim=True).float()], dim=1))
            # px_u_param = self.theta_g(torch.cat([z_u_sample, yhat], dim=1))
            # px_u_mu = px_u_param[:, :self.dim_x]
            # px_u_log_sigma_sq = px_u_param[:, self.dim_x:]
            px_u_param = self.theta_g(torch.cat([z_u_sample, yhat], dim=1))
            px_u_mu = torch.sigmoid(px_u_param)

            # Compute L_l(x, yhat) in Eq (2)
            _L_lhat = self._L(x_u, px_u_mu, yhat, z_u_sample, qz_u_mu, qz_u_log_sigma_sq)
            _L_lhat = _L_lhat.unsqueeze(1)

            if label == 0:
                L_lhat = _L_lhat
            else:
                L_lhat = torch.cat([L_lhat, _L_lhat], dim=1)

        # Compute L_U(x) in Eq (3)
        # print(L_lhat)
        assert L_lhat.size() == y_u.size()
        L_u = y_u * (L_lhat - torch.log(y_u))
        L_u = L_u.sum(1).sum()


        ## TOTAL L ##


        # Compute L(x, y) in Eq (4)
        L_tot = L_l +  D_l + L_u 
        # print(L_tot)
        L_tot = (L_tot*self.num_batches + L_prior)
        # print(L_prior)
        # print('###',L_tot)
        loss = - L_tot / (self.batch_size*self.num_batches)


        return loss 
Exemplo n.º 6
0
Arquivo: vhcr.py Projeto: NoSyu/VHUCM
    def forward(self,
                utterances,
                utterance_length,
                input_conversation_length,
                target_utterances,
                decode=False):
        """
        Forward of VHRED
        :param utterances: [num_utterances, max_utter_len]
        :param utterance_length: [num_utterances]
        :param input_conversation_length: [batch_size]
        :param target_utterances: [num_utterances, seq_len]
        :param decode: True or False
        :return: decoder_outputs
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = utterances.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        encoder_outputs, encoder_hidden = self.encoder(utterances,
                                                       utterance_length)

        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences + batch_size, -1)

        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1] + 1)), 0)
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
            for s, l in zip(start.data.tolist(),
                            input_conversation_length.data.tolist())
        ], 0)

        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat([
            encoder_hidden_inference[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        encoder_hidden_input = encoder_hidden[:, :-1, :]

        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(
                    np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    encoder_hidden_input[:, indices, :] = self.unk_sent

            context_inference_outputs, context_inference_hidden = self.context_inference(
                encoder_hidden, input_conversation_length + 1)

            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(
                context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(
                conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior,
                                          conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior,
                                         conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                        conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn(
                (num_sentences, self.config.z_utter_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, encoder_hidden_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior,
                                          sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior,
                                         sent_var_prior).sum()
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1,
                                        z_conv.size(1)).expand(
                                            z_conv.size(0), max_len,
                                            z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            context_outputs = torch.cat([
                context_outputs[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])

            z_conv_flat = torch.cat([
                z_conv_expand[i, :l, :]
                for i, l in enumerate(input_conversation_length.data)
            ])
            sent_mu_prior, sent_var_prior = self.sent_prior(
                context_outputs, z_conv_flat)
            eps = to_var(torch.randn(
                (num_sentences, self.config.z_utter_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior,
                                    sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior,
                                     conv_var_prior).sum()
            log_q_zx = None

        z_conv = torch.cat([
            z.view(1, -1).expand(m.item(), self.config.z_conv_size)
            for z, m in zip(z_conv, input_conversation_length)
        ])

        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1, self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        if not decode:
            decoder_outputs = self.decoder(target_utterances,
                                           init_h=decoder_init,
                                           decode=decode)
            # return decoder_outputs, kl_div, log_p_z, log_q_zx
            return decoder_outputs, kl_div
        else:
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)
            # return prediction, kl_div, log_p_z, log_q_zx
            return prediction, kl_div
Exemplo n.º 7
0
    # build graph
    z_init, _ = net(x, "net", args.n_levels, args.depth, width=args.width, init=True)
    z, logdet = net(x, "net", args.n_levels, args.depth, width=args.width)

    # train to optimize logp(x)
    if args.finetune == 0:
        print("Training for generation")
        # input reconstructions
        x_recons, _ = net(z, "net", args.n_levels, args.depth, width=args.width, backward=True)
        # samples
        z_samp = [tf.random_normal(tf.shape(_z)) for _z in z]
        x_samp, _ = net(z_samp, "net", args.n_levels, args.depth, width=args.width, backward=True)

        # compute objective logp(x)
        logpz = tf.add_n([tf.reduce_sum(utils.normal_logpdf(z_l, 0., 0.), axis=[1, 2, 3]) for z_l in z])
        logpx = logpz + logdet
        objective = tf.reduce_mean(logpx) - np.log(args.n_bins_x) * np.prod(gs(x)[1:])
        loss = -objective

        # create optimizer
        lr = tf.placeholder(tf.float32, [], name="lr")
        optim = tf.train.AdamOptimizer(lr)
        opt = optim.minimize(loss)

        # summaries and visualizations
        x_recons = utils.postprocess(x_recons, n_bits_x=args.n_bits_x)
        x_samp = utils.postprocess(x_samp, n_bits_x=args.n_bits_x)
        recons_error = tf.reduce_mean(tf.square(tf.to_float(utils.postprocess(x, n_bits_x=args.n_bits_x) - x_recons)))
        tf.summary.image("x_sample", x_samp)
        tf.summary.image("x_recons", x_recons)