Exemple #1
0
 def get_kl(self, z):
     m_mixture, z_mixture = utils.gaussian_parameters(self.z_pre, dim=1)
     m = self.z_mean
     v = self.z_sigma
     kl = torch.mean(
         utils.log_normal(z, m, v) + self.flow_log_prob -
         utils.log_normal_mixture(z, m_mixture, z_mixture))
     return kl
Exemple #2
0
 def sample_z(self, batch):
     m, v = utils.gaussian_parameters(self.z_pre.squeeze(0), dim=0)
     idx = torch.distributions.categorical.Categorical(self.pi).sample(
         (batch, ))
     m, v = m[idx], v[idx]
     x = utils.sample_gaussian(m, v)
     if self.n_flows > 0:
         for flow in self.nf[::-1]:
             x, _ = flow.inverse(x)
     return x
Exemple #3
0
 def sample_point(self, batch):
     m, v = ut.gaussian_parameters(self.z_pre.squeeze(0), dim=0)
     idx = torch.distributions.categorical.Categorical(self.pi).sample(
         (batch, ))
     m, v = m[idx], v[idx]
     z = ut.sample_gaussian(m, v)
     decoder_input = z if not self.use_encoding_in_decoder else \
     torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
     y = self.sample_x_given(decoder_input)
     return y
def eval(dataloader):
	lst_loss_vae=list()
	lst_loss_ae=list()

	for name_model,model in dict_model.items():
	    model.eval()
	with torch.no_grad():
	    for ind_batch, X_batch in enumerate(dataloader):
	        cur_batch_size=X_batch.shape[0]
	        #print(X_batch.shape)
	        #reconstruct
	        out,_,_=dict_model['PointNet'](X_batch.permute(0,2,1)) #out: batch, 1024
	        set_rep=dict_model['FeatureCompression'](out) #set_rep: batch, dim_rep

	        #encoding. dim: batch, dim_z
	        qm,qv=ut.gaussian_parameters(dict_model['EncoderVAE'](set_rep),dim=1)
	        #sample z
	        z=ut.sample_gaussian(qm,qv,device=device) #batch_size, dim_z
	        #z to rep
	        rep_m,rep_v=ut.gaussian_parameters(dict_model['LatentToRep'](z))
	        X_rec=dict_model['DecoderAE'](ut.sample_gaussian(rep_m,rep_v,device=device)).reshape(cur_batch_size,-1,3)

	        #ae
	        X_rec_ae=dict_model['DecoderAE'](set_rep).reshape(cur_batch_size,-1,3)

	        dist_1,dist_2 = chamfer_dist(X_batch, X_rec)
	        loss_vae = (torch.mean(dist_1,axis=1)) + (torch.mean(dist_2,axis=1))

	        dist_1,dist_2 = chamfer_dist(X_batch, X_rec_ae)
	        loss_ae = (torch.mean(dist_1,axis=1)) + (torch.mean(dist_2,axis=1))

	        lst_loss_vae.append(loss_vae)
	        lst_loss_ae.append(loss_ae)
	avg_loss_vae=torch.cat(lst_loss_vae).mean().item()
	avg_loss_ae=torch.cat(lst_loss_ae).mean().item()
	return avg_loss_vae, avg_loss_ae
Exemple #5
0
 def forward(self, g, h, r, norm):
     self.node_id = h.squeeze()
     h = self.input_layer(g, h, r, norm)
     h = self.rconv_layer_1(g, h, r, norm)
     h = self.rconv_layer_2(g, h, r, norm)
     self.z_mean, self.z_sigma = utils.gaussian_parameters(h)
     z = utils.sample_gaussian(self.z_mean, self.z_sigma)
     # flow transformed samples
     if self.n_flows > 0:
         self.flow_log_prob = torch.zeros(1, )
         log_det_sum = torch.zeros((z.shape[0], ))
         for flow in self.nf:
             z, log_det = flow.forward(z)
             if log_det.is_cuda:
                 log_det_sum = log_det_sum.cuda()
             log_det_sum = log_det_sum + log_det.view(-1)
         self.flow_log_prob = torch.mean(log_det_sum.view(-1, 1))
     return z
