Ejemplo n.º 1
0
    def forward(self, imgs=None):
        q = probtorch.Trace()
        self.guide(q, imgs)

        p = probtorch.Trace()
        _, _, _, reconstruction = self.model(p, q, imgs)

        return p, q, reconstruction
Ejemplo n.º 2
0
    def test_normal(self):
        S = 3  # sample size
        B = 5  # batch size
        D = 2  # hidden dim
        mu = Variable(torch.randn(S, B, D))
        sigma = torch.exp(Variable(torch.randn(S, B, D)))
        q = probtorch.Trace()
        q.normal(mu=mu, sigma=sigma, name='z')
        z = q['z']
        value = z.value

        log_joint, log_prod_mar = q.log_batch_marginal(sample_dim=0,
                                                       batch_dim=1,
                                                       nodes=['z'])

        # compare result
        log_probs = Variable(torch.zeros(B, S, B, D))
        for b1 in range(B):
            for s in range(S):
                for b2 in range(B):
                    d = Normal(mu[s, b2], sigma[s, b2])
                    log_probs[b1, s, b2] = d.log_prob(value[s, b1])
        log_joint_2 = log_mean_exp(log_probs.sum(3), 2).transpose(0, 1)
        log_prod_mar_2 = log_mean_exp(log_probs, 2).sum(2).transpose(0, 1)

        self.assertEqual(log_joint, log_joint_2)
        self.assertEqual(log_prod_mar, log_prod_mar_2)
Ejemplo n.º 3
0
    def test_concrete(self):
        S = 3  # sample size
        B = 5  # batch size
        D = 2  # hidden dim
        K = 4  # event dim
        log_w = Variable(torch.randn(S, B, D, K))
        q = probtorch.Trace()
        q.concrete(log_w, 0.66, name='y')
        y = q['y']
        value = y.value

        log_joint, log_prod_mar = q.log_batch_marginal(sample_dim=0,
                                                       batch_dim=1,
                                                       nodes=['y'])

        # compare result
        log_probs = Variable(torch.zeros(B, S, B, D))
        for b1 in range(B):
            for s in range(S):
                for b2 in range(B):
                    d = Concrete(log_w[s, b2], 0.66)
                    log_probs[b1, s, b2] = d.log_prob(value[s, b1])
        log_joint_2 = log_mean_exp(log_probs.sum(3), 2).transpose(0, 1)
        log_prod_mar_2 = log_mean_exp(log_probs, 2).sum(2).transpose(0, 1)

        self.assertEqual(log_joint, log_joint_2)
        self.assertEqual(log_prod_mar, log_prod_mar_2)
Ejemplo n.º 4
0
    def forward(self, trace, params, times=None, guide=probtorch.Trace()):
        if times is None:
            times = (0, self._num_times)

        weight_params = utils.unsqueeze_and_expand_vardict(
            params['weights'],
            len(params['weights']['mu'].shape) - 1, times[1] - times[0], True)

        weights = trace.normal(
            weight_params['mu'],
            torch.exp(weight_params['log_sigma']),
            value=guide['Weights%dt%d-%d' % (self.block, times[0], times[1])],
            name='Weights%dt%d-%d' % (self.block, times[0], times[1]))

        factor_centers = trace.normal(
            params['factor_centers']['mu'],
            torch.exp(params['factor_centers']['log_sigma']),
            value=guide['FactorCenters' + str(self.block)],
            name='FactorCenters' + str(self.block))
        factor_log_widths = trace.normal(
            params['factor_log_widths']['mu'],
            torch.exp(params['factor_log_widths']['log_sigma']),
            value=guide['FactorLogWidths' + str(self.block)],
            name='FactorLogWidths' + str(self.block))

        return weights, factor_centers, factor_log_widths
Ejemplo n.º 5
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        shared_mean = torch.zeros_like(q['sharedA'].dist.loc)
        shared_std = torch.ones_like(q['sharedA'].dist.scale)

        p = probtorch.Trace()

        for shared_from in shared.keys():
            # prior for z_shared_atrr
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=q[shared[shared_from]],
                               name=shared[shared_from])

            # h = self.dec_hidden(zShared.squeeze(0))
            h = F.relu(self.fc1(zShared.squeeze(0)))
            h = F.relu(self.fc2(h))
            h = F.relu(self.fc3(h))
            images_mean = self.dec_image(h)
            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(
                1 - x_hat + EPS) * (1 - x)).sum(-1),
                   images_mean,
                   images,
                   name='images_' + shared_from)
        return p
