コード例 #1
0
 def forward(self, x):
     z = self.encoder(x)
     mu = self.dense_mu(z)
     var = F.softplus(self.dense_var(z)) + 1e-5
     z = mu + torch.randn(mu.shape).cuda() * var.sqrt()
     log_q_z_given_x = log_prob_normal(z, mu, var)
     return z, mu, var, log_q_z_given_x
コード例 #2
0
ファイル: models.py プロジェクト: drproduck/codebase
 def log_posterior(self, z, mu, var):
     return log_prob_normal(z, mu, var)
コード例 #3
0
ファイル: models.py プロジェクト: drproduck/codebase
 def log_prior(self, z):
     return log_prob_normal(z, 0)
コード例 #4
0
    def forward(
        self,
        x,
        y,
        n_samples_z=1,
        n_samples_v=1,
        inner_method='iwae',
        outer_method='elbo',
    ):
        """
        """

        # for k particles of z, repeat x, y k times
        x = repeat_newdim(x, n_samples_z, -2)
        y = repeat_newdim(y, n_samples_z, -2)

        _, mu_z_given_x, var_z_given_x, log_posterior_z_given_x = self.z_given_x(
            x)
        _, mu_z_given_y, var_z_given_y, log_posterior_z_given_y = self.z_given_y(
            y)
        mu_z_given_xy, var_z_given_xy = product_of_diag_normals(
            [mu_z_given_x, mu_z_given_y], [var_z_given_x, var_z_given_y])

        z_outer = mu_z_given_xy + torch.Tensor(
            mu_z_given_xy.shape).normal_().cuda() * var_z_given_xy.sqrt()
        log_posterior_z_given_xy = log_prob_normal(z_outer, mu_z_given_xy,
                                                   var_z_given_xy)

        # for n_samples_v particles of v, repeat z_outer, n_samples_v times, x, y another n_samples_v times
        #TODO: very akward, better way?
        z_inner = repeat_newdim(z_outer, n_samples_v, -2)
        x = repeat_newdim(x, n_samples_v, -2)
        y = repeat_newdim(y, n_samples_v, -2)

        # thus, for each (x,y) pair, there are {n_samples_z * n_samples_v} v particles
        # there are z_outer and z_inner for the fact that log_q(z_outer|x,y) in the outer sum can be used right away

        x_reduced = self.x_encoder(x)
        y_reduced = self.y_encoder(y)

        xyz_reduced = torch.cat((x_reduced, y_reduced, z_inner), dim=-1)
        v, mu, var, log_posterior_v_given_xyz = self.v_given_xyz(xyz_reduced)
        log_prior_z = log_normal_prior(z_inner)
        log_prior_v = log_normal_prior(v)

        zv = torch.cat((z_inner, v), dim=-1)
        x_recon = self.x_given_zv(zv)
        y_recon = self.y_given_z(z_inner)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        # compute iwae q(v|x,y,z) | p(x,y,z,v)
        # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xyz.squeeze_(dim=-1)
        log_posterior_z_given_xy.squeeze_(dim=-1)

        inner_lowerbound = img_log_likelihood_x + img_log_likelihood_y + log_prior_z + log_prior_v - log_posterior_v_given_xyz
        if inner_method == 'iwae':
            inner_lowerbound = get_importance_bound(inner_lowerbound)
            inner_lowerbound = inner_lowerbound.sum(dim=-1, keepdim=True)
        elif inner_method == 'elbo':
            inner_lowerbound = inner_lowerbound.mean(dim=-1, keepdim=True)
        inner_lowerbound.squeeze_(-1)

        outer_lowerbound = inner_lowerbound - log_posterior_z_given_xy
        if outer_method == 'iwae':
            outer_lowerbound = get_importance_bound(outer_lowerbound)
            outer_lowerbound = outer_lowerbound.sum(dim=-1, keepdim=True)
        elif outer_method == 'elbo':
            outer_lowerbound = outer_lowerbound.mean(dim=-1, keepdim=True)
        outer_lowerbound.squeeze_(-1)

        return outer_lowerbound