Exemple #6
0
 def get_mmd(self, z):
     m_mixture, s_mixture = utils.gaussian_parameters(self.z_pre, dim=1)
     num_sample = 200
     z_pri = utils.sample_gaussian(m_mixture,
                                   s_mixture,
                                   repeat=num_sample // self.k)
     if self.n_flows > 0:
         for flow in self.nf:
             z_pri, _ = flow.forward(z_pri)
     # (Monte Carlo) sample the z otherwise GPU will be out of memory
     z_post_samples = z[random.sample(range(z.shape[0]), num_sample)]
     prior_z_kernel = self.compute_kernel(z_pri, z_pri)
     posterior_z_kernel = self.compute_kernel(z_post_samples,
                                              z_post_samples)
     mixed_kernel = self.compute_kernel(z_pri, z_post_samples)
     mmd = prior_z_kernel.mean() + posterior_z_kernel.mean(
     ) - 2 * mixed_kernel.mean()
     return mmd
Exemple #7
0
 def forward(self, inputs):
     ret = {}
     x, y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     # Compute the mixture of Gaussian prior
     prior = ut.gaussian_parameters(self.z_pre, dim=1)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     else:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         y = self.decoder(decoder_input)
         #compute KL divergence loss :
         z_prior_m, z_prior_v = prior[0], prior[1]
         kl_loss = ut.log_normal(z, m, v) - ut.log_normal_mixture(
             z, z_prior_m, z_prior_v)
     #compute reconstruction loss
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y, x)
     # mean or sum
     if self.loss_sum_mean == "mean":
         x_reconst = x_reconst.mean()
         kl_loss = kl_loss.mean()
     else:
         x_reconst = x_reconst.sum()
         kl_loss = kl_loss.sum()
     nelbo = x_reconst + kl_loss
     ret = {'nelbo': nelbo, 'kl_loss': kl_loss, 'x_reconst': x_reconst}
     # classifer network
     mv = torch.cat((m, v), dim=1)
     y_logits = self.z_classifer(mv)
     z_cl_loss = self.z_classifer.cross_entropy_loss(y_logits, y_class)
     ret['z_cl_loss'] = z_cl_loss
     return ret
 def forward(self, x):
     x = F.relu(self.bn1(self.fc1(x)))
     x = F.relu(self.bn2(self.dropout(self.fc2(x))))
     x = self.fc3(x)
     m, v = ut.gaussian_parameters(x, dim=1)
     return m, v
             z_prior_v).to(device)
with tqdm.tqdm(total=iter_max) as pbar:
    for i in range(iter_max):
        optimizer.zero_grad()
        X_hold, _ = dataset_train[0]  #random sample each call
        X_hold = X_hold.squeeze(1)  #batch_size,N_hold,3
        #X_eval=X_eval.squeeze(1) #batch_size,N_eval,3

        #extract set representation from hold out set
        out, _, _ = dict_model['PointNet'](X_hold.permute(
            0, 2, 1))  #out: batch, 1024
        set_rep = dict_model['FeatureCompression'](
            out)  #set_rep: batch, dim_rep

        #encoding. dim: batch, dim_z
        qm, qv = ut.gaussian_parameters(dict_model['EncoderVAE'](set_rep),
                                        dim=1)
        #sample z
        z = ut.sample_gaussian(qm, qv, device=device)  #batch_size, dim_z
        #z to rep
        rep_m, rep_v = ut.gaussian_parameters(dict_model['LatentToRep'](z))

        log_likelihood = ut.log_normal(set_rep, rep_m, rep_v)  #dim: batch
        lb_1 = log_likelihood.mean()  #scalar

        #KL divergence
        m = z_prior_m.expand(P, dim_z)
        v = z_prior_v.expand(P, dim_z)
        lb_2 = -ut.kl_normal(qm, qv, m, v).mean()  #scalar

        loss = -1 * (lb_1 + lb_2)
        loss.backward()