Ejemplo n.º 6
0
    def forward(self,
                labels,
                shared,
                q=None,
                p=None,
                num_samples=None,
                train=True,
                CUDA=False):
        shared_mean = torch.zeros_like(q['sharedB'].dist.loc)
        shared_std = torch.ones_like(q['sharedB'].dist.scale)
        pred = {}

        p = probtorch.Trace()

        for shared_from in shared.keys():
            # prior for z_shared_atrr
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=q[shared[shared_from]],
                               name=shared[shared_from])

            # h = self.dec_hidden(zShared.squeeze(0))
            pred_labels = F.log_softmax(zShared.squeeze(0) + EPS, dim=1)
            p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \
                   pred_labels, labels, name='label_' + shared_from)

            pred.update({shared_from: pred_labels})

        if train:
            predicted_attr = pred['own']
        else:
            predicted_attr = pred['cross']
        return p, predicted_attr
Ejemplo n.º 7
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        private_mean = torch.zeros_like(q['privateA'].dist.loc)
        private_std = torch.ones_like(q['privateA'].dist.scale)
        shared_mean = torch.zeros_like(q['sharedA'].dist.loc)
        shared_std = torch.ones_like(q['sharedA'].dist.scale)

        p = probtorch.Trace()

        # prior for z_private
        zPrivate = p.normal(private_mean,
                            private_std,
                            value=q['privateA'],
                            name='privateA')

        for shared_name in shared.keys():
            # prior for z_shared
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=shared[shared_name],
                               name=shared_name)

            hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1))

            hiddens = hiddens.view(-1, 64, 7, 7)
            images_mean = self.dec_image(hiddens)

            images_mean = images_mean.view(images_mean.size(0), -1)
            images = images.view(images.size(0), -1)
            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(
                1 - x_hat + EPS) * (1 - x)).sum(-1),
                   images_mean,
                   images,
                   name='images_' + shared_name)
        return p
	def forward(self, images, labels=None, num_samples=None):
		q = probtorch.Trace()
		hiddens = self.enc_hidden(images)
		q.normal(self.z_mean(hiddens),
			self.z_log_std(hiddens).exp(),
			name='z')
		return q
Ejemplo n.º 9
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        digit_log_weights = torch.zeros_like(q['sharedA'].dist.logits) # prior is the concrete dist for uniform dist. with all params=1
        style_mean = torch.zeros_like(q['privateA'].dist.loc)
        style_std = torch.ones_like(q['privateA'].dist.scale)

        p = probtorch.Trace()

        # prior for z_private
        zPrivate = p.normal(style_mean,
                        style_std,
                        value=q['privateA'],
                        name='privateA')
        # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함
        for shared_name in shared.keys():
            # prior for z_shared
            zShared = p.concrete(logits=digit_log_weights,
                                temperature=self.digit_temp,
                                value=shared[shared_name],
                                name=shared_name)

            if 'poe' in shared_name:
                hiddens = self.dec_hidden(torch.cat([zPrivate, torch.pow(zShared + EPS, 1/3)], -1))
            else:
                hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1))

            images_mean = self.dec_image(hiddens)

            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x +
                                      torch.log(1 - x_hat + EPS) * (1-x)).sum(-1),
                   images_mean, images, name= 'images_' + shared_name)
        return p
    def forward(self,
                images,
                q=None,
                num_samples=NUM_SAMPLES,
                batch_size=NUM_BATCH):
        p = probtorch.Trace()
        digit_log_weights = torch.zeros(num_samples, batch_size,
                                        self.num_digits)
        style_mean = torch.zeros(num_samples, batch_size, self.num_style)
        style_std = torch.ones(num_samples, batch_size, self.num_style)

        if CUDA:
            digit_log_weights = digit_log_weights.cuda()
            style_mean = style_mean.cuda()
            style_std = style_std.cuda()

        digits = digits = p.concrete(logits=digit_log_weights,
                                     temperature=self.digit_temp,
                                     value=q['y'],
                                     name='y')

        styles = p.normal(loc=style_mean,
                          scale=style_std,
                          value=q['z'],
                          name='z')

        hiddens = self.dec_hidden(torch.cat([digits, styles], -1))
        images_mean = self.dec_images(hiddens)
        p.loss(binary_cross_entropy, images_mean, images, name='images')
        return p
