示例#1
0
    def sample(self, mu, logvar, k=1):

        eps = torch.FloatTensor(self.B, self.z_size).normal_().cuda()  #[B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mu  #[B,Z]

        logpz = lognormal(z,
                          torch.zeros(self.B, self.z_size).cuda(),
                          torch.zeros(self.B, self.z_size).cuda())  #[B]
        logqz = lognormal(z, mu.detach(), logvar.detach())
        # [P,B,Z], [P,B]
        return z, logpz, logqz
示例#2
0
    def sample(self, mu, logvar, k=1):

        B = mu.shape[0]

        eps = torch.FloatTensor(B, self.z_size).normal_().cuda() #[B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mu  #[B,Z]

        logpz = lognormal(z, torch.zeros(B, self.z_size).cuda(), 
                            torch.zeros(B, self.z_size).cuda()) #[B]
        logqz = lognormal(z, mu.detach(), logvar.detach())
        # [P,B,Z], [P,B]
        return z, logpz, logqz
示例#3
0
    def forward(self,
                x=None,
                q=None,
                warmup=1.,
                generate=False,
                inf_type=1,
                dec_type=0):  #, k=1): #, marginf_type=0):
        # x: [B,3,112,112]
        # q: [B,L]
        # inf type: 0 is both, 1 is only x, 2 is only y
        # dec type: 0 is both, 1 is only x, 2 is only y

        outputs = {}

        if inf_type in [0, 2] or dec_type in [0, 2]:
            embed = self.encoder_embed(q)

        if inf_type == 0:
            x_enc = self.image_encoder(x)
            y_enc = self.encode_attributes(embed)
            mu, logvar = self.inference_net(x_enc, y_enc)
            z, logpz, logqz = self.sample(mu, logvar)

        elif inf_type == 1:
            # if self.joint_inf:
            x_enc = self.image_encoder2(x)
            mu, logvar = self.inference_net_x(x_enc)
            # else:
            #     if dec_type ==0:
            #         x_enc = self.image_encoder(x)
            #         mu, logvar = self.inference_net(x_enc)
            #     else:
            #         x_enc = self.image_encoder2(x)
            #         mu, logvar = self.inference_net_x(x_enc)
            z, logpz, logqz = self.sample(mu, logvar)

        elif inf_type == 2:
            y_enc = self.encode_attributes2(embed)
            mu, logvar = self.inference_net_y(y_enc)
            if self.flow_int:
                z, logpz, logqz = self.flow.sample(mu, logvar)
            else:
                z, logpz, logqz = self.sample(mu, logvar)

        # z_prior = torch.FloatTensor(self.B, self.z_size).normal_().cuda()
        # loss, acc = self.discrim.discrim_loss(z, z_prior)
        pred = self.discrim.predict(
            z).mean()  #want to minimize this, since prior prediction = 0

        z_dec = self.z_to_enc(z)

        B = z_dec.shape[0]

        if dec_type == 0:
            # Decode Image
            x_hat = self.image_decoder(z_dec)
            alpha = torch.sigmoid(x_hat)

            beta = Beta(alpha * self.beta_scale,
                        (1. - alpha) * self.beta_scale)
            logpx = beta.log_prob(x)  #[120,3,112,112]
            logpx = torch.sum(logpx.view(B, -1), 1)  # [B]

            word_preds, logpy = self.text_generator.teacher_force(
                z_dec, embed, q)

            logpx = logpx * self.w_logpx
            logpy = logpy * self.w_logpy

            #CE of q(z|y)
            if inf_type == 1:
                embed = self.encoder_embed(q)
            y_enc = self.encode_attributes2(embed)
            mu_y, logvar_y = self.inference_net_y(y_enc)
            if self.flow_int:
                logqzy = self.flow.logprob(z.detach(), mu_y, logvar_y)
                # logqzy = self.flow.logprob(z, mu_y, logvar_y)
            else:
                logqzy = lognormal(z, mu_y, logvar_y)
            logqzy = logqzy * self.w_logqy

            log_ws = logpx + logpy + logpz - logqz  #+ logqzy
            elbo = torch.mean(log_ws)
            # warmed_elbo = torch.mean(logpx + logpy + logqzy - logqz + warmup*( logpz - logqz))
            # warmed_elbo = torch.mean(logpx + logpy + logqzy + warmup*( logpz - logqz))
            warmed_elbo = torch.mean(logpx + logpy + logqzy + warmup * (pred))
            # warmed_elbo = torch.mean(-torch.log(pred))
            # warmed_elbo = pred
            # warmed_elbo = torch.mean(logpz - logqz)

            outputs['logpx'] = torch.mean(logpx)
            outputs['x_recon'] = alpha
            outputs['logpy'] = torch.mean(logpy)
            outputs['logqzy'] = torch.mean(logqzy)

        elif dec_type == 1:
            # Decode Image
            x_hat = self.image_decoder(z_dec)
            alpha = torch.sigmoid(x_hat)

            beta = Beta(alpha * self.beta_scale,
                        (1. - alpha) * self.beta_scale)
            logpx = beta.log_prob(x)  #[120,3,112,112]

            logpx = torch.sum(logpx.view(B, -1), 1)  # [PB]  * self.w_logpx
            logpx = logpx * self.w_logpx

            log_ws = logpx + logpz - logqz

            elbo = torch.mean(log_ws)
            warmed_elbo = torch.mean(logpx + warmup * (logpz - logqz))

            outputs['logpx'] = torch.mean(logpx)
            outputs['x_recon'] = alpha

        elif dec_type == 2:
            #Decode Text
            word_preds, logpy = self.text_generator.teacher_force(
                z_dec, embed, q)
            logpy = logpy * self.w_logpy

            log_ws = logpy + logpz - logqz
            elbo = torch.mean(log_ws)
            warmed_elbo = torch.mean(logpy + warmup * (logpz - logqz))

            outputs['logpy'] = torch.mean(logpy)

        outputs['welbo'] = warmed_elbo
        outputs['elbo'] = elbo
        outputs['logws'] = log_ws
        outputs['z'] = z
        outputs['logpz'] = torch.mean(logpz)
        outputs['logqz'] = torch.mean(logqz)
        outputs['logvar'] = logvar

        if generate:

            word_preds, sampled_words = self.text_generator.teacher_force(
                z_dec, generate=generate, embeder=self.encoder_embed)
            if dec_type == 2:
                alpha = torch.sigmoid(self.image_decoder(z_dec))

            return outputs, alpha, word_preds, sampled_words

        return outputs
