def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader: self.global_iter += 1 self.pbar.update(1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() D_tc_loss.backward() self.optim_D.step() if self.global_iter%self.print_iter == 0: self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close()
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader: self.global_iter += 1 self.pbar.update(1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar,self.r) H_r = entropy(self.r) D_z = self.D(self.r*z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss + self.etaS*self.r.abs().sum() + self.etaH*H_r self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() self.optim_r.zero_grad() vae_loss.backward(retain_graph=True) self.optim_r.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(self.r*z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() D_tc_loss.backward() self.optim_D.step() if self.global_iter%self.print_iter == 0: self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.viz_on and (self.global_iter%self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2*self.batch_size self.line_gather.insert(iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), kld=vae_kld.item(), acc=D_acc.item(), r_distribute=self.r.data.cpu()) if self.viz_on and (self.global_iter%self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() if self.viz_on and (self.global_iter%self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() if self.viz_on and (self.global_iter%self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: self.visualize_traverse(limit=3, inter=2/3) if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close()
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) metrics = [] out = False while not out: for x_true1, x_true2 in self.data_loader: self.global_iter += 1 self.pbar.update(1) self.optim_VAE.step() self.optim_D.step() x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) #self.optim_VAE.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() D_tc_loss.backward() #self.optim_D.step() # Saving the training metrics if self.global_iter % 100 == 0: metrics.append({'its':self.global_iter, 'vae_loss': vae_loss.detach().to(torch.device("cpu")).item(), 'D_loss': D_tc_loss.detach().to(torch.device("cpu")).item(), 'recon_loss':vae_recon_loss.detach().to(torch.device("cpu")).item(), 'tc_loss': vae_tc_loss.detach().to(torch.device("cpu")).item()}) # Saving the disentanglement metrics results if self.global_iter % 1500 == 0: score = self.disentanglement_metric() metrics.append({'its':self.global_iter, 'metric_score': score}) self.net_mode(train=True) #To continue the training again if self.global_iter%self.print_iter == 0: self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(str(self.global_iter)+".pth") self.save_metrics(metrics) metrics = [] if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close()
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) epochs = int(np.ceil(self.steps) / len(self.dataloader)) print("number of epochs {}".format(epochs)) step = 0 for e in range(epochs): #for e in range(): for x_true1, x_true2 in self.dataloader: if step == 1: break step += 1 # VAE x_true1 = x_true1.unsqueeze(1).to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) # Reconstruction and KL vae_recon_loss = recon_loss(x_true1, x_recon) vae_kl = kl_div(mu, logvar) # Total Correlation D_z = self.D(z) tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() # Synergy term best_ai = self.D_syn(mu, logvar) best_ai_labels = torch.bernoulli(best_ai) # TODO Copy to an empty tensor mu[best_ai_labels == 0] = 0 logvar_syn[best_ai_labels == 0] = 0 # TODO For to KL for i in range(self.batch_size): mu_syn_s = mu_syn[i][mu_syn[i] != 0] if len(mu_syn.size()) == 1: syn_loss = kl_div_uni_dim(mu_syn, logvar_syn).mean() # print("here") else: syn_loss = kl_div(mu_syn, logvar_syn) # VAE loss vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss + self.alpha * syn_loss # Optimise VAE self.optim_VAE.zero_grad() #zero gradients the buffer, grads vae_loss.backward( retain_graph=True) # grad parameters are populated self.optim_VAE.step() #Does the step # TODO Check the best greedy policy # Discriminator Syn real_seq = greedy_policy_Smax_discount(self.z_dim, mu, logvar, 0.8).detach d_syn_loss = recon_loss(real_seq, best_ai) # Optimise Discriminator Syn self.optim_D_syn.zero_grad( ) # set zeros all the gradients of VAE network d_syn_loss.backward( retain_graph=True) # backprop the gradients self.optim_D_syn.step( ) # Does the update in VAE network parameters # Discriminator TC x_true2 = x_true2.unsqueeze(1).to(self.device) z_prime = self.VAE(x_true2, decode=False)[3] z_perm = permute_dims(z_prime).detach( ) ## detaches the output from the graph. no gradient will be backproped along this variable. D_z_perm = self.D(z_perm) # Discriminator TC loss d_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_perm, ones)) # Optimise Discriminator TC self.optim_D.zero_grad() d_loss.backward() self.optim_D.step() # Logging if step % self.args.log_interval == 0: print("Step {}".format(step)) print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss)) print("KL Loss = " + "{:.4f}".format(vae_kl)) print("TC Loss = " + "{:.4f}".format(tc_loss)) print("Syn Loss = " + "{:.4f}".format(syn_loss)) print("Factor VAE Loss = " + "{:.4f}".format(vae_loss)) print("D loss = " + "{:.4f}".format(d_loss)) print("best_ai {}".format(best_ai)) print("Syn loss {:.4f}".format(syn_loss)) # Saving if not step % self.args.save_interval: filename = 'traversal_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) traverse(self.net_mode, self.VAE, self.test_imgs, filepath)
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) epochs = int(np.ceil(self.steps) / len(self.dataloader)) print("number of epochs {}".format(epochs)) step = 0 # dict of init opt weights #dict_init = {a: defaultdict(list) for a in range(10)} # dict of VAE opt weights #dict_VAE = {a:defaultdict(list) for a in range(10)} weights_names = [ 'encoder.2.weight', 'encoder.10.weight', 'decoder.0.weight', 'decoder.7.weight', 'net.4.weight' ] dict_VAE = defaultdict(list) dict_weight = {a: [] for a in weights_names} for e in range(epochs): #for e in range(): for x_true1, x_true2 in self.dataloader: #if step == 1: break step += 1 """ # TRACKING OF GRADS print("GRADS") for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': #size : 32,32,4,4 print("Grads: Before VAE optim step {}".format(step)) #if params.grad != None: if step != 1: if np.array_equal(dict_VAE[name], params.grad.numpy()) == False : #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) #dict_init[step][name] = params.grad.numpy() dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) else: print("name {}, params grad {}".format(name, params.grad)) #dict_init[step][name] = None dict_VAE[name] = None if name == 'encoder.10.weight': #size : 32,32,4,4 #print("Before VAE optim encoder step {}".format(step)) #if params.grad != None: if step != 1: if np.array_equal(dict_VAE[name], params.grad.numpy()) == False : #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) #dict_init[step][name] = params.grad.numpy() dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) else: print("name {}, params grad {}".format(name, params.grad)) #dict_init[step][name] = None dict_VAE[name] = None if name == 'decoder.0.weight': #print("Before VAE optim decoder step {}".format(step)) #if params.grad != None: if step != 1: if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("name {}, params grad {}".format(name, params.grad[:5, :2])) else: print("name {}, params grad {}".format(name, params.grad)) #dict_init[step][name] = None dict_VAE[name] = None if name == 'decoder.7.weight': #print("Before VAE optim decoder step {}".format(step)) #if params.grad != None: if step != 1: if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :])) else: print("name {}, params grad {}".format(name, params.grad)) #dict_init[step][name] = None dict_VAE[name] = None for name, params in self.D.named_parameters(): if name == 'net.4.weight': #print("Before VAE optim discrim step {}".format(step)) #if params.grad != None: if step != 1: if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :])) else: print("name {}, params grad {}".format(name, params.grad)) #dict_init[step][name] = None dict_VAE[name] = None print() """ # VAE x_true1 = x_true1.unsqueeze(1).to(self.device) #print("x_true1 size {}".format(x_true1.size())) x_recon, mu, logvar, z = self.VAE(x_true1) # Reconstruction and KL vae_recon_loss = recon_loss(x_true1, x_recon) #print("vae recon loss {}".format(vae_recon_loss)) vae_kl = kl_div(mu, logvar) #print("vae kl loss {}".format(vae_kl)) # Total Correlation D_z = self.D(z) #print("D_z size {}".format(D_z.size())) tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #print("tc loss {}".format(tc_loss)) # VAE loss vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss #print("Total VAE loss {}".format(vae_loss)) #print("Weights: Before VAE, step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) # Optimise VAE self.optim_VAE.zero_grad() #zero gradients the buffer, grads """ print("after zero grad step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) # check if the VAE is optimizing the encoder and decoder for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 if step == 1: print("name {}, params grad {}".format(name, params.grad)) #else: #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if name == 'encoder.10.weight': # size : 32,32,4,4 if step == 1: print("name {}, params grad {}".format(name, params.grad)) #else: #print("name {}, params grad {}".format(name, params.grad[:, :2])) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 32,32,4,4 if step == 1: print("name {}, params grad {}".format(name, params.grad)) #else: #print("name {}, params grad {}".format(name, params.grad[:5, :5])) """ vae_loss.backward( retain_graph=True) # grad parameters are populated """ print() print("after backward step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) # check if the VAE is optimizing the encoder and decoder for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 1000,1000 #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name))""" self.optim_VAE.step() #Does the step #print() #print("after VAE update step {}".format(step)) #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size())) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0,0,:,:])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) """ # check if the VAE is optimizing the encoder and decoder for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': #size : 32,32,4,4 print("After VAE optim step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': #size : 20, 128 #print("size of {}: {}".format(name, params.grad.size())) #print("After VAE optim encoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': #128,10 #print("After VAE optim decoder linear step {}".format(step)) #print("size of {}: {}".format(name, params.grad.size())) #print("name {}, params grad {}".format(name, params.grad[:3, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': #print("After VAE optim decoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': #print("After VAE optim discriminator step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:5,:5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) print() """ #print() #print("Before Syn step {}".format(step)) #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size())) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) ################## #Synergy Max # Step 1: compute the argmax of D kl (q(ai | x(i)) || ) best_ai = greedy_policy_Smax_discount(self.z_dim, mu, logvar, alpha=self.omega) # Step 2: compute the Imax mu_syn = mu[:, best_ai] logvar_syn = logvar[:, best_ai] if len(mu_syn.size()) == 1: I_max = kl_div_uni_dim(mu_syn, logvar_syn).mean() # print("here") else: I_max = kl_div(mu_syn, logvar_syn) #I_max1 = I_max_batch(best_ai, mu, logvar) #print("I_max step{}".format(I_max, step)) # Step 3: Use it in the loss syn_loss = self.alpha * I_max #print("syn_loss step {}".format(syn_loss, step)) # Step 4: Optimise Syn term self.optim_VAE.zero_grad( ) # set zeros all the gradients of VAE network """ #print() print("after zeros Syn step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 1000,1000 #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) """ syn_loss.backward(retain_graph=True) #backprop the gradients """ print() print("after Syn backward step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 1000,1000 #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) """ self.optim_VAE.step( ) #Does the update in VAE network parameters #print() #print("after Syn update step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) ################### """ # check if the VAE is optimizing the encoder and decoder for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 print("After Syn optim step {}".format(step)) # print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : #print("After Syn optim encoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:, :])) #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dim_changes = [] for dim in range(20): if np.array_equal(dict_VAE[name][dim, :2], params.grad.numpy()[dim, :2]) == False: dim_changes.append(dim) print("Changes in dimensions: {}".format(dim_changes)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': # 1024, 128 #print("After Syn optim decoder linear step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:5, :2])) #print("name {}, params grad {}".format(name, params.grad[:3, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': #print("After Syn optim decoder step {}".format(step)) # print("name {}, params grad {}".format(name, params.grad[1, 1, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': #print("After Syn optim discriminator step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:5,:5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) print() """ # Discriminator x_true2 = x_true2.unsqueeze(1).to(self.device) z_prime = self.VAE(x_true2, decode=False)[3] z_perm = permute_dims(z_prime).detach( ) ## detaches the output from the graph. no gradient will be backproped along this variable. D_z_perm = self.D(z_perm) # Discriminator loss d_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_perm, ones)) #print("d_loss {}".format(d_loss)) #print("dict VAE {}".format(dict_VAE['encoder.2.weight'][0, 0, :, :])) #print("before Disc, step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) # Optimise Discriminator self.optim_D.zero_grad() """ print("after zero grad Disc step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 1000,1000 #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) """ d_loss.backward() """ print() print("after backward Disc step {}".format(step)) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) # check if the VAE is optimizing the encoder and decoder for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'encoder.10.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': # size : 32,32,4,4 #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': # size : 1000,1000 #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) """ self.optim_D.step() #print("dict VAE {}".format(dict_VAE['encoder.2.weight'][0, 0, :, :])) """ print() print("after update disc step {}".format(step)) #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size())) #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :])) #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5])) for name, params in self.VAE.named_parameters(): if name == 'encoder.2.weight': #size : 32,32,4,4 print("After Discriminator optim encoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) #print("dict VAE {}".format(dict_VAE[name][0, 0, :, :])) #if np.isclose(dict_VAE[name], params.grad.numpy(), rtol=1e-05, atol=1e-08, equal_nan=False): "Works" #if np.all(abs(dict_VAE[step][name] - params.grad.numpy())) < 1e-7 == False: #if dict_VAE[name] != tuple(params.grad.numpy()): if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) #print("dict VAE {}".format(dict_VAE[name][0, 0, :, :])) if name == 'encoder.10.weight': # size : #print("After Syn optim encoder step {}".format(step)) # print("name {}, params grad {}".format(name, params.grad[0, 0, :, :])) #print("name {}, params grad {}".format(name, params.grad[:, :2])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: # if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.0.weight': #1024, 128 #print("After Discriminator optim decoder linear step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:5, :2])) #print("name {}, params grad {}".format(name, params.grad[:3, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) if name == 'decoder.7.weight': #print("After Discriminator optim decoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) for name, params in self.D.named_parameters(): if name == 'net.4.weight': #print("After Discriminator optim decoder step {}".format(step)) #print("name {}, params grad {}".format(name, params.grad[:5, :5])) if np.array_equal(dict_VAE[name], params.grad.numpy()) == False: #if dict_VAE[name] != tuple(params.grad.numpy()): print("Change in gradients {}".format(name)) dict_VAE[name] = params.grad.numpy().copy() else: print("No change in gradients {}".format(name)) print()""" # Logging if step % self.args.log_interval == 0: print("Step {}".format(step)) print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss)) print("KL Loss = " + "{:.4f}".format(vae_kl)) print("TC Loss = " + "{:.4f}".format(tc_loss)) print("Factor VAE Loss = " + "{:.4f}".format(vae_loss)) print("D loss = " + "{:.4f}".format(d_loss)) print("best_ai {}".format(best_ai)) print("I_max {}".format(I_max)) print("Syn loss {:.4f}".format(syn_loss)) # Saving if not step % self.args.save_interval: filename = 'alpha_' + str( self.alpha) + '_traversal_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) traverse(self.net_mode, self.VAE, self.test_imgs, filepath)
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) epochs = int(np.ceil(self.steps) / len(self.dataloader)) print("number of epochs {}".format(epochs)) step = 0 for e in range(epochs): #for e in range(1): for x_true1, x_true2 in self.dataloader: #if step == 50: break step += 1 # VAE x_true1 = x_true1.unsqueeze(1).to(self.device) #print("x_true1 size {}".format(x_true1.size())) x_recon, mu, logvar, z = self.VAE(x_true1) #print("x_recon size {}".format(x_recon.size())) #print("mu size {}".format(mu.size())) #print("logvar size {}".format(logvar.size())) #print("z size {}".format(z.size())) # Reconstruction and KL vae_recon_loss = recon_loss(x_true1, x_recon) #print("vae recon loss {}".format(vae_recon_loss)) vae_kl = kl_div(mu, logvar) #print("vae kl loss {}".format(vae_kl)) # Total Correlation D_z = self.D(z) #print("D_z size {}".format(D_z.size())) tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #print("tc loss {}".format(tc_loss)) # VAE loss vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss #print("Total VAE loss {}".format(vae_loss)) # Optimise VAE self.optim_VAE.zero_grad() #zero gradients the buffer vae_loss.backward(retain_graph=True) self.optim_VAE.step() #Does the step # Discriminator x_true2 = x_true2.unsqueeze(1).to(self.device) z_prime = self.VAE(x_true2, decode=False)[3] z_perm = permute_dims(z_prime).detach( ) ## detaches the output from the graph. no gradient will be backproped along this variable. D_z_perm = self.D(z_perm) # Discriminator loss d_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_perm, ones)) #print("d_loss {}".format(d_loss)) # Optimise Discriminator self.optim_D.zero_grad() d_loss.backward() self.optim_D.step() # Logging if step % self.args.log_interval == 0: print("Step {}".format(step)) print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss)) print("KL Loss = " + "{:.4f}".format(vae_kl)) print("TC Loss = " + "{:.4f}".format(tc_loss)) print("Factor VAE Loss = " + "{:.4f}".format(vae_loss)) print("D loss = " + "{:.4f}".format(d_loss)) # Saving if not step % self.args.save_interval: filename = 'traversal_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) traverse(self.net_mode, self.VAE, self.test_imgs, filepath) # Saving plot gt vs predicted if not step % self.args.gt_interval: filename = 'gt_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) plot_gt_shapes(self.net_mode, self.VAE, self.dataloader_gt, filepath)
def train(self): gcam = GradCamDissen(self.VAE, self.D, target_layer='encode.1', cuda=True) self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) mu_avg, logvar_avg = 0, 1 metrics = [] out = False while not out: for batch_idx, (x1, x2) in enumerate(self.data_loader): self.global_iter += 1 self.pbar.update(1) self.optim_VAE.step() self.optim_D.step() x1 = x1.to(self.device) x1_rec, mu, logvar, z = gcam.forward(x1) # For Standard FactorVAE loss vae_recon_loss = recon_loss(x1, x1_rec) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() factorVae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss # For attention disentanglement loss gcam.backward(mu, logvar, mu_avg, logvar_avg) att_loss = 0 with torch.no_grad(): gcam_maps = gcam.generate() selected = self.select_attention_maps(gcam_maps) for (sel1, sel2) in selected: att_loss += attention_disentanglement(sel1, sel2) att_loss /= len( selected) # Averaging the loss accross all pairs of maps vae_loss = factorVae_loss + self.lambdaa * att_loss self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) #self.optim_VAE.step() x2 = x2.to(self.device) z_prime = self.VAE(x2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() D_tc_loss.backward() #self.optim_D.step() # Saving the training metrics if self.global_iter % 100 == 0: metrics.append({ 'its': self.global_iter, 'vae_loss': vae_loss.detach().to(torch.device("cpu")).item(), 'D_loss': D_tc_loss.detach().to(torch.device("cpu")).item(), 'recon_loss': vae_recon_loss.detach().to(torch.device("cpu")).item(), 'tc_loss': vae_tc_loss.detach().to(torch.device("cpu")).item() }) # Saving the disentanglement metrics results if self.global_iter % 1500 == 0: score = self.disentanglement_metric() metrics.append({ 'its': self.global_iter, 'metric_score': score }) self.net_mode(train=True) # To continue the training again if self.global_iter % self.print_iter == 0: self.pbar.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}' .format(self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.global_iter % self.ckpt_save_iter == 0: self.save_checkpoint(str(self.global_iter) + ".pth") self.save_metrics(metrics) metrics = [] if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close()
def main(args): torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pbar = tqdm(total=args.epochs) image_gather = DataGather('true', 'recon') dataset = get_celeba_selected_dataset() data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True) lr_vae = args.lr_vae lr_D = args.lr_D vae = CelebaFactorVAE(args.z_dim, args.num_labels).to(device) optim_vae = torch.optim.Adam(vae.parameters(), lr=args.lr_vae) D = Discriminator(args.z_dim, args.num_labels).to(device) optim_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0.5, 0.9)) # Checkpoint ckpt_dir = os.path.join(args.ckpt_dir, args.name) mkdirs(ckpt_dir) start_epoch = 0 if args.ckpt_load: load_checkpoint(pbar, ckpt_dir, D, vae, optim_D, optim_vae, lr_vae, lr_D) #optim_D.param_groups[0]['lr'] = 0.00001#lr_D #optim_vae.param_groups[0]['lr'] = 0.00001#lr_vae print("confirming lr after loading checkpoint: ", optim_vae.param_groups[0]['lr']) # Output output_dir = os.path.join(args.output_dir, args.name) mkdirs(output_dir) ones = torch.ones(args.batch_size, dtype=torch.long, device=device) zeros = torch.zeros(args.batch_size, dtype=torch.long, device=device) for epoch in range(start_epoch, args.epochs): pbar.update(1) for iteration, (x, y, x2, y2) in enumerate(data_loader): x, y, x2, y2 = x.to(device), y.to(device), x2.to(device), y2.to( device) recon_x, mean, log_var, z = vae(x, y) if z.shape[0] != args.batch_size: print("passed a batch in epoch {}, iteration {}!".format( epoch, iteration)) continue D_z = D(z) vae_recon_loss = recon_loss(x, recon_x) * args.recon_weight vae_kld = kl_divergence(mean, log_var) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() * args.gamma vae_loss = vae_recon_loss + vae_tc_loss #+ vae_kld optim_vae.zero_grad() vae_loss.backward(retain_graph=True) z_prime = vae(x2, y2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = D(z_pperm) D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) optim_D.zero_grad() D_tc_loss.backward() optim_vae.step() optim_D.step() if iteration % args.print_iter == 0: pbar.write( '[epoch {}/{}, iter {}/{}] vae_recon_loss:{:.4f} vae_kld:{:.4f} vae_tc_loss:{:.4f} D_tc_loss:{:.4f}' .format(epoch, args.epochs, iteration, len(data_loader) - 1, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if iteration % args.output_iter == 0 and iteration != 0: output_dir = os.path.join( args.output_dir, args.name) #, "{}.{}".format(epoch, iteration)) mkdirs(output_dir) #reconstruction #image_gather.insert(true=x.data.cpu(), recon=torch.sigmoid(recon_x).data.cpu()) #data = image_gather.data #true_image = data['true'][0] #recon_image = data['recon'][0] #true_image = make_grid(true_image) #recon_image = make_grid(recon_image) #sample = torch.stack([true_image, recon_image], dim=0) #save_image(tensor=sample.cpu(), fp=os.path.join(output_dir, "recon.jpg")) #image_gather.flush() #inference given num_labels = 10 c = torch.randint(low=0, high=2, size=(1, 10)) #populated with 0s and 1s for i in range(9): c = torch.cat( (c, torch.randint(low=0, high=2, size=(1, 10))), 0) c = c.to(device) z_inf = torch.rand([c.size(0), args.z_dim]).to(device) #print("shapes: ",z_inf.shape, c.shape) #c = c.reshape(-1,args.num_labels,1,1) z_inf = torch.cat((z_inf, c), dim=1) z_inf = z_inf.reshape(-1, args.num_labels + args.z_dim, 1, 1) x = vae.decode(z_inf) plt.figure() plt.figure(figsize=(10, 20)) for p in range(args.num_labels): plt.subplot(5, 2, p + 1) #row, col, index starting from 1 plt.text(0, 0, "c={}".format(c[p]), color='black', backgroundcolor='white', fontsize=10) p = x[p].view(3, 218, 178) image = torch.transpose(p, 0, 2) image = torch.transpose(image, 0, 1) plt.imshow( (image.cpu().data.numpy() * 255).astype(np.uint8)) plt.axis('off') plt.savefig(os.path.join( output_dir, "E{:d}||{:d}.png".format(epoch, iteration)), dpi=300) plt.clf() plt.close('all') if epoch % 8 == 0: optim_vae.param_groups[0]['lr'] /= 10 optim_D.param_groups[0]['lr'] /= 10 print("\nnew learning rate at epoch {} is {}!".format( epoch, optim_vae.param_groups[0]['lr'])) if epoch % args.ckpt_iter_epoch == 0: save_checkpoint(pbar, epoch, D, vae, optim_D, optim_vae, ckpt_dir, epoch) pbar.write("[Training Finished]") pbar.close()
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader: self.global_iter += 1 self.pbar.update(1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_ad_loss = self.get_ad_loss(z) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss + self.lamb * vae_ad_loss x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_D.zero_grad() D_tc_loss.backward() self.optim_VAE.step() self.optim_D.step() if self.global_iter % self.print_iter == 0: if self.dis_score: dis_score = disentanglement_score( self.VAE.eval(), self.device, self.dataset, self.z_dim, self.L, self.vote_count, self.dis_batch_size) self.VAE.train() else: dis_score = torch.tensor(0) self.pbar.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} ad_loss:{:.3f} D_tc_loss:{:.3f} dis_score:{:.3f}' .format(self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), vae_ad_loss.item(), D_tc_loss.item(), dis_score.item())) if self.results_save: self.outputs['vae_recon_loss'].append( vae_recon_loss.item()) self.outputs['vae_kld'].append(vae_kld.item()) self.outputs['vae_tc_loss'].append(vae_tc_loss.item()) self.outputs['D_tc_loss'].append(D_tc_loss.item()) self.outputs['ad_loss'].append(vae_ad_loss.item()) self.outputs['dis_score'].append(dis_score.item()) self.outputs['iteration'].append(self.global_iter) if self.global_iter % self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close() if self.results_save: save_args_outputs(self.results_dir, self.args, self.outputs)
def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader:#ここで読み込んでいる? self.global_iter += 1 self.pbar.update(1) if self.dataset == 'mnist': x_true1 = x_true1.view(x_true1.shape[0], -1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) x = x_true1.view(x_true1.shape[0], -1) #custom #vae_recon_loss = self.custom_loss(x) / self.batch_size #custom vae_recon_loss = recon_loss(x, x_recon) #復元誤差, 交差エントロピー誤差 vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #恐らく, discriminatorのloss vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss #vae_loss = vae_recon_loss + self.gamma*vae_tc_loss self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() x_true2 = x_true2.to(self.device) #x_true2 = x_true2.view(x_true2.shape[0], -1) z_prime = self.VAE(x_true2, no_dec=True) #trueにすることで潜在空間に写像した状態のデータを獲得? z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) #GANのdiscriminatorっぽい?偽物と本物 #そのため誤差の部分が0と1になっているはず!zerosとonesの部分 self.optim_D.zero_grad() D_tc_loss.backward() self.optim_D.step() #if self.global_iter%self.print_iter == 0: # self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( # self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.test_count % 547 == 0: #self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( #self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_tc_loss.item(), D_tc_loss.item())) self.test_count = 0 if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.viz_on and (self.global_iter%self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2*self.batch_size self.line_gather.insert(iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), #kld=vae_kld.item(), acc=D_acc.item()) if self.viz_on and (self.global_iter%self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() if self.viz_on and (self.global_iter%self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() if self.viz_on and (self.global_iter%self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: #self.visualize_traverse(limit=3, inter=2/3) print("ignore") if self.global_iter >= self.max_iter: out = True break self.test_count += 1 self.pbar.write("[Training Finished]") torch.save(self.VAE.state_dict(), "model1/0531_128_2_gamma2.pth") self.pbar.close()