Ejemplo n.º 11
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        private_mean = torch.zeros_like(q['privateB'].dist.loc)
        private_std = torch.ones_like(q['privateB'].dist.scale)
        shared_mean = torch.zeros_like(q['sharedB'].dist.loc)
        shared_std = torch.ones_like(q['sharedB'].dist.scale)

        p = probtorch.Trace()

        # prior for z_private
        zPrivate = p.normal(private_mean,
                            private_std,
                            value=q['privateB'],
                            name='privateB')
        # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함
        for shared_name in shared.keys():
            # prior for z_shared
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=shared[shared_name],
                               name=shared_name)

            hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1))
            images_mean = self.dec_image(hiddens).squeeze(0)

            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x +
                                      torch.log(1 - x_hat + EPS) * (1 - x)).sum(-1),
                   images_mean, images, name='images2_' + shared_name)
        return p
Ejemplo n.º 12
0
 def forward(self, ob, prior_ng, sampled=True, tau_old=None, mu_old=None):
     q = probtorch.Trace()
     (prior_alpha, prior_beta, prior_mu, prior_nu) = prior_ng
     q_alpha, q_beta, q_mu, q_nu = posterior_eta(self.ob(ob),
                                                 self.gamma(ob),
                                                 prior_alpha, prior_beta,
                                                 prior_mu, prior_nu)
     if sampled:  ## used in forward transition kernel where we need to sample
         tau = Gamma(q_alpha, q_beta).sample()
         q.gamma(q_alpha, q_beta, value=tau, name='precisions')
         mu = Normal(q_mu,
                     1. / (q_nu * q['precisions'].value).sqrt()).sample()
         q.normal(
             q_mu,
             1. /
             (q_nu *
              q['precisions'].value).sqrt(),  # std = 1 / sqrt(nu * tau)
             value=mu,
             name='means')
     else:  ## used in backward transition kernel where samples were given from last sweep
         q.gamma(q_alpha, q_beta, value=tau_old, name='precisions')
         q.normal(q_mu,
                  1. / (q_nu * q['precisions'].value).sqrt(),
                  value=mu_old,
                  name='means')
     return q
Ejemplo n.º 13
0
    def forward(self,
                ob,
                z,
                prior_ng,
                sampled=True,
                tau_old=None,
                mu_old=None):
        q = probtorch.Trace()
        (prior_alpha, prior_beta, prior_mu, prior_nu) = prior_ng
        ob_z = torch.cat(
            (ob, z), -1)  # concatenate observations and cluster asssignemnts
        q_alpha, q_beta, q_mu, q_nu = posterior_eta(self.ob(ob_z),
                                                    self.gamma(ob_z),
                                                    prior_alpha, prior_beta,
                                                    prior_mu, prior_nu)

        if sampled == True:
            tau = Gamma(q_alpha, q_beta).sample()
            q.gamma(q_alpha, q_beta, value=tau, name='precisions')
            mu = Normal(q_mu,
                        1. / (q_nu * q['precisions'].value).sqrt()).sample()
            q.normal(q_mu,
                     1. / (q_nu * q['precisions'].value).sqrt(),
                     value=mu,
                     name='means')
        else:
            q.gamma(q_alpha, q_beta, value=tau_old, name='precisions')
            q.normal(q_mu,
                     1. / (q_nu * q['precisions'].value).sqrt(),
                     value=mu_old,
                     name='means')
        return q
Ejemplo n.º 14
0
    def forward(self, x, num_samples=None, q=None):
        if q is None:
            q = probtorch.Trace()

        hiddens = self.enc_hidden(x)
        hiddens = hiddens.view(hiddens.size(0), -1)
        stats = self.fc(hiddens)
        stats = stats.unsqueeze(0)

        muPrivate = stats[:, :, :self.zPrivate_dim]
        logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)]
        stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS)

        muShared = stats[:, :, (2 * self.zPrivate_dim):(2 * self.zPrivate_dim +
                                                        self.zShared_dim)]
        logvarShared = stats[:, :, (2 * self.zPrivate_dim + self.zShared_dim):]
        stdShared = torch.sqrt(torch.exp(logvarShared) + EPS)

        q.normal(loc=muPrivate, scale=stdPrivate, name='privateA')

        # print('muSharedA: ', muShared)
        # print('logvarSharedA: ', logvarShared)
        # print('stdSharedA: ', stdShared)
        # print('----------------------------')

        # attributes
        q.normal(loc=muShared, scale=stdShared, name='sharedA')
        return q
