Exemplo n.º 1
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False

        for i_num in range(self.max_iter - self.global_iter):
            total_pa_num = 0
            total_pa_correct_num = 0
            total_male_num = 0
            total_male_correct = 0
            total_female_num = 0
            total_female_correct = 0

            total_rev_num = 0
            total_rev_correct_num = 0

            total_t_num = 0
            total_t_correct_num = 0
            total_t_rev_num = 0
            total_t_rev_correct_num = 0

            for i, (x_true1, x_true2, heavy_makeup,
                    male) in enumerate(self.data_loader):
                #from PIL import Image
                #from torchvision import transforms
                #import pdb;pdb.set_trace()
                #import pdb;pdb.set_trace()

                heavy_makeup = heavy_makeup.to(self.device)
                male = male.to(self.device)
                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                z_reverse = grad_reverse(z.split(30, 1)[-1])
                #z_reverse=z.split(10,1)[-1]
                reverse_output = self.revcls(z_reverse)
                output = self.pacls(z.split(30, 1)[-1])

                z_t_reverse = grad_reverse(z.split(30, 1)[0])
                #z_reverse=z.split(10,1)[-1]
                t_reverse_output = self.trevcls(z_t_reverse)
                t_output = self.tcls(z.split(30, 1)[0])

                #if i==0:
                #    print(output.argmax(1))
                #    # print(t_reverse_output.argmax(1))

                rev_correct = (
                    reverse_output.argmax(1) == heavy_makeup).sum().float()
                rev_num = heavy_makeup.size(0)

                pa_correct = (output.argmax(1) == male).sum().float()
                pa_num = male.size(0)

                t_correct = (t_output.argmax(1) == heavy_makeup).sum().float()
                t_num = heavy_makeup.size(0)

                t_rev_correct = (
                    t_reverse_output.argmax(1) == male).sum().float()
                t_rev_num = male.size(0)

                total_pa_correct_num += pa_correct
                total_pa_num += pa_num

                total_rev_correct_num += rev_correct
                total_rev_num += rev_num

                total_t_correct_num += t_correct
                total_t_num += t_num

                total_t_rev_correct_num += t_rev_correct
                total_t_rev_num += t_rev_num

                total_male_num += (male == 1).sum()
                total_female_num += (male == 0).sum()
                #import pdb;pdb.set_trace()

                total_male_correct += ((output.argmax(1) == male) *
                                       (male == 1)).sum()

                total_female_correct += ((output.argmax(1) == male) *
                                         (male == 0)).sum()
                '''
                pa_correct=(output.argm ax(1)==male).sum()
                pa_num=male.size(0)
                
                total_pa_correct_num+=pa_correct
                total_pa_num+=pa_num
                
                total_male_num+=(male==1).sum()
                total_male_correct+=((output.argmax(1)==male)*(male==1)).sum()
                total_female_num+=(male==0).sum()
                total_female_correct+=((output.argmax(1)==male)*(male==0)).sum()
                
                '''
                #weight=torch.tensor([3.5,1.0]).cuda()
                pa_cls = F.cross_entropy(output, male)
                #pa_cls=F.cross_entropy(output,male)
                rev_cls = F.cross_entropy(reverse_output, heavy_makeup)

                t_pa_cls = F.cross_entropy(t_output, heavy_makeup)
                t_rev_cls = F.cross_entropy(t_reverse_output, male)

                vae_loss = vae_recon_loss + self.beta * vae_kld
                # + self.grl*rev_cls

                self.optim_VAE.zero_grad()
                self.optim_pacls.zero_grad()
                self.optim_revcls.zero_grad()
                self.optim_tcls.zero_grad()
                self.optim_trevcls.zero_grad()

                vae_loss.backward(retain_graph=True)

                self.optim_VAE.step()
                self.optim_pacls.step()
                self.optim_revcls.step()
                self.optim_tcls.step()
                self.optim_trevcls.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()

                D_z_pperm = self.D(z_pperm)
                #D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))
                D_tc_loss = (F.cross_entropy(D_z, zeros) +
                             F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                #D_tc_loss.backward()
                self.optim_D.step()

            self.pbar.update(1)
            self.global_iter += 1
            pa_acc = float(total_pa_correct_num) / float(total_pa_num)
            rev_acc = float(total_rev_correct_num) / float(total_rev_num)

            t_acc = float(total_t_correct_num) / float(total_t_num)
            t_rev_acc = float(total_t_rev_correct_num) / float(total_t_rev_num)

            male_acc = float(total_male_correct) / float(total_male_num)
            female_acc = float(total_female_correct) / float(total_female_num)

            if self.global_iter % self.print_iter == 0:
                self.pbar.write(
                    '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f} pa_cls_loss:{:.3f} pa_acc:{:.3f} m_acc:{:.3f}  f_acc:{:.3f} rev_acc:{:.3f} t_acc:{:.3f} t_rev_acc:{:.3f}'
                    .format(self.global_iter, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            D_tc_loss.item(), pa_cls.item(), pa_acc, male_acc,
                            female_acc, rev_acc, t_acc, t_rev_acc))

            if self.global_iter % self.ckpt_save_iter == 0:
                self.save_checkpoint(self.global_iter)
                #self.ckpt_save_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ll_iter == 0):
                soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                D_acc = ((soft_D_z >= 0.5).sum() +
                         (soft_D_z_pperm < 0.5).sum()).float()
                D_acc /= 2 * self.batch_size
                self.line_gather.insert(
                    iter=self.global_iter,
                    soft_D_z=soft_D_z.mean().item(),
                    soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                    recon=vae_recon_loss.item(),
                    kld=vae_kld.item(),
                    acc=D_acc.item())
                #viz_ll_iter+=1
            if self.viz_on and (self.global_iter % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
                #viz_la_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ra_iter == 0):
                self.image_gather.insert(true=x_true1.data.cpu(),
                                         recon=F.sigmoid(x_recon).data.cpu())
                self.visualize_recon()
                self.image_gather.flush()
                #viz_ra_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ta_iter == 0):
                if self.dataset.lower() == '3dchairs':
                    self.visualize_traverse(limit=2, inter=0.5)
                else:
                    self.visualize_traverse(limit=3, inter=2 / 3)

        self.pbar.write("[Training Finished]")
        self.pbar.close()

        self.train_cls()
Exemplo n.º 2
0
    def train_cls(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False
        for i_num in range(80):

            for i, (x_true1, x_true2, heavy_makeup,
                    male) in enumerate(self.data_loader):

                for name, param in self.VAE.named_parameters():
                    #if name=='encode.0.weight':
                    #    print(param[0])

                    param.requires_grad = False

                male = male.to(self.device)
                heavy_makeup = heavy_makeup.to(self.device)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                #weight=torch.tensor([1.0,3.0]).cuda()

                #target=self.targetcls(z)
                dim_list = [24]
                #dim_list=[18,47]

                target_z = z.split(1, 1)[0]

                for dim in range(1, len(z[0])):
                    if dim not in dim_list:

                        target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1)

                target = self.targetcls(target_z)
                pa_target = self.pa_target(z.split(30, 1)[-1])

                target_pa = self.target_pa(z.split(1, 1)[0])
                pa_pa = self.pa_pa(z.split(30, 1)[-1])

                #weight=torch.tensor([1.0,3.0]).cuda()
                target_cls = F.cross_entropy(target, heavy_makeup)

                #import pdb;pdb.set_trace()

                #pa_target_cls=F.cross_entropy(pa_target,heavy_makeup)

                #target_pa_cls=F.cross_entropy(target_pa,male)

                #pa_pa_cls=F.cross_entropy(pa_pa,male)

                vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss

                target_loss = target_cls

                self.optim_cls.zero_grad()
                #self.optim_pa_target.zero_grad()
                #self.optim_target_pa.zero_grad()
                #self.optim_pa_pa.zero_grad()

                target_loss.backward()
                #self.optim_pa_target.step()
                self.optim_cls.step()
                #self.optim_target_pa.step()
                #self.optim_pa_pa.step()

            self.global_iter_cls += 1
            self.pbar_cls.update(1)

            if self.global_iter_cls % self.print_iter == 0:
                acc = ((target.argmax(1) == heavy_makeup).sum().float() /
                       len(x_true1)).item()
                pa_acc = ((pa_target.argmax(1) == heavy_makeup).sum().float() /
                          len(x_true1)).item()
                self.pbar_cls.write(
                    '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'
                    .format(self.global_iter_cls, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            target_loss.item(), acc))

            #if self.global_iter_cls%self.ckpt_save_iter == 0:
            #    self.save_checkpoint_cls(self.global_iter_cls)

            self.val()

        self.pbar_cls.write("[Classifier Training Finished]")
        self.pbar_cls.close()
Exemplo n.º 3
0
    def train_cls(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)
      

        out = False
        for i_num in range(40):
            
            for i, (x_true1,x_true2,heavy_makeup, male) in enumerate(self.data_loader):
                
                for name,param in self.VAE.named_parameters():
                    #if name=='encode.0.weight':
                    #    print(param[0])
                        
                    param.requires_grad=False
                

               
                
                male=male.to(self.device)
                heavy_makeup=heavy_makeup.to(self.device)
             


                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
              
           
                model=TSNE(learning_rate=100,n_iter=1000)
                z_tsne=torch.cat((z.split(30,1)[0],z.split(30,1)[1]),0)
                transformed=model.fit_transform(z_tsne.cpu().detach().numpy())

                for k in range(2):
            
            
                    xs=transformed[256*(k):256*(k+1),0]
                    ys=transformed[256*(k):256*(k+1),1]
                
                    plt.scatter(xs,ys,label=str(k),s=100)
                import pdb;pdb.set_trace()

                #target=self.targetcls(z)
                
                target=self.targetcls(z.split(30,1)[0])
                pa_target=self.pa_target(z.split(30,1)[-1])

                target_pa=self.target_pa(z.split(30,1)[0])
                pa_pa=self.pa_pa(z.split(30,1)[-1])






                weight=torch.tensor([1.0,3]).cuda()
                target_cls=F.cross_entropy(target,heavy_makeup,weight=weight)

                #import pdb;pdb.set_trace()

                pa_target_cls=F.cross_entropy(pa_target,heavy_makeup)

                target_pa_cls=F.cross_entropy(target_pa,male)

                pa_pa_cls=F.cross_entropy(pa_pa,male)

              
               
                
                
                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss
                
                target_loss=pa_target_cls+target_cls+target_pa_cls+pa_pa_cls
                
                
                self.optim_cls.zero_grad()
                self.optim_pa_target.zero_grad()
                self.optim_target_pa.zero_grad()
                self.optim_pa_pa.zero_grad()
                
                target_loss.backward()
                self.optim_pa_target.step()
                self.optim_cls.step()
                self.optim_target_pa.step()
                self.optim_pa_pa.step()
               
                
            self.global_iter_cls += 1
            self.pbar_cls.update(1)

            if self.global_iter_cls%self.print_iter == 0:
                acc=((target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item()
                pa_acc=((pa_target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item()
                self.pbar_cls.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'.format(
                    self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc ))

            #if self.global_iter_cls%self.ckpt_save_iter == 0:
            #    self.save_checkpoint_cls(self.global_iter_cls)

            self.val()
            
        self.pbar_cls.write("[Classifier Training Finished]")
        self.pbar_cls.close()