예제 #1
0
    def forward(
        self,
        x,
        y,
        z_given_x,
        n_samples=1,
        block_grad=False,
    ):
        """
        z_given_x inference network is trained separately with elbo objective
        """

        # for k particles of z, repeat x k times
        x = repeat_newdim(x, n_samples, -2)
        y = repeat_newdim(y, n_samples, -2)
        z, _, _, log_posterior_z_given_x = z_given_x(x)

        x_reduced = self.x_encoder(x)
        z_reduced_xz = self.z_encoder_xz(z)
        xz_reduced = torch.cat((x_reduced, z_reduced_xz), dim=-1)
        v, _, _, log_posterior_v_given_xz = self.v_given_xz(xz_reduced)

        log_prior_z = log_normal_prior(z)
        log_prior_v = log_normal_prior(v)

        z_reduced_zv = self.z_encoder_zv(z)
        v_reduced = self.v_encoder(v)
        zv_reduced = torch.cat((z_reduced_zv, v_reduced), dim=-1)
        x_recon = self.x_given_zv(zv_reduced)
        y_recon = self.y_given_z(z)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xz.squeeze_(dim=-1)
        log_posterior_z_given_x.squeeze_(dim=-1)



        w = img_log_likelihood_x + img_log_likelihood_y + log_prior_z + log_prior_v \
            - log_posterior_z_given_x - log_posterior_v_given_xz

        # elbo
        if w.shape[-1] != 1:
            elbo = w.mean(dim=-1, keepdim=True)
            iw_elbo = log_mean_exp(w, dim=-1, detach=True)

        elbo.squeeze_(-1)
        iw_elbo.squeeze_(-1)

        return iw_elbo, elbo
예제 #2
0
    def forward(
        self,
        x,
        y,
        n_samples_z=1,
        n_samples_v=1,
        inner_method='iwae',
        outer_method='elbo',
        block_grad=False,
    ):
        """
        For Full training Method (ELBO, IWAE), n_samples_v should be 1, and inner_method doesnt matter. outer_method decides the type of bound

        For repeating random samples, I repeat the input data. I don't know if this is faster
        However, I need to repeat input data to compute likelihood regardless

        What to repeat in the inner bound: v,z and x
        """

        # for k particles of z, repeat x k times
        x_outer = repeat_newdim(x, n_samples_z, -2)
        z_outer, _, _, log_posterior_z_given_x = self.z_given_x(x_outer)

        # for n_samples_v particles of v, repeat z n_samples_v times, x_re k times and y n_samples_z * n_samples_v times
        #TODO: very akward, better way?
        y_outer = repeat_newdim(y, n_samples_z, -2)
        x_inner = repeat_newdim(x_outer, n_samples_v, -2)
        z_inner = repeat_newdim(z_outer, n_samples_v, -2)

        ##############################
        # tentative
        if block_grad:
            z_inner = z_inner.detach()
            assert (z_inner.requires_grad is False)

        # thus, for each (x,y) pair, there are {n_samples_z * n_samples_v} v particles
        # there are z_outer and z_inner for the fact that log_q(z_outer|x) in the outer sum can be used right away
        x_reduced = self.x_encoder(x_inner)
        z_reduced_xz = self.z_encoder_xz(z_inner)
        xz_reduced = torch.cat((x_reduced, z_reduced_xz), dim=-1)
        v_inner, _, _, log_posterior_v_given_xz = self.v_given_xz(xz_reduced)

        log_prior_z = log_normal_prior(z_outer)
        log_prior_v = log_normal_prior(v_inner)

        z_reduced_zv = self.z_encoder_zv(z_inner)
        v_reduced = self.v_encoder(v_inner)
        zv_reduced = torch.cat((z_reduced_zv, v_reduced), dim=-1)
        x_recon = self.x_given_zv(zv_reduced)
        y_recon = self.y_given_z(z_outer)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x_inner)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y_outer)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        # compute iwae q(v|x,y,z) | p(x,y,z,v)
        # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xz.squeeze_(dim=-1)
        log_posterior_z_given_x.squeeze_(dim=-1)

        inner_lowerbound = img_log_likelihood_x + log_prior_v - log_posterior_v_given_xz

        if inner_method == 'hybrid' and outer_method == 'hybrid':
            assert (n_samples_v == n_samples_z)
            w = img_log_likelihood_x + img_log_likelihood_y + log_prior_z + log_prior_v \
                - log_posterior_z_given_x - log_posterior_v_given_xz

            if w.shape[-1] != 1:
                # elbo
                elbo = w.mean(dim=-1, keepdim=True)
                # iwae
                iw_elbo = log_mean_exp(w, dim=-1, detach=True)

            elbo.squeeze_(-1)
            iw_elbo.squeeze_(-1)

        return iw_elbo, elbo

        if inner_lowerbound.shape[-1] != 1:
            if inner_method == 'iwae':
                inner_lowerbound = log_mean_exp(inner_lowerbound,
                                                dim=-1,
                                                detach=True)
                inner_lowerbound = inner_lowerbound.sum(dim=-1, keepdim=True)
            elif inner_method == 'elbo':
                inner_lowerbound = inner_lowerbound.mean(dim=-1, keepdim=True)
        inner_lowerbound.squeeze_(-1)

        log_p_x_lowerbound = inner_lowerbound + log_prior_z - log_posterior_z_given_x
        outer_lowerbound = log_p_x_lowerbound + img_log_likelihood_y
        # outer_lowerbound = inner_lowerbound - log_posterior_z_given_x
        if outer_lowerbound.shape[-1] != 1:
            if outer_method == 'iwae':
                outer_lowerbound = log_mean_exp(outer_lowerbound,
                                                dim=-1,
                                                detach=True)
                outer_lowerbound = outer_lowerbound.sum(dim=-1, keepdim=True)
            elif outer_method == 'elbo':
                outer_lowerbound = outer_lowerbound.mean(dim=-1, keepdim=True)
        outer_lowerbound.squeeze_(-1)

        return outer_lowerbound