Ejemplo n.º 15
0
 def z_prior(self, q):
     p = probtorch.Trace()
     _ = p.variable(cat,
                    probs=self.prior_pi,
                    value=q['states'],
                    name='states')
     return p
Ejemplo n.º 16
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):

        shared_mean = torch.zeros_like(q['sharedA'].dist.loc)
        shared_std = torch.ones_like(q['sharedA'].dist.scale)

        p = probtorch.Trace()

        # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함
        for shared_name in shared.keys():
            # prior for z_shared
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=shared[shared_name],
                               name=shared_name)

            hiddens = self.dec_hidden(zShared)
            hiddens = hiddens.view(-1, 64, 7, 7)
            images_mean = self.dec_image(hiddens)
            images_mean = images_mean.view(images_mean.size(0), -1)
            images = images.view(images.size(0), -1)
            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(
                1 - x_hat + EPS) * (1 - x)).sum(-1),
                   images_mean,
                   images,
                   name='images_' + shared_name)
        return p
Ejemplo n.º 17
0
 def forward(self,
             ob,
             z,
             beta,
             K,
             priors,
             sampled=True,
             mu_old=None,
             EPS=1e-8):
     q = probtorch.Trace()
     S, B, N, D = ob.shape
     (prior_mu, prior_sigma) = priors
     nss1 = self.nss1(torch.cat(
         (ob, z), -1)).unsqueeze(-1).repeat(1, 1, 1, 1,
                                            K).transpose(-1, -2)
     nss2 = self.nss2(torch.cat(
         (ob, z), -1)).unsqueeze(-1).repeat(1, 1, 1, 1, nss1.shape[-1])
     nss = (nss1 * nss2).sum(2) / (nss2.sum(2) + EPS)
     nss_prior = torch.cat(
         (nss, prior_mu.repeat(S, B, K, 1), prior_sigma.repeat(S, B, K, 1)),
         -1)
     q_mu_mu = self.mean_mu(nss_prior)
     q_mu_sigma = self.mean_log_sigma(nss_prior).exp()
     if sampled:
         mu = Normal(q_mu_mu, q_mu_sigma).sample()
         q.normal(q_mu_mu, q_mu_sigma, value=mu, name='means')
     else:
         q.normal(q_mu_mu, q_mu_sigma, value=mu_old, name='means')
     return q
Ejemplo n.º 18
0
    def forward(self, x, num_samples=None, q=None):
        if q is None:
            q = probtorch.Trace()
        hiddens = self.resnet(x)
        hiddens = hiddens.view(hiddens.size(0), -1)
        stats = self.fc(hiddens)
        stats = stats.unsqueeze(0)
        muPrivate = stats[:, :, :self.zPrivate_dim]
        logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)]
        stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS)

        shared_attr_logit = stats[:, :, (2 * self.zPrivate_dim):(
            2 * self.zPrivate_dim + sum(self.zSharedAttr_dim))]
        shared_label_logit = stats[:, :, (2 * self.zPrivate_dim +
                                          sum(self.zSharedAttr_dim)):]

        q.normal(loc=muPrivate, scale=stdPrivate, name='privateA')

        # attributes
        start_dim = 0
        i = 0
        for this_attr_dim in self.zSharedAttr_dim:
            q.concrete(logits=shared_attr_logit[:, :, start_dim:start_dim +
                                                this_attr_dim],
                       temperature=self.digit_temp,
                       name='sharedA_attr' + str(i))
            start_dim += this_attr_dim
            i += 1

        # label
        q.concrete(logits=shared_label_logit,
                   temperature=self.digit_temp,
                   name='sharedA_label')
        # self.q = q
        return q
Ejemplo n.º 19
0
    def forward(self, images, q=None, num_samples=None):
        p = probtorch.Trace()

        attrs = []

        for i in range(self.num_attr):
            attrs.append(
                p.concrete(logits=self.attr_log_weights,
                           temperature=self.digit_temp,
                           value=q['attr' + str(i)],
                           name='attr' + str(i)))

        styles = p.normal(self.style_mean,
                          self.style_std,
                          value=q['styles'],
                          name='styles')
        attrs.append(styles)

        hiddens = self.dec_hidden(torch.cat(attrs, -1))
        hiddens = hiddens.view(-1, 256, 5, 5)
        images_mean = self.dec_image(hiddens)

        images_mean = images_mean.view(images_mean.size(0), -1).unsqueeze(0)
        images = images.view(images.size(0), -1).unsqueeze(0)
        p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(
            1 - x_hat + EPS) * (1 - x)).sum(-1),
               images_mean,
               images,
               name='images')
        return p
