def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False for i_num in range(self.max_iter - self.global_iter): total_pa_num = 0 total_pa_correct_num = 0 total_male_num = 0 total_male_correct = 0 total_female_num = 0 total_female_correct = 0 total_rev_num = 0 total_rev_correct_num = 0 total_t_num = 0 total_t_correct_num = 0 total_t_rev_num = 0 total_t_rev_correct_num = 0 for i, (x_true1, x_true2, heavy_makeup, male) in enumerate(self.data_loader): #from PIL import Image #from torchvision import transforms #import pdb;pdb.set_trace() #import pdb;pdb.set_trace() heavy_makeup = heavy_makeup.to(self.device) male = male.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() z_reverse = grad_reverse(z.split(30, 1)[-1]) #z_reverse=z.split(10,1)[-1] reverse_output = self.revcls(z_reverse) output = self.pacls(z.split(30, 1)[-1]) z_t_reverse = grad_reverse(z.split(30, 1)[0]) #z_reverse=z.split(10,1)[-1] t_reverse_output = self.trevcls(z_t_reverse) t_output = self.tcls(z.split(30, 1)[0]) #if i==0: # print(output.argmax(1)) # # print(t_reverse_output.argmax(1)) rev_correct = ( reverse_output.argmax(1) == heavy_makeup).sum().float() rev_num = heavy_makeup.size(0) pa_correct = (output.argmax(1) == male).sum().float() pa_num = male.size(0) t_correct = (t_output.argmax(1) == heavy_makeup).sum().float() t_num = heavy_makeup.size(0) t_rev_correct = ( t_reverse_output.argmax(1) == male).sum().float() t_rev_num = male.size(0) total_pa_correct_num += pa_correct total_pa_num += pa_num total_rev_correct_num += rev_correct total_rev_num += rev_num total_t_correct_num += t_correct total_t_num += t_num total_t_rev_correct_num += t_rev_correct total_t_rev_num += t_rev_num total_male_num += (male == 1).sum() total_female_num += (male == 0).sum() #import pdb;pdb.set_trace() total_male_correct += ((output.argmax(1) == male) * (male == 1)).sum() total_female_correct += ((output.argmax(1) == male) * (male == 0)).sum() ''' pa_correct=(output.argm ax(1)==male).sum() pa_num=male.size(0) total_pa_correct_num+=pa_correct total_pa_num+=pa_num total_male_num+=(male==1).sum() total_male_correct+=((output.argmax(1)==male)*(male==1)).sum() total_female_num+=(male==0).sum() total_female_correct+=((output.argmax(1)==male)*(male==0)).sum() ''' #weight=torch.tensor([3.5,1.0]).cuda() pa_cls = F.cross_entropy(output, male) #pa_cls=F.cross_entropy(output,male) rev_cls = F.cross_entropy(reverse_output, heavy_makeup) t_pa_cls = F.cross_entropy(t_output, heavy_makeup) t_rev_cls = F.cross_entropy(t_reverse_output, male) vae_loss = vae_recon_loss + self.beta * vae_kld # + self.grl*rev_cls self.optim_VAE.zero_grad() self.optim_pacls.zero_grad() self.optim_revcls.zero_grad() self.optim_tcls.zero_grad() self.optim_trevcls.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() self.optim_pacls.step() self.optim_revcls.step() self.optim_tcls.step() self.optim_trevcls.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) #D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) D_tc_loss = (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() #D_tc_loss.backward() self.optim_D.step() self.pbar.update(1) self.global_iter += 1 pa_acc = float(total_pa_correct_num) / float(total_pa_num) rev_acc = float(total_rev_correct_num) / float(total_rev_num) t_acc = float(total_t_correct_num) / float(total_t_num) t_rev_acc = float(total_t_rev_correct_num) / float(total_t_rev_num) male_acc = float(total_male_correct) / float(total_male_num) female_acc = float(total_female_correct) / float(total_female_num) if self.global_iter % self.print_iter == 0: self.pbar.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f} pa_cls_loss:{:.3f} pa_acc:{:.3f} m_acc:{:.3f} f_acc:{:.3f} rev_acc:{:.3f} t_acc:{:.3f} t_rev_acc:{:.3f}' .format(self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item(), pa_cls.item(), pa_acc, male_acc, female_acc, rev_acc, t_acc, t_rev_acc)) if self.global_iter % self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) #self.ckpt_save_iter+=1 if self.viz_on and (self.global_iter % self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2 * self.batch_size self.line_gather.insert( iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), kld=vae_kld.item(), acc=D_acc.item()) #viz_ll_iter+=1 if self.viz_on and (self.global_iter % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() #viz_la_iter+=1 if self.viz_on and (self.global_iter % self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() #viz_ra_iter+=1 if self.viz_on and (self.global_iter % self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: self.visualize_traverse(limit=3, inter=2 / 3) self.pbar.write("[Training Finished]") self.pbar.close() self.train_cls()
def train_cls(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False for i_num in range(80): for i, (x_true1, x_true2, heavy_makeup, male) in enumerate(self.data_loader): for name, param in self.VAE.named_parameters(): #if name=='encode.0.weight': # print(param[0]) param.requires_grad = False male = male.to(self.device) heavy_makeup = heavy_makeup.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #weight=torch.tensor([1.0,3.0]).cuda() #target=self.targetcls(z) dim_list = [24] #dim_list=[18,47] target_z = z.split(1, 1)[0] for dim in range(1, len(z[0])): if dim not in dim_list: target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1) target = self.targetcls(target_z) pa_target = self.pa_target(z.split(30, 1)[-1]) target_pa = self.target_pa(z.split(1, 1)[0]) pa_pa = self.pa_pa(z.split(30, 1)[-1]) #weight=torch.tensor([1.0,3.0]).cuda() target_cls = F.cross_entropy(target, heavy_makeup) #import pdb;pdb.set_trace() #pa_target_cls=F.cross_entropy(pa_target,heavy_makeup) #target_pa_cls=F.cross_entropy(target_pa,male) #pa_pa_cls=F.cross_entropy(pa_pa,male) vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss target_loss = target_cls self.optim_cls.zero_grad() #self.optim_pa_target.zero_grad() #self.optim_target_pa.zero_grad() #self.optim_pa_pa.zero_grad() target_loss.backward() #self.optim_pa_target.step() self.optim_cls.step() #self.optim_target_pa.step() #self.optim_pa_pa.step() self.global_iter_cls += 1 self.pbar_cls.update(1) if self.global_iter_cls % self.print_iter == 0: acc = ((target.argmax(1) == heavy_makeup).sum().float() / len(x_true1)).item() pa_acc = ((pa_target.argmax(1) == heavy_makeup).sum().float() / len(x_true1)).item() self.pbar_cls.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}' .format(self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc)) #if self.global_iter_cls%self.ckpt_save_iter == 0: # self.save_checkpoint_cls(self.global_iter_cls) self.val() self.pbar_cls.write("[Classifier Training Finished]") self.pbar_cls.close()
def train_cls(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False for i_num in range(40): for i, (x_true1,x_true2,heavy_makeup, male) in enumerate(self.data_loader): for name,param in self.VAE.named_parameters(): #if name=='encode.0.weight': # print(param[0]) param.requires_grad=False male=male.to(self.device) heavy_makeup=heavy_makeup.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() model=TSNE(learning_rate=100,n_iter=1000) z_tsne=torch.cat((z.split(30,1)[0],z.split(30,1)[1]),0) transformed=model.fit_transform(z_tsne.cpu().detach().numpy()) for k in range(2): xs=transformed[256*(k):256*(k+1),0] ys=transformed[256*(k):256*(k+1),1] plt.scatter(xs,ys,label=str(k),s=100) import pdb;pdb.set_trace() #target=self.targetcls(z) target=self.targetcls(z.split(30,1)[0]) pa_target=self.pa_target(z.split(30,1)[-1]) target_pa=self.target_pa(z.split(30,1)[0]) pa_pa=self.pa_pa(z.split(30,1)[-1]) weight=torch.tensor([1.0,3]).cuda() target_cls=F.cross_entropy(target,heavy_makeup,weight=weight) #import pdb;pdb.set_trace() pa_target_cls=F.cross_entropy(pa_target,heavy_makeup) target_pa_cls=F.cross_entropy(target_pa,male) pa_pa_cls=F.cross_entropy(pa_pa,male) vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss target_loss=pa_target_cls+target_cls+target_pa_cls+pa_pa_cls self.optim_cls.zero_grad() self.optim_pa_target.zero_grad() self.optim_target_pa.zero_grad() self.optim_pa_pa.zero_grad() target_loss.backward() self.optim_pa_target.step() self.optim_cls.step() self.optim_target_pa.step() self.optim_pa_pa.step() self.global_iter_cls += 1 self.pbar_cls.update(1) if self.global_iter_cls%self.print_iter == 0: acc=((target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item() pa_acc=((pa_target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item() self.pbar_cls.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'.format( self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc )) #if self.global_iter_cls%self.ckpt_save_iter == 0: # self.save_checkpoint_cls(self.global_iter_cls) self.val() self.pbar_cls.write("[Classifier Training Finished]") self.pbar_cls.close()