예제 #3
0
    def forward(
        self,
        x,
        y,
        n_samples_z=1,
        n_samples_v=1,
        inner_method='iwae',
        outer_method='elbo',
        block_grad=False,
    ):
        """
        For Full training Method (ELBO, IWAE), n_samples_v should be 1, and inner_method doesnt matter. outer_method decides the type of bound

        For repeating random samples, I repeat the input data. I don't know if this is faster
        However, I need to repeat input data to compute likelihood regardless

        What to repeat in the inner bound: v,z and x
        """

        # for k particles of z, repeat x k times
        x_outer = repeat_newdim(x, n_samples_z, -2)
        z_outer, _, _, log_posterior_z_given_x = self.z_given_x(x_outer)

        # for n_samples_v particles of v, repeat z n_samples_v times, x_re k times and y n_samples_z * n_samples_v times
        #TODO: very akward, better way?
        y_outer = repeat_newdim(y, n_samples_z, -2)
        x_inner = repeat_newdim(x_outer, n_samples_v, -2)
        z_inner = repeat_newdim(z_outer, n_samples_v, -2)

        ##############################
        # tentative
        if block_grad:
            z_inner = z_inner.detach()
            assert (z_inner.requires_grad is False)

        # thus, for each (x,y) pair, there are {n_samples_z * n_samples_v} v particles
        # there are z_outer and z_inner for the fact that log_q(z_outer|x) in the outer sum can be used right away
        x_reduced = self.x_encoder(x_inner)
        z_reduced_xz = self.z_encoder_xz(z_inner)
        xz_reduced = torch.cat((x_reduced, z_reduced_xz), dim=-1)
        v_inner, _, _, log_posterior_v_given_xz = self.v_given_xz(xz_reduced)

        log_prior_z = log_normal_prior(z_outer)
        log_prior_v = log_normal_prior(v_inner)

        z_reduced_zv = self.z_encoder_zv(z_inner)
        v_reduced = self.v_encoder(v_inner)
        zv_reduced = torch.cat((z_reduced_zv, v_reduced), dim=-1)
        x_recon = self.x_given_zv(zv_reduced)
        y_recon = self.y_given_z(z_outer)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x_inner)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y_outer)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        # compute iwae q(v|x,y,z) | p(x,y,z,v)
        # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xz.squeeze_(dim=-1)
        log_posterior_z_given_x.squeeze_(dim=-1)

        inner_lowerbound = img_log_likelihood_x + log_prior_v - log_posterior_v_given_xz
        # inner_lowerbound = img_log_likelihood_x + self.reg * img_log_likelihood_y + log_prior_z + log_prior_v - log_posterior_v_given_xz
        if inner_lowerbound.shape[-1] != 1:
            if inner_method == 'iwae':
                inner_lowerbound = log_mean_exp(inner_lowerbound,
                                                dim=-1,
                                                detach=True)
                inner_lowerbound = inner_lowerbound.sum(dim=-1, keepdim=True)
            elif inner_method == 'elbo':
                inner_lowerbound = inner_lowerbound.mean(dim=-1, keepdim=True)
        inner_lowerbound.squeeze_(-1)

        log_p_x_lowerbound = inner_lowerbound + log_prior_z - log_posterior_z_given_x
        outer_lowerbound = log_p_x_lowerbound + img_log_likelihood_y
        # outer_lowerbound = inner_lowerbound - log_posterior_z_given_x
        if outer_lowerbound.shape[-1] != 1:
            if outer_method == 'iwae':
                outer_lowerbound = log_mean_exp(outer_lowerbound,
                                                dim=-1,
                                                detach=True)
                outer_lowerbound = outer_lowerbound.sum(dim=-1, keepdim=True)
            elif outer_method == 'elbo':
                outer_lowerbound = outer_lowerbound.mean(dim=-1, keepdim=True)
        outer_lowerbound.squeeze_(-1)

        return outer_lowerbound