Ejemplo n.º 20
0
    def forward(self,
                labels,
                shared,
                q=None,
                p=None,
                num_samples=None,
                train=True):
        shared_mean = torch.zeros_like(q['sharedB'].dist.loc)
        shared_std = torch.ones_like(q['sharedB'].dist.scale)

        p = probtorch.Trace()
        for shared_name in shared.keys():
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=shared[shared_name],
                               name=shared_name)
            hiddens = self.dec_hidden(zShared)
            pred_labels = self.dec_label(hiddens)
            # define reconstruction loss (log prob of bernoulli dist)
            pred_labels = F.log_softmax(pred_labels + EPS, dim=2)
            if train:
                p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \
                       pred_labels, labels.unsqueeze(0), name='labels_' + shared_name)
                p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \
                       pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_acc_' + shared_name)
            else:
                p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \
                       pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_' + shared_name)

        return p
Ejemplo n.º 21
0
    def forward(self, attributes, num_samples=None, q=None):
        if q is None:
            q = probtorch.Trace()
        # attributes = attributes.view(attributes.size(0), -1)
        hiddens = self.enc_hidden(attributes)
        stats = self.fc(hiddens)

        shared_attr_logit = stats[:, :, :sum(self.zSharedAttr_dim)]
        shared_label_logit = stats[:, :, sum(self.zSharedAttr_dim):]

        # attribute
        start_dim = 0
        i = 0
        for this_attr_dim in self.zSharedAttr_dim:
            q.concrete(logits=shared_attr_logit[:, :, start_dim:start_dim +
                                                this_attr_dim],
                       temperature=self.digit_temp,
                       name='sharedB_attr' + str(i))
            start_dim += this_attr_dim
            i += 1

        # label
        q.concrete(logits=shared_label_logit,
                   temperature=self.digit_temp,
                   name='sharedB_label')
        return q
Ejemplo n.º 22
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        priv_mean = torch.zeros_like(q['privateA'].dist.loc)
        priv_std = torch.ones_like(q['privateA'].dist.scale)
        shared_mean = torch.zeros_like(q['sharedA'].dist.loc)
        shared_std = torch.ones_like(q['sharedA'].dist.scale)
        p = probtorch.Trace()

        # prior for z_private
        zPrivate = p.normal(priv_mean,
                            priv_std,
                            value=q['privateA'],
                            name='privateA')

        recon_img = {}
        for shared_from in shared.keys():
            latents = [zPrivate]
            # prior for z_shared_atrr
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=q[shared[shared_from]],
                               name=shared[shared_from])
            latents.append(zShared)

            # hiddens = self.dec_hidden(torch.cat(latents, -1))
            pred_imgs = self.dec_image(torch.cat(latents, -1))
            pred_imgs = pred_imgs.squeeze(0)
            # pred_imgs = F.sigmoid(pred_imgs)

            p.loss(
                lambda y_pred, target: F.binary_cross_entropy_with_logits(y_pred, target, reduction='none').sum(dim=1), \
                torch.log(pred_imgs + EPS), images, name='images_' + shared_from)

            recon_img.update({shared_from: pred_imgs})
        return p, recon_img
 def forward(self, images, labels=None, num_samples=None):
     q = probtorch.Trace()
     hidden = self.enc_hidden(images)
     styles_mean = self.style_mean(hidden)
     styles_std = torch.exp(self.style_log_std(hidden))
     q.normal(styles_mean, styles_std, name='z')
     return q