示例#4
0
    def forward_qy2(self,
                    qy,
                    x=None,
                    q=None,
                    warmup=1.,
                    generate=False,
                    inf_type=1,
                    dec_type=0):  #, k=1): #, marginf_type=0):
        # x: [B,3,112,112]
        # q: [B,L]
        # inf type: 0 is both, 1 is only x, 2 is only y
        # dec type: 0 is both, 1 is only x, 2 is only y

        outputs = {}

        # Y inference
        embed = self.encoder_embed(q)
        y_enc = qy.encode_attributes2_2(embed)
        mu_y, logvar_y = qy.inference_net_y_2(y_enc)
        z_y, logpz_y, logqz_y = qy.flow_cond_2.sample(mu_y,
                                                      logvar_y)  #, y_enc)

        # X inference
        x_enc = self.image_encoder2(x)
        mu_x, logvar_x = self.inference_net_x(x_enc)
        z_x, logpz_x, logqz_x = self.sample(mu_x, logvar_x)

        # Prob z_x under qy
        logqzy = qy.flow_cond_2.logprob(z_x.detach(), mu_y,
                                        logvar_y)  #, y_enc)

        # Recon y using y inference
        z_dec = self.z_to_enc(z_y)
        word_preds, logpy = self.text_generator.teacher_force(z_dec, embed,
                                                              q)  #.eval()
        logpy = logpy * self.w_logpy

        # Recon x using x inference
        alpha = torch.sigmoid(self.image_decoder(self.z_to_enc(z_x)))
        beta = Beta(alpha * self.beta_scale, (1. - alpha) * self.beta_scale)
        logpx = beta.log_prob(x)  #[120,3,112,112]
        B = x.shape[0]
        logpx = torch.sum(logpx.view(B, -1), 1)  # [B]
        logpx = logpx * self.w_logpx

        #z_x under prior
        logpz = lognormal(z_x.detach(),
                          torch.zeros(self.B, self.z_size).cuda(),
                          torch.zeros(self.B, self.z_size).cuda())

        outputs['logqzy'] = torch.mean(logqzy)
        outputs['logpz_x_under_pz'] = torch.mean(logpz)
        outputs['welbo'] = torch.mean(logpy + logpz_y - logqz_y)
        outputs['logpx_give_y'] = torch.mean(logpx + logqzy - logqz_x)

        if generate:

            z, logpz, logqz = qy.flow_cond_2.sample(mu_y, logvar_y)  #, y_enc)
            z_dec = self.z_to_enc(z)

            word_preds, sampled_words = self.text_generator.teacher_force(
                z_dec, generate=generate, embeder=self.encoder_embed)
            alpha = torch.sigmoid(self.image_decoder(z_dec))

            return outputs, alpha, word_preds, sampled_words

        return outputs