# class FullELBO(nn.Module):
#     def __str__(self):
#         return "Full_ELBO"

#     def __init__(self, X_DIM, Y_DIM, Z_DIM, V_DIM,
#                 ):
#         """
#         #TODO: need to move batch enlargement in here.

#         :arg reg: increase weight of \log p(y|z). This option is tentative
#         """
#         super().__init__()
#         self.X_DIM = X_DIM
#         self.Y_DIM = Y_DIM
#         self.Z_DIM = Z_DIM
#         self.V_DIM = V_DIM
#         self.x_encoder = nets.LinearMap(X_DIM, [200, 200], 50)
#         self.z_encoder_xz = nets.LinearMap(Z_DIM, [200, 200], 50)
#         self.z_encoder_zv = nets.LinearMap(Z_DIM, [200, 200], 50)
#         self.v_encoder = nets.LinearMap(V_DIM, [200, 200], 50)

#         # inference networks
#         self.z_given_x = nets.GaussianStochasticEncoder(X_DIM, Z_DIM, hidden_dims=[200, 200], hidden_acts=nn.ReLU())

#         xz_reduced_dim = self.x_encoder.get_output_dim() + self.z_encoder_xz.get_output_dim()
#         self.v_given_xz = nets.GaussianStochasticEncoder(xz_reduced_dim, V_DIM, hidden_dims=[200, 200], hidden_acts=nn.ReLU())

#         # generative networks
#         self.y_given_z = nets.StochasticDecoder(Z_DIM, Y_DIM, hidden_dims=[200, 200], hidden_acts=nn.ReLU(), output_act=None)

#         zv_reduced_dim = self.z_encoder_zv.get_output_dim() + self.v_encoder.get_output_dim()
#         self.x_given_zv = nets.StochasticDecoder(zv_reduced_dim, X_DIM, hidden_dims=[200, 200], hidden_acts=nn.ReLU(), output_act=None)

#     def forward(self,
#                 x,
#                 y,
#                 n_samples=1,
#                 method='elbo',
#                 ):
#         """
#         first estimator: $VAE\left(q(z|x), IWAE(q(v|x,y,z),p(x,y,z,v))\right)$

#         For repeating random samples, I repeat the input data. I don't know if this is faster
#         However, I need to repeat input data to compute likelihood regardless

#         What to repeat in the inner bound: v,z and x
#         """

#         # for k particles of z, repeat x k times
#         x = repeat_newdim(x, n_samples, -2)
#         y = repeat_newdim(y, n_samples, -2)
#         z, _, _, log_posterior_z_given_x = self.z_given_x(x)

#         # sample v from q(v|z,x)
#         x_reduced = self.x_encoder(x)
#         z_reduced_xz = self.z_encoder_xz(z)
#         xz_reduced = torch.cat((x_reduced, z_reduced_xz), dim=-1)
#         v, _, _, log_posterior_v_given_xz = self.v_given_xz(xz_reduced)

#         log_prior_z = log_normal_prior(z)
#         log_prior_v = log_normal_prior(v)