Ejemplo n.º 24
0
 def forward(self, ob, mu, K, sampled=True, z_old=None, beta_old=None):
     q = probtorch.Trace()
     S, B, N, D = ob.shape
     ob_mu = ob.unsqueeze(2).repeat(
         1, 1, K, 1, 1) - mu.unsqueeze(-2).repeat(1, 1, 1, N, 1)
     q_probs = F.softmax(
         self.pi_log_prob(ob_mu).squeeze(-1).transpose(-1, -2), -1)
     if sampled:
         z = cat(q_probs).sample()
         _ = q.variable(cat, probs=q_probs, value=z, name='states')
         mu_expand = torch.gather(
             mu, -2,
             z.argmax(-1).unsqueeze(-1).repeat(1, 1, 1, D))
         q_angle_con1 = self.angle_log_con1(ob - mu_expand).exp()
         q_angle_con0 = self.angle_log_con0(ob - mu_expand).exp()
         beta = Beta(q_angle_con1, q_angle_con0).sample()
         q.beta(q_angle_con1, q_angle_con0, value=beta, name='angles')
     else:
         _ = q.variable(cat, probs=q_probs, value=z_old, name='states')
         mu_expand = torch.gather(
             mu, -2,
             z_old.argmax(-1).unsqueeze(-1).repeat(1, 1, 1, D))
         q_angle_con1 = self.angle_log_con1(ob - mu_expand).exp()
         q_angle_con0 = self.angle_log_con0(ob - mu_expand).exp()
         q.beta(q_angle_con1, q_angle_con0, value=beta_old, name='angles')
     return q
Ejemplo n.º 25
0
    def forward(self, x, num_samples=None, q=None):
        if q is None:
            q = probtorch.Trace()

        hiddens = self.enc_hidden(x)
        stats = self.fc(hiddens)

        muPrivate = stats[:, :, :self.zPrivate_dim]
        logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)]
        stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS)

        muShared = stats[:, :, (2 * self.zPrivate_dim):(2 * self.zPrivate_dim + self.zShared_dim)]
        logvarShared = stats[:, :, (2 * self.zPrivate_dim + self.zShared_dim):]
        stdShared = torch.sqrt(torch.exp(logvarShared) + EPS)


        q.normal(loc=muPrivate,
                 scale=stdPrivate,
                 name='privateA')

        q.normal(loc=muShared,
                 scale=stdShared,
                 name='sharedA')

        return q
Ejemplo n.º 26
0
    def forward(self, images, shared, q=None, p=None, num_samples=None):
        priv_mean = torch.zeros_like(q['privateA'].dist.loc)
        priv_std = torch.ones_like(q['privateA'].dist.scale)
        shared_mean = torch.zeros_like(q['sharedA'].dist.loc)
        shared_std = torch.ones_like(q['sharedA'].dist.scale)

        p = probtorch.Trace()

        # prior for z_private
        zPrivate = p.normal(priv_mean,
                            priv_std,
                            value=q['privateA'],
                            name='privateA')

        for shared_from in shared.keys():
            latents = [zPrivate]
            # prior for z_shared_atrr
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=q[shared[shared_from]],
                               name=shared[shared_from])
            latents.append(zShared)
            hiddens = self.fc(torch.cat(latents, -1))
            hiddens = hiddens.view(-1, 256, 5, 5)
            images_mean = self.hallucinate(hiddens)
            images_mean = images_mean.view(images_mean.size(0), -1)
            images = images.view(images.size(0), -1)
            # define reconstruction loss (log prob of bernoulli dist)
            p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(
                1 - x_hat + EPS) * (1 - x)).sum(-1),
                   images_mean,
                   images,
                   name='images_' + shared_from)
        return p
Ejemplo n.º 27
0
    def forward(self,
                labels,
                shared,
                q=None,
                p=None,
                num_samples=None,
                train=True):
        shared_mean = torch.zeros_like(q['sharedB'].dist.loc)
        shared_std = torch.ones_like(q['sharedB'].dist.scale)

        p = probtorch.Trace()
        # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함
        for shared_name in shared.keys():
            # prior for z_shared # prior is the concrete dist for uniform dist. with all params=1
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=shared[shared_name],
                               name=shared_name)
            hiddens = self.dec_hidden(zShared)

            pred_labels = self.dec_label(hiddens)
            # define reconstruction loss (log prob of bernoulli dist)
            pred_labels = F.log_softmax(pred_labels + EPS, dim=2)
            if train:
                p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \
                       pred_labels, labels.unsqueeze(0), name='labels_' + shared_name)
            else:
                p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \
                       pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_' + shared_name)

        return p
