def forward(self, imgs=None): q = probtorch.Trace() self.guide(q, imgs) p = probtorch.Trace() _, _, _, reconstruction = self.model(p, q, imgs) return p, q, reconstruction
def test_normal(self): S = 3 # sample size B = 5 # batch size D = 2 # hidden dim mu = Variable(torch.randn(S, B, D)) sigma = torch.exp(Variable(torch.randn(S, B, D))) q = probtorch.Trace() q.normal(mu=mu, sigma=sigma, name='z') z = q['z'] value = z.value log_joint, log_prod_mar = q.log_batch_marginal(sample_dim=0, batch_dim=1, nodes=['z']) # compare result log_probs = Variable(torch.zeros(B, S, B, D)) for b1 in range(B): for s in range(S): for b2 in range(B): d = Normal(mu[s, b2], sigma[s, b2]) log_probs[b1, s, b2] = d.log_prob(value[s, b1]) log_joint_2 = log_mean_exp(log_probs.sum(3), 2).transpose(0, 1) log_prod_mar_2 = log_mean_exp(log_probs, 2).sum(2).transpose(0, 1) self.assertEqual(log_joint, log_joint_2) self.assertEqual(log_prod_mar, log_prod_mar_2)
def test_concrete(self): S = 3 # sample size B = 5 # batch size D = 2 # hidden dim K = 4 # event dim log_w = Variable(torch.randn(S, B, D, K)) q = probtorch.Trace() q.concrete(log_w, 0.66, name='y') y = q['y'] value = y.value log_joint, log_prod_mar = q.log_batch_marginal(sample_dim=0, batch_dim=1, nodes=['y']) # compare result log_probs = Variable(torch.zeros(B, S, B, D)) for b1 in range(B): for s in range(S): for b2 in range(B): d = Concrete(log_w[s, b2], 0.66) log_probs[b1, s, b2] = d.log_prob(value[s, b1]) log_joint_2 = log_mean_exp(log_probs.sum(3), 2).transpose(0, 1) log_prod_mar_2 = log_mean_exp(log_probs, 2).sum(2).transpose(0, 1) self.assertEqual(log_joint, log_joint_2) self.assertEqual(log_prod_mar, log_prod_mar_2)
def forward(self, trace, params, times=None, guide=probtorch.Trace()): if times is None: times = (0, self._num_times) weight_params = utils.unsqueeze_and_expand_vardict( params['weights'], len(params['weights']['mu'].shape) - 1, times[1] - times[0], True) weights = trace.normal( weight_params['mu'], torch.exp(weight_params['log_sigma']), value=guide['Weights%dt%d-%d' % (self.block, times[0], times[1])], name='Weights%dt%d-%d' % (self.block, times[0], times[1])) factor_centers = trace.normal( params['factor_centers']['mu'], torch.exp(params['factor_centers']['log_sigma']), value=guide['FactorCenters' + str(self.block)], name='FactorCenters' + str(self.block)) factor_log_widths = trace.normal( params['factor_log_widths']['mu'], torch.exp(params['factor_log_widths']['log_sigma']), value=guide['FactorLogWidths' + str(self.block)], name='FactorLogWidths' + str(self.block)) return weights, factor_centers, factor_log_widths
def forward(self, images, shared, q=None, p=None, num_samples=None): shared_mean = torch.zeros_like(q['sharedA'].dist.loc) shared_std = torch.ones_like(q['sharedA'].dist.scale) p = probtorch.Trace() for shared_from in shared.keys(): # prior for z_shared_atrr zShared = p.normal(shared_mean, shared_std, value=q[shared[shared_from]], name=shared[shared_from]) # h = self.dec_hidden(zShared.squeeze(0)) h = F.relu(self.fc1(zShared.squeeze(0))) h = F.relu(self.fc2(h)) h = F.relu(self.fc3(h)) images_mean = self.dec_image(h) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log( 1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images_' + shared_from) return p
def forward(self, labels, shared, q=None, p=None, num_samples=None, train=True, CUDA=False): shared_mean = torch.zeros_like(q['sharedB'].dist.loc) shared_std = torch.ones_like(q['sharedB'].dist.scale) pred = {} p = probtorch.Trace() for shared_from in shared.keys(): # prior for z_shared_atrr zShared = p.normal(shared_mean, shared_std, value=q[shared[shared_from]], name=shared[shared_from]) # h = self.dec_hidden(zShared.squeeze(0)) pred_labels = F.log_softmax(zShared.squeeze(0) + EPS, dim=1) p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \ pred_labels, labels, name='label_' + shared_from) pred.update({shared_from: pred_labels}) if train: predicted_attr = pred['own'] else: predicted_attr = pred['cross'] return p, predicted_attr
def forward(self, images, shared, q=None, p=None, num_samples=None): private_mean = torch.zeros_like(q['privateA'].dist.loc) private_std = torch.ones_like(q['privateA'].dist.scale) shared_mean = torch.zeros_like(q['sharedA'].dist.loc) shared_std = torch.ones_like(q['sharedA'].dist.scale) p = probtorch.Trace() # prior for z_private zPrivate = p.normal(private_mean, private_std, value=q['privateA'], name='privateA') for shared_name in shared.keys(): # prior for z_shared zShared = p.normal(shared_mean, shared_std, value=shared[shared_name], name=shared_name) hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1)) hiddens = hiddens.view(-1, 64, 7, 7) images_mean = self.dec_image(hiddens) images_mean = images_mean.view(images_mean.size(0), -1) images = images.view(images.size(0), -1) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log( 1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images_' + shared_name) return p
def forward(self, images, labels=None, num_samples=None): q = probtorch.Trace() hiddens = self.enc_hidden(images) q.normal(self.z_mean(hiddens), self.z_log_std(hiddens).exp(), name='z') return q
def forward(self, images, shared, q=None, p=None, num_samples=None): digit_log_weights = torch.zeros_like(q['sharedA'].dist.logits) # prior is the concrete dist for uniform dist. with all params=1 style_mean = torch.zeros_like(q['privateA'].dist.loc) style_std = torch.ones_like(q['privateA'].dist.scale) p = probtorch.Trace() # prior for z_private zPrivate = p.normal(style_mean, style_std, value=q['privateA'], name='privateA') # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함 for shared_name in shared.keys(): # prior for z_shared zShared = p.concrete(logits=digit_log_weights, temperature=self.digit_temp, value=shared[shared_name], name=shared_name) if 'poe' in shared_name: hiddens = self.dec_hidden(torch.cat([zPrivate, torch.pow(zShared + EPS, 1/3)], -1)) else: hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1)) images_mean = self.dec_image(hiddens) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(1 - x_hat + EPS) * (1-x)).sum(-1), images_mean, images, name= 'images_' + shared_name) return p
def forward(self, images, q=None, num_samples=NUM_SAMPLES, batch_size=NUM_BATCH): p = probtorch.Trace() digit_log_weights = torch.zeros(num_samples, batch_size, self.num_digits) style_mean = torch.zeros(num_samples, batch_size, self.num_style) style_std = torch.ones(num_samples, batch_size, self.num_style) if CUDA: digit_log_weights = digit_log_weights.cuda() style_mean = style_mean.cuda() style_std = style_std.cuda() digits = digits = p.concrete(logits=digit_log_weights, temperature=self.digit_temp, value=q['y'], name='y') styles = p.normal(loc=style_mean, scale=style_std, value=q['z'], name='z') hiddens = self.dec_hidden(torch.cat([digits, styles], -1)) images_mean = self.dec_images(hiddens) p.loss(binary_cross_entropy, images_mean, images, name='images') return p
def forward(self, images, shared, q=None, p=None, num_samples=None): private_mean = torch.zeros_like(q['privateB'].dist.loc) private_std = torch.ones_like(q['privateB'].dist.scale) shared_mean = torch.zeros_like(q['sharedB'].dist.loc) shared_std = torch.ones_like(q['sharedB'].dist.scale) p = probtorch.Trace() # prior for z_private zPrivate = p.normal(private_mean, private_std, value=q['privateB'], name='privateB') # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함 for shared_name in shared.keys(): # prior for z_shared zShared = p.normal(shared_mean, shared_std, value=shared[shared_name], name=shared_name) hiddens = self.dec_hidden(torch.cat([zPrivate, zShared], -1)) images_mean = self.dec_image(hiddens).squeeze(0) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log(1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images2_' + shared_name) return p
def forward(self, ob, prior_ng, sampled=True, tau_old=None, mu_old=None): q = probtorch.Trace() (prior_alpha, prior_beta, prior_mu, prior_nu) = prior_ng q_alpha, q_beta, q_mu, q_nu = posterior_eta(self.ob(ob), self.gamma(ob), prior_alpha, prior_beta, prior_mu, prior_nu) if sampled: ## used in forward transition kernel where we need to sample tau = Gamma(q_alpha, q_beta).sample() q.gamma(q_alpha, q_beta, value=tau, name='precisions') mu = Normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt()).sample() q.normal( q_mu, 1. / (q_nu * q['precisions'].value).sqrt(), # std = 1 / sqrt(nu * tau) value=mu, name='means') else: ## used in backward transition kernel where samples were given from last sweep q.gamma(q_alpha, q_beta, value=tau_old, name='precisions') q.normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt(), value=mu_old, name='means') return q
def forward(self, ob, z, prior_ng, sampled=True, tau_old=None, mu_old=None): q = probtorch.Trace() (prior_alpha, prior_beta, prior_mu, prior_nu) = prior_ng ob_z = torch.cat( (ob, z), -1) # concatenate observations and cluster asssignemnts q_alpha, q_beta, q_mu, q_nu = posterior_eta(self.ob(ob_z), self.gamma(ob_z), prior_alpha, prior_beta, prior_mu, prior_nu) if sampled == True: tau = Gamma(q_alpha, q_beta).sample() q.gamma(q_alpha, q_beta, value=tau, name='precisions') mu = Normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt()).sample() q.normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt(), value=mu, name='means') else: q.gamma(q_alpha, q_beta, value=tau_old, name='precisions') q.normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt(), value=mu_old, name='means') return q
def forward(self, x, num_samples=None, q=None): if q is None: q = probtorch.Trace() hiddens = self.enc_hidden(x) hiddens = hiddens.view(hiddens.size(0), -1) stats = self.fc(hiddens) stats = stats.unsqueeze(0) muPrivate = stats[:, :, :self.zPrivate_dim] logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)] stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS) muShared = stats[:, :, (2 * self.zPrivate_dim):(2 * self.zPrivate_dim + self.zShared_dim)] logvarShared = stats[:, :, (2 * self.zPrivate_dim + self.zShared_dim):] stdShared = torch.sqrt(torch.exp(logvarShared) + EPS) q.normal(loc=muPrivate, scale=stdPrivate, name='privateA') # print('muSharedA: ', muShared) # print('logvarSharedA: ', logvarShared) # print('stdSharedA: ', stdShared) # print('----------------------------') # attributes q.normal(loc=muShared, scale=stdShared, name='sharedA') return q
def z_prior(self, q): p = probtorch.Trace() _ = p.variable(cat, probs=self.prior_pi, value=q['states'], name='states') return p
def forward(self, images, shared, q=None, p=None, num_samples=None): shared_mean = torch.zeros_like(q['sharedA'].dist.loc) shared_std = torch.ones_like(q['sharedA'].dist.scale) p = probtorch.Trace() # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함 for shared_name in shared.keys(): # prior for z_shared zShared = p.normal(shared_mean, shared_std, value=shared[shared_name], name=shared_name) hiddens = self.dec_hidden(zShared) hiddens = hiddens.view(-1, 64, 7, 7) images_mean = self.dec_image(hiddens) images_mean = images_mean.view(images_mean.size(0), -1) images = images.view(images.size(0), -1) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log( 1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images_' + shared_name) return p
def forward(self, ob, z, beta, K, priors, sampled=True, mu_old=None, EPS=1e-8): q = probtorch.Trace() S, B, N, D = ob.shape (prior_mu, prior_sigma) = priors nss1 = self.nss1(torch.cat( (ob, z), -1)).unsqueeze(-1).repeat(1, 1, 1, 1, K).transpose(-1, -2) nss2 = self.nss2(torch.cat( (ob, z), -1)).unsqueeze(-1).repeat(1, 1, 1, 1, nss1.shape[-1]) nss = (nss1 * nss2).sum(2) / (nss2.sum(2) + EPS) nss_prior = torch.cat( (nss, prior_mu.repeat(S, B, K, 1), prior_sigma.repeat(S, B, K, 1)), -1) q_mu_mu = self.mean_mu(nss_prior) q_mu_sigma = self.mean_log_sigma(nss_prior).exp() if sampled: mu = Normal(q_mu_mu, q_mu_sigma).sample() q.normal(q_mu_mu, q_mu_sigma, value=mu, name='means') else: q.normal(q_mu_mu, q_mu_sigma, value=mu_old, name='means') return q
def forward(self, x, num_samples=None, q=None): if q is None: q = probtorch.Trace() hiddens = self.resnet(x) hiddens = hiddens.view(hiddens.size(0), -1) stats = self.fc(hiddens) stats = stats.unsqueeze(0) muPrivate = stats[:, :, :self.zPrivate_dim] logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)] stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS) shared_attr_logit = stats[:, :, (2 * self.zPrivate_dim):( 2 * self.zPrivate_dim + sum(self.zSharedAttr_dim))] shared_label_logit = stats[:, :, (2 * self.zPrivate_dim + sum(self.zSharedAttr_dim)):] q.normal(loc=muPrivate, scale=stdPrivate, name='privateA') # attributes start_dim = 0 i = 0 for this_attr_dim in self.zSharedAttr_dim: q.concrete(logits=shared_attr_logit[:, :, start_dim:start_dim + this_attr_dim], temperature=self.digit_temp, name='sharedA_attr' + str(i)) start_dim += this_attr_dim i += 1 # label q.concrete(logits=shared_label_logit, temperature=self.digit_temp, name='sharedA_label') # self.q = q return q
def forward(self, images, q=None, num_samples=None): p = probtorch.Trace() attrs = [] for i in range(self.num_attr): attrs.append( p.concrete(logits=self.attr_log_weights, temperature=self.digit_temp, value=q['attr' + str(i)], name='attr' + str(i))) styles = p.normal(self.style_mean, self.style_std, value=q['styles'], name='styles') attrs.append(styles) hiddens = self.dec_hidden(torch.cat(attrs, -1)) hiddens = hiddens.view(-1, 256, 5, 5) images_mean = self.dec_image(hiddens) images_mean = images_mean.view(images_mean.size(0), -1).unsqueeze(0) images = images.view(images.size(0), -1).unsqueeze(0) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log( 1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images') return p
def forward(self, labels, shared, q=None, p=None, num_samples=None, train=True): shared_mean = torch.zeros_like(q['sharedB'].dist.loc) shared_std = torch.ones_like(q['sharedB'].dist.scale) p = probtorch.Trace() for shared_name in shared.keys(): zShared = p.normal(shared_mean, shared_std, value=shared[shared_name], name=shared_name) hiddens = self.dec_hidden(zShared) pred_labels = self.dec_label(hiddens) # define reconstruction loss (log prob of bernoulli dist) pred_labels = F.log_softmax(pred_labels + EPS, dim=2) if train: p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \ pred_labels, labels.unsqueeze(0), name='labels_' + shared_name) p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \ pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_acc_' + shared_name) else: p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \ pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_' + shared_name) return p
def forward(self, attributes, num_samples=None, q=None): if q is None: q = probtorch.Trace() # attributes = attributes.view(attributes.size(0), -1) hiddens = self.enc_hidden(attributes) stats = self.fc(hiddens) shared_attr_logit = stats[:, :, :sum(self.zSharedAttr_dim)] shared_label_logit = stats[:, :, sum(self.zSharedAttr_dim):] # attribute start_dim = 0 i = 0 for this_attr_dim in self.zSharedAttr_dim: q.concrete(logits=shared_attr_logit[:, :, start_dim:start_dim + this_attr_dim], temperature=self.digit_temp, name='sharedB_attr' + str(i)) start_dim += this_attr_dim i += 1 # label q.concrete(logits=shared_label_logit, temperature=self.digit_temp, name='sharedB_label') return q
def forward(self, images, shared, q=None, p=None, num_samples=None): priv_mean = torch.zeros_like(q['privateA'].dist.loc) priv_std = torch.ones_like(q['privateA'].dist.scale) shared_mean = torch.zeros_like(q['sharedA'].dist.loc) shared_std = torch.ones_like(q['sharedA'].dist.scale) p = probtorch.Trace() # prior for z_private zPrivate = p.normal(priv_mean, priv_std, value=q['privateA'], name='privateA') recon_img = {} for shared_from in shared.keys(): latents = [zPrivate] # prior for z_shared_atrr zShared = p.normal(shared_mean, shared_std, value=q[shared[shared_from]], name=shared[shared_from]) latents.append(zShared) # hiddens = self.dec_hidden(torch.cat(latents, -1)) pred_imgs = self.dec_image(torch.cat(latents, -1)) pred_imgs = pred_imgs.squeeze(0) # pred_imgs = F.sigmoid(pred_imgs) p.loss( lambda y_pred, target: F.binary_cross_entropy_with_logits(y_pred, target, reduction='none').sum(dim=1), \ torch.log(pred_imgs + EPS), images, name='images_' + shared_from) recon_img.update({shared_from: pred_imgs}) return p, recon_img
def forward(self, images, labels=None, num_samples=None): q = probtorch.Trace() hidden = self.enc_hidden(images) styles_mean = self.style_mean(hidden) styles_std = torch.exp(self.style_log_std(hidden)) q.normal(styles_mean, styles_std, name='z') return q
def forward(self, ob, mu, K, sampled=True, z_old=None, beta_old=None): q = probtorch.Trace() S, B, N, D = ob.shape ob_mu = ob.unsqueeze(2).repeat( 1, 1, K, 1, 1) - mu.unsqueeze(-2).repeat(1, 1, 1, N, 1) q_probs = F.softmax( self.pi_log_prob(ob_mu).squeeze(-1).transpose(-1, -2), -1) if sampled: z = cat(q_probs).sample() _ = q.variable(cat, probs=q_probs, value=z, name='states') mu_expand = torch.gather( mu, -2, z.argmax(-1).unsqueeze(-1).repeat(1, 1, 1, D)) q_angle_con1 = self.angle_log_con1(ob - mu_expand).exp() q_angle_con0 = self.angle_log_con0(ob - mu_expand).exp() beta = Beta(q_angle_con1, q_angle_con0).sample() q.beta(q_angle_con1, q_angle_con0, value=beta, name='angles') else: _ = q.variable(cat, probs=q_probs, value=z_old, name='states') mu_expand = torch.gather( mu, -2, z_old.argmax(-1).unsqueeze(-1).repeat(1, 1, 1, D)) q_angle_con1 = self.angle_log_con1(ob - mu_expand).exp() q_angle_con0 = self.angle_log_con0(ob - mu_expand).exp() q.beta(q_angle_con1, q_angle_con0, value=beta_old, name='angles') return q
def forward(self, x, num_samples=None, q=None): if q is None: q = probtorch.Trace() hiddens = self.enc_hidden(x) stats = self.fc(hiddens) muPrivate = stats[:, :, :self.zPrivate_dim] logvarPrivate = stats[:, :, self.zPrivate_dim:(2 * self.zPrivate_dim)] stdPrivate = torch.sqrt(torch.exp(logvarPrivate) + EPS) muShared = stats[:, :, (2 * self.zPrivate_dim):(2 * self.zPrivate_dim + self.zShared_dim)] logvarShared = stats[:, :, (2 * self.zPrivate_dim + self.zShared_dim):] stdShared = torch.sqrt(torch.exp(logvarShared) + EPS) q.normal(loc=muPrivate, scale=stdPrivate, name='privateA') q.normal(loc=muShared, scale=stdShared, name='sharedA') return q
def forward(self, images, shared, q=None, p=None, num_samples=None): priv_mean = torch.zeros_like(q['privateA'].dist.loc) priv_std = torch.ones_like(q['privateA'].dist.scale) shared_mean = torch.zeros_like(q['sharedA'].dist.loc) shared_std = torch.ones_like(q['sharedA'].dist.scale) p = probtorch.Trace() # prior for z_private zPrivate = p.normal(priv_mean, priv_std, value=q['privateA'], name='privateA') for shared_from in shared.keys(): latents = [zPrivate] # prior for z_shared_atrr zShared = p.normal(shared_mean, shared_std, value=q[shared[shared_from]], name=shared[shared_from]) latents.append(zShared) hiddens = self.fc(torch.cat(latents, -1)) hiddens = hiddens.view(-1, 256, 5, 5) images_mean = self.hallucinate(hiddens) images_mean = images_mean.view(images_mean.size(0), -1) images = images.view(images.size(0), -1) # define reconstruction loss (log prob of bernoulli dist) p.loss(lambda x_hat, x: -(torch.log(x_hat + EPS) * x + torch.log( 1 - x_hat + EPS) * (1 - x)).sum(-1), images_mean, images, name='images_' + shared_from) return p
def forward(self, labels, shared, q=None, p=None, num_samples=None, train=True): shared_mean = torch.zeros_like(q['sharedB'].dist.loc) shared_std = torch.ones_like(q['sharedB'].dist.scale) p = probtorch.Trace() # private은 sharedA(infA), sharedB(crossA), sharedPOE 모두에게 공통적으로 들어가는 node로 z_private 한 샘플에 의해 모두가 다 생성돼야함 for shared_name in shared.keys(): # prior for z_shared # prior is the concrete dist for uniform dist. with all params=1 zShared = p.normal(shared_mean, shared_std, value=shared[shared_name], name=shared_name) hiddens = self.dec_hidden(zShared) pred_labels = self.dec_label(hiddens) # define reconstruction loss (log prob of bernoulli dist) pred_labels = F.log_softmax(pred_labels + EPS, dim=2) if train: p.loss(lambda y_pred, target: -(target * y_pred).sum(-1), \ pred_labels, labels.unsqueeze(0), name='labels_' + shared_name) else: p.loss(lambda y_pred, target: (1 - (target == y_pred).float()), \ pred_labels.max(-1)[1], labels.max(-1)[1], name='labels_' + shared_name) return p
def forward(self, attributes, shared, q=None, p=None, num_samples=None, train=True, CUDA=False): shared_mean = torch.zeros_like(q['sharedB'].dist.loc) shared_std = torch.ones_like(q['sharedB'].dist.scale) pred = {} p = probtorch.Trace() for shared_from in shared.keys(): # prior for z_shared_atrr zShared = p.normal(shared_mean, shared_std, value=q[shared[shared_from]], name=shared[shared_from]) hiddens = self.dec_hidden(zShared) pred_labels = self.dec_label(hiddens) pred_labels = pred_labels.squeeze(0) pred_labels = F.logsigmoid(pred_labels + EPS) p.loss( lambda y_pred, target: F.binary_cross_entropy_with_logits(y_pred, target, reduction='none').sum(dim=1), \ pred_labels, attributes, name='attr_' + shared_from) pred.update({shared_from: pred_labels}) return p, pred['cross']
def forward(self, trace, times=None, guide=probtorch.Trace(), blocks=None, observations=[], weights_params=None): if blocks is None: blocks = list(range(self._num_blocks)) params = self._hyperparams.state_vardict() template = self._template_prior(trace, params, guide=guide) weights, centers, log_widths = self._subject_prior( trace, params, template, times=times, blocks=blocks, guide=guide, weights_params=weights_params) return [ self.likelihoods[b](trace, weights[i], centers[i], log_widths[i], params, times=times, observations=observations[i]) for (i, b) in enumerate(blocks) ]
def test_normal(self): N = 100 # Number of training data S = 3 # sample size B = 5 # batch size D = 10 # hidden dim mu1 = Variable(torch.randn(S, B, D)) mu2 = Variable(torch.randn(S, B, D)) sigma1 = torch.exp(Variable(torch.randn(S, B, D))) sigma2 = torch.exp(Variable(torch.randn(S, B, D))) q = probtorch.Trace() q.normal(mu1, sigma1, name='z1') q.normal(mu2, sigma2, name='z2') z1 = q['z1'] z2 = q['z2'] value1 = z1.value value2 = z2.value bias = (N - 1.0) / (B - 1.0) log_joint, log_mar, log_prod_mar = q.log_batch_marginal(sample_dim=0, batch_dim=1, nodes=['z1', 'z2'], bias=bias) # compare result log_probs1 = Variable(torch.zeros(B, B, S, D)) log_probs2 = Variable(torch.zeros(B, B, S, D)) for b1 in range(B): for s in range(S): for b2 in range(B): d1 = Normal(mu1[s, b2], sigma1[s, b2]) d2 = Normal(mu2[s, b2], sigma2[s, b2]) log_probs1[b1, b2, s] = d1.log_prob(value1[s, b1]) log_probs2[b1, b2, s] = d2.log_prob(value2[s, b1]) log_sum_1 = log_probs1.sum(3) log_sum_2 = log_probs2.sum(3) log_joint_2 = log_sum_1 + log_sum_2 log_joint_2[range(B), range(B)] -= math.log(bias) log_joint_2 = log_mean_exp(log_joint_2, 1).transpose(0, 1) log_sum_1[range(B), range(B)] -= math.log(bias) log_sum_2[range(B), range(B)] -= math.log(bias) log_mar_z1 = log_mean_exp(log_sum_1, 1).transpose(0, 1) log_mar_z2 = log_mean_exp(log_sum_2, 1).transpose(0, 1) log_mar_2 = log_mar_z1 + log_mar_z2 log_probs1[range(B), range(B)] -= math.log(bias) log_probs2[range(B), range(B)] -= math.log(bias) log_prod_mar_z1 = (log_mean_exp(log_probs1, 1)).sum(2).transpose(0, 1) log_prod_mar_z2 = (log_mean_exp(log_probs2, 1)).sum(2).transpose(0, 1) log_prod_mar_2 = log_prod_mar_z1 + log_prod_mar_z2 self.assertEqual(log_mar, log_mar_2) self.assertEqual(log_prod_mar, log_prod_mar_2) self.assertEqual(log_joint, log_joint_2)