#         z_reduced_zv = self.z_encoder_zv(z)
#         v_reduced = self.v_encoder(v)
#         zv_reduced = torch.cat((z_reduced_zv, v_reduced), dim=-1)
#         x_recon = self.x_given_zv(zv_reduced)
#         y_recon = self.y_given_z(z)

#         img_log_likelihood_x = log_prob_bernoulli(x_recon, x)
#         img_log_likelihood_y = log_prob_bernoulli(y_recon, y)

#         # for iwae, average over particles
#         # for elbo, sum over weighted particles

#         # compute iwae q(v|x,y,z) | p(x,y,z,v)
#         # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

#         img_log_likelihood_x.squeeze_(dim=-1)
#         img_log_likelihood_y.squeeze_(dim=-1)
#         log_prior_z.squeeze_(dim=-1)
#         log_prior_v.squeeze_(dim=-1)
#         log_posterior_v_given_xz.squeeze_(dim=-1)
#         log_posterior_z_given_x.squeeze_(dim=-1)

#         lowerbound = img_log_likelihood_x + img_log_likelihood_y + log_prior_z + log_prior_v - log_posterior_z_given_x - log_posterior_v_given_xz

#         if method == 'iwae':

#             lowerbound = get_importance_bound(lowerbound)
#             lowerbound = lowerbound.sum(dim=-1, keepdim=True)

#         if method == 'elbo':
#             lowerbound = lowerbound.mean(dim=-1, keepdim=True)

#         lowerbound = lowerbound.squeeze_(dim=-1)
#         return lowerbound
예제 #4
0
    def forward(
        self,
        x,
        y,
        n_samples_z=1,
        n_samples_v=1,
        inner_method='iwae',
        outer_method='elbo',
    ):
        """
        first estimator: $VAE\left(q(z|x), IWAE(q(v|x,y,z),p(x,y,z,v))\right)$


        For repeating random samples, I repeat the input data. I don't know if this is faster
        However, I need to repeat input data to compute likelihood regardless
        """

        # for k particles of z, repeat x k times
        x = repeat_newdim(x, n_samples_z, -2)

        z_outer, mu_z_given_x, var_z_given_x, log_posterior_z_given_x = self.z_given_x(
            x)

        # for n_samples_v particles of v, repeat z n_samples_v times, x_re k times and y n_samples_z * n_samples_v times
        #TODO: very akward, better way?
        x = repeat_newdim(x, n_samples_v, -2)
        z_inner = repeat_newdim(z_outer, n_samples_v, -2)
        y = repeat_newdim(y, n_samples_z, -2)
        y = repeat_newdim(y, n_samples_v, -2)

        # thus, for each (x,y) pair, there are {n_samples_z * n_samples_v} v particles
        # there are z_outer and z_inner for the fact that log_q(z_outer|x) in the outer sum can be used right away
        x_reduced = self.x_encoder(x)
        y_reduced = self.y_encoder(y)
        z_reduced = self.z_encoder(z_inner)
        xyz_reduced = torch.cat((x_reduced, y_reduced, z_reduced), dim=-1)
        v, mu, var, log_posterior_v_given_xyz = self.v_given_xyz(xyz_reduced)
        log_prior_z = log_normal_prior(z_inner)
        log_prior_v = log_normal_prior(v)

        v_reduced = self.v_encoder(v)
        zv_reduced = torch.cat((z_reduced, v_reduced), dim=-1)
        x_recon = self.x_given_zv(zv_reduced)
        y_recon = self.y_given_z(z_inner)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        # compute iwae q(v|x,y,z) | p(x,y,z,v)
        # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xyz.squeeze_(dim=-1)
        log_posterior_z_given_x.squeeze_(dim=-1)

        inner_lowerbound = img_log_likelihood_x + self.reg * img_log_likelihood_y + log_prior_z + log_prior_v - log_posterior_v_given_xyz
        if inner_method == 'iwae':
            inner_lowerbound = get_importance_bound(inner_lowerbound)
            inner_lowerbound = inner_lowerbound.sum(dim=-1, keepdim=True)
        elif inner_method == 'elbo':
            inner_lowerbound = inner_lowerbound.mean(dim=-1, keepdim=True)
        inner_lowerbound.squeeze_(-1)

        outer_lowerbound = inner_lowerbound - log_posterior_z_given_x
        if outer_method == 'iwae':
            outer_lowerbound = get_importance_bound(outer_lowerbound)
            outer_lowerbound = outer_lowerbound.sum(dim=-1, keepdim=True)
        elif outer_method == 'elbo':
            outer_lowerbound = outer_lowerbound.mean(dim=-1, keepdim=True)
        outer_lowerbound.squeeze_(-1)

        return outer_lowerbound