Ejemplo n.º 28
0
    def forward(self,
                attributes,
                shared,
                q=None,
                p=None,
                num_samples=None,
                train=True,
                CUDA=False):
        shared_mean = torch.zeros_like(q['sharedB'].dist.loc)
        shared_std = torch.ones_like(q['sharedB'].dist.scale)
        pred = {}

        p = probtorch.Trace()

        for shared_from in shared.keys():
            # prior for z_shared_atrr
            zShared = p.normal(shared_mean,
                               shared_std,
                               value=q[shared[shared_from]],
                               name=shared[shared_from])
            hiddens = self.dec_hidden(zShared)
            pred_labels = self.dec_label(hiddens)
            pred_labels = pred_labels.squeeze(0)

            pred_labels = F.logsigmoid(pred_labels + EPS)

            p.loss(
                lambda y_pred, target: F.binary_cross_entropy_with_logits(y_pred, target, reduction='none').sum(dim=1), \
                pred_labels, attributes, name='attr_' + shared_from)
            pred.update({shared_from: pred_labels})
        return p, pred['cross']
Ejemplo n.º 29
0
    def forward(self,
                trace,
                times=None,
                guide=probtorch.Trace(),
                blocks=None,
                observations=[],
                weights_params=None):
        if blocks is None:
            blocks = list(range(self._num_blocks))
        params = self._hyperparams.state_vardict()

        template = self._template_prior(trace, params, guide=guide)
        weights, centers, log_widths = self._subject_prior(
            trace,
            params,
            template,
            times=times,
            blocks=blocks,
            guide=guide,
            weights_params=weights_params)

        return [
            self.likelihoods[b](trace,
                                weights[i],
                                centers[i],
                                log_widths[i],
                                params,
                                times=times,
                                observations=observations[i])
            for (i, b) in enumerate(blocks)
        ]
Ejemplo n.º 30
0
    def test_normal(self):
        N = 100 # Number of training data
        S = 3  # sample size
        B = 5  # batch size
        D = 10  # hidden dim
        mu1 = Variable(torch.randn(S, B, D))
        mu2 = Variable(torch.randn(S, B, D))
        sigma1 = torch.exp(Variable(torch.randn(S, B, D)))
        sigma2 = torch.exp(Variable(torch.randn(S, B, D)))
        q = probtorch.Trace()
        q.normal(mu1, sigma1, name='z1')
        q.normal(mu2, sigma2, name='z2')
        z1 = q['z1']
        z2 = q['z2']
        value1 = z1.value
        value2 = z2.value

        bias = (N - 1.0) / (B - 1.0)

        log_joint, log_mar, log_prod_mar = q.log_batch_marginal(sample_dim=0,
                                                                batch_dim=1,
                                                                nodes=['z1', 'z2'],
                                                                bias=bias)
        # compare result
        log_probs1 = Variable(torch.zeros(B, B, S, D))
        log_probs2 = Variable(torch.zeros(B, B, S, D))
        for b1 in range(B):
            for s in range(S):
                for b2 in range(B):
                    d1 = Normal(mu1[s, b2], sigma1[s, b2])
                    d2 = Normal(mu2[s, b2], sigma2[s, b2])
                    log_probs1[b1, b2, s] = d1.log_prob(value1[s, b1])
                    log_probs2[b1, b2, s] = d2.log_prob(value2[s, b1])


        log_sum_1 = log_probs1.sum(3)
        log_sum_2 = log_probs2.sum(3)

        log_joint_2 = log_sum_1 +  log_sum_2
        log_joint_2[range(B), range(B)] -= math.log(bias)
        log_joint_2 = log_mean_exp(log_joint_2, 1).transpose(0, 1)

        log_sum_1[range(B), range(B)] -= math.log(bias)
        log_sum_2[range(B), range(B)] -= math.log(bias)
        log_mar_z1 = log_mean_exp(log_sum_1, 1).transpose(0, 1) 
        log_mar_z2 = log_mean_exp(log_sum_2, 1).transpose(0, 1) 
        log_mar_2 = log_mar_z1 + log_mar_z2

        log_probs1[range(B), range(B)] -= math.log(bias)
        log_probs2[range(B), range(B)] -= math.log(bias)
        log_prod_mar_z1 = (log_mean_exp(log_probs1, 1)).sum(2).transpose(0, 1)
        log_prod_mar_z2 = (log_mean_exp(log_probs2, 1)).sum(2).transpose(0, 1)
        log_prod_mar_2 = log_prod_mar_z1 + log_prod_mar_z2

        self.assertEqual(log_mar, log_mar_2)
        self.assertEqual(log_prod_mar, log_prod_mar_2)
        self.assertEqual(log_joint, log_joint_2)