예제 #5
0
    def forward(
        self,
        x,
        y,
        n_samples_z=1,
        n_samples_v=1,
        inner_method='iwae',
        outer_method='elbo',
    ):
        """
        """

        # for k particles of z, repeat x, y k times
        x = repeat_newdim(x, n_samples_z, -2)
        y = repeat_newdim(y, n_samples_z, -2)

        _, mu_z_given_x, var_z_given_x, log_posterior_z_given_x = self.z_given_x(
            x)
        _, mu_z_given_y, var_z_given_y, log_posterior_z_given_y = self.z_given_y(
            y)
        mu_z_given_xy, var_z_given_xy = product_of_diag_normals(
            [mu_z_given_x, mu_z_given_y], [var_z_given_x, var_z_given_y])

        z_outer = mu_z_given_xy + torch.Tensor(
            mu_z_given_xy.shape).normal_().cuda() * var_z_given_xy.sqrt()
        log_posterior_z_given_xy = log_prob_normal(z_outer, mu_z_given_xy,
                                                   var_z_given_xy)

        # for n_samples_v particles of v, repeat z_outer, n_samples_v times, x, y another n_samples_v times
        #TODO: very akward, better way?
        z_inner = repeat_newdim(z_outer, n_samples_v, -2)
        x = repeat_newdim(x, n_samples_v, -2)
        y = repeat_newdim(y, n_samples_v, -2)

        # thus, for each (x,y) pair, there are {n_samples_z * n_samples_v} v particles
        # there are z_outer and z_inner for the fact that log_q(z_outer|x,y) in the outer sum can be used right away

        x_reduced = self.x_encoder(x)
        y_reduced = self.y_encoder(y)

        xyz_reduced = torch.cat((x_reduced, y_reduced, z_inner), dim=-1)
        v, mu, var, log_posterior_v_given_xyz = self.v_given_xyz(xyz_reduced)
        log_prior_z = log_normal_prior(z_inner)
        log_prior_v = log_normal_prior(v)

        zv = torch.cat((z_inner, v), dim=-1)
        x_recon = self.x_given_zv(zv)
        y_recon = self.y_given_z(z_inner)

        img_log_likelihood_x = log_prob_bernoulli(x_recon, x)
        img_log_likelihood_y = log_prob_bernoulli(y_recon, y)

        # for iwae, average over particles
        # for elbo, sum over weighted particles

        # compute iwae q(v|x,y,z) | p(x,y,z,v)
        # p(x,y,z,v) = p(y|z)p(x|z,v)p(z)p(v)

        img_log_likelihood_x.squeeze_(dim=-1)
        img_log_likelihood_y.squeeze_(dim=-1)
        log_prior_z.squeeze_(dim=-1)
        log_prior_v.squeeze_(dim=-1)
        log_posterior_v_given_xyz.squeeze_(dim=-1)
        log_posterior_z_given_xy.squeeze_(dim=-1)

        inner_lowerbound = img_log_likelihood_x + img_log_likelihood_y + log_prior_z + log_prior_v - log_posterior_v_given_xyz
        if inner_method == 'iwae':
            inner_lowerbound = get_importance_bound(inner_lowerbound)
            inner_lowerbound = inner_lowerbound.sum(dim=-1, keepdim=True)
        elif inner_method == 'elbo':
            inner_lowerbound = inner_lowerbound.mean(dim=-1, keepdim=True)
        inner_lowerbound.squeeze_(-1)

        outer_lowerbound = inner_lowerbound - log_posterior_z_given_xy
        if outer_method == 'iwae':
            outer_lowerbound = get_importance_bound(outer_lowerbound)
            outer_lowerbound = outer_lowerbound.sum(dim=-1, keepdim=True)
        elif outer_method == 'elbo':
            outer_lowerbound = outer_lowerbound.mean(dim=-1, keepdim=True)
        outer_lowerbound.squeeze_(-1)

        return outer_lowerbound