Exemplo n.º 1
0
 def run_model(self, x, y):
     print(x.shape)
     shape = x.shape
     print('shape')
     print(shape)
     x = x.reshape(200,1024)
     x_recon_list, p_dist, mu, logvar, z = self.net(x, train=True)
     x = x.reshape(shape)
     x_recon = x_recon_list[-1]
     x_recon = x_recon.reshape(shape)
     recon_loss = reconstruction_loss(y, x_recon)
     if self.z_dim_bern == 0:                    
         total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(mu, logvar)
         KL_loss = self.beta*total_kld
     elif self.z_dim_gauss == 0:
         if not self.AE:
             total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist)
             KL_loss = self.gamma *total_kld 
         elif self.AE:
             KL_loss =torch.tensor(0.0)
     elif self.z_dim_bern !=0 and self.z_dim_gauss != 0:
         total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(p_dist)
         total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(mu, logvar)
         KL_loss = self.gamma *total_kld_bern + self.beta*total_kld_gauss
     return([recon_loss,KL_loss ])
 def run_model(self, x, y):
     x_recon, p_dist, mu, logvar = self.net(x, train=True)
     recon_loss = reconstruction_loss(y, x_recon)
     if self.z_dim_bern == 0:
         total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(
             mu, logvar)
         KL_loss = self.beta * total_kld
     elif self.z_dim_gauss == 0:
         total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist)
         KL_loss = self.gamma * total_kld
     elif self.z_dim_bern != 0 and self.z_dim_gauss != 0:
         total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(
             p_dist)
         total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(
             mu, logvar)
         KL_loss = self.gamma * total_kld_bern + self.beta * total_kld_gauss
     return ([recon_loss, KL_loss])
    def train(self):
        #self.net(train=True)
        iters_per_epoch = len(self.train_dl)
        print(iters_per_epoch, 'iters per epoch')
        max_iter = self.max_epoch * iters_per_epoch
        batch_size = self.train_dl.batch_size
        current_idxs = 0
        current_flip_idx = []
        count = 0

        out = False
        pbar = tqdm(total=max_iter)
        pbar.update(self.global_iter)

        while not out:
            for sample in self.train_dl:
                self.global_iter += 1
                pbar.update(1)

                if self.flip == True:
                    if count % iters_per_epoch == 0:
                        print("RESETTING COUNTER")
                        count = 0
                    current_idxs = range(count * batch_size,
                                         (count + 1) * batch_size)
                    current_flip_idx = [
                        x for x in self.flip_idx if x in current_idxs
                    ]
                    if not current_flip_idx:
                        current_flip_idx_norm = None
                    else:
                        current_flip_idx_norm = []
                        current_flip_idx_norm[:] = [
                            i - count * batch_size for i in current_flip_idx
                        ]
                else:
                    current_flip_idx_norm = None

                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)

                x_recon, p_dist, mu, logvar = self.net(x, train=True)

                recon_loss = reconstruction_loss(y, x_recon)

                if self.z_dim_bern == 0:
                    total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(
                        mu, logvar)
                    KL_loss = self.beta * total_kld
                elif self.z_dim_gauss == 0:
                    total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(
                        p_dist)
                    KL_loss = self.gamma * total_kld
                elif self.z_dim_bern != 0 and self.z_dim_gauss != 0:
                    total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(
                        p_dist)
                    total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(
                        mu, logvar)
                    KL_loss = self.gamma * total_kld_bern + self.beta * total_kld_gauss

                loss = recon_loss + KL_loss

                self.adjust_learning_rate(self.optim,
                                          (count / iters_per_epoch))
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                count += 1

                if self.global_iter % self.gather_step == 0:
                    self.test_loss()
                    if self.gnrl_dl != 0:
                        self.gnrl_loss()
                        with open("{}/LOGBOOK.txt".format(self.output_dir),
                                  "a") as myfile:
                            myfile.write(
                                '\n[{}] train_loss:{:.3f},  train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}'
                                .format(self.global_iter, float(loss.data),
                                        float(recon_loss.data),
                                        float(KL_loss.data), self.testLoss,
                                        self.test_recon_loss,
                                        self.test_kl_loss, self.gnrlLoss,
                                        self.gnrl_recon_loss,
                                        self.gnrl_kl_loss))

                        self.gather.insert(
                            iter=self.global_iter,
                            trainLoss=float(loss.data),
                            train_recon_loss=float(recon_loss.data),
                            train_KL_loss=float(KL_loss.data),
                            testLoss=self.testLoss,
                            test_recon_loss=self.test_recon_loss,
                            test_kl_loss=self.test_kl_loss,
                            gnrlLoss=self.gnrlLoss,
                            gnrl_recon_loss=self.gnrl_recon_loss,
                            gnrl_kl_loss=self.gnrl_kl_loss)
                    else:
                        with open("{}/LOGBOOK.txt".format(self.output_dir),
                                  "a") as myfile:
                            myfile.write(
                                '\n[{}] train_loss:{:.3f},  train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}'
                                .format(
                                    self.global_iter,
                                    float(loss.data),
                                    float(recon_loss.data),
                                    float(KL_loss.data),
                                    self.testLoss,
                                    self.test_recon_loss,
                                    self.test_kl_loss,
                                ))

                        self.gather.insert(
                            iter=self.global_iter,
                            trainLoss=loss.data.cpu().numpy(),
                            train_recon_loss=recon_loss.data.cpu().numpy(),
                            train_KL_loss=KL_loss.data.cpu().numpy(),
                            testLoss=self.testLoss,
                            test_recon_loss=self.test_recon_loss,
                            test_kl_loss=self.test_kl_loss)

                if self.global_iter % self.display_step == 0:
                    if self.z_dim_bern != 0 and self.z_dim_gauss != 0:
                        pbar.write(
                            '[{}] recon_loss:{:.3f} total_kld_gauss:{:.3f} mean_kld_gauss:{:.3f} total_kld_bern:{:.3f} mean_kld_bern:{:.3f}'
                            .format(self.global_iter, recon_loss.data,
                                    total_kld_gauss.data[0],
                                    mean_kld_gauss.data[0],
                                    total_kld_bern.data[0],
                                    mean_kld_bern.data[0]))
                    else:
                        pbar.write(
                            '[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f} '
                            .format(self.global_iter, recon_loss.data,
                                    total_kld.data[0], mean_kld.data[0]))

                    if self.z_dim_bern != 0:
                        var = logvar.exp().mean(0).data
                        var_str = ''
                        for j, var_j in enumerate(var):
                            var_str += 'var{}:{:.4f} '.format(j + 1, var_j)
                        pbar.write(var_str)

                if self.global_iter % self.save_step == 0:
                    self.save_checkpoint('last')
                    oldtestLoss = self.testLoss
                    self.test_loss()
                    if self.gnrl_dl != 0:
                        self.gnrl_loss()
                    print('old test loss', oldtestLoss, 'current test loss',
                          self.testLoss)
                    if self.testLoss < oldtestLoss:
                        self.save_checkpoint('best')
                        pbar.write('Saved best checkpoint(iter:{})'.format(
                            self.global_iter))

                    self.test_plots()
                    self.gather.save_data(self.global_iter, self.output_dir,
                                          'last')

                if self.global_iter % 500 == 0:
                    self.save_checkpoint(str(self.global_iter))
                    self.gather.save_data(self.global_iter, self.output_dir,
                                          None)

                if self.global_iter >= max_iter:
                    out = True
                    break

        pbar.write("[Training Finished]")
        pbar.close()
Exemplo n.º 4
0
    def train(self):
        #self.net(train=True)
        iters_per_epoch = len(self.train_dl)
        print(iters_per_epoch, 'iters per epoch')
        max_iter = self.max_epoch*iters_per_epoch
        batch_size = self.train_dl.batch_size
        current_idxs  = 0
        current_flip_idx = []
        count = 0

        out = False
        pbar = tqdm(total=max_iter)
        pbar.update(self.global_iter)
        
        while not out:
            for sample in self.train_dl:
                self.global_iter += 1
                pbar.update(1)
               
                if self.flip == True:
                    if count%iters_per_epoch==0:
                        print("RESETTING COUNTER")
                        count=0
                    current_idxs = range(count*batch_size, (count+1)*batch_size)
                    current_flip_idx = [x for x in self.flip_idx if x in current_idxs]
                    if not current_flip_idx:
                        current_flip_idx_norm = None
                    else:
                        current_flip_idx_norm = []
                        current_flip_idx_norm[:] = [i - count*batch_size for i in current_flip_idx]
                else:
                    current_flip_idx_norm = None
                
              
                    
                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)
                print(x.shape)
                shape = x.shape
                #x = x.reshape(-1) # flatten
                x = x.reshape(100,1024)
                print(x.shape)
                
                x_recon_list, p_dist, mu, logvar, z = self.net(x, train=True)
                x_recon = x_recon_list[-1]
                x_recon = x_recon.reshape(shape)
                x = x.reshape(shape)
                self.z_prev = z
                recon_loss = 0
                #for x_recon in x_recon_list:
                    #recon_loss = recon_loss + reconstruction_loss(y, x_recon)
                #recon_loss = reconstruction_loss(y, x_recon_list[-1])  #reconstruction loss for that sample
                recon_loss = reconstruction_loss(y, x_recon)

                if self.z_dim_bern == 0:                    
                    total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(mu, logvar)
                    KL_loss = self.beta*total_kld
                elif self.z_dim_gauss == 0:
                    if not self.AE: 
                        total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist)
                        KL_loss = self.gamma *total_kld
                    if self.AE:
                        total_kld = torch.tensor(0); mean_kld=torch.tensor(0)
                        KL_loss = torch.tensor(0.0)
                elif self.z_dim_bern !=0 and self.z_dim_gauss != 0:
                    total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(p_dist)
                    total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(mu, logvar)
                    KL_loss = self.gamma *total_kld_bern + self.beta*total_kld_gauss
                    
                loss = recon_loss + KL_loss #KL loss =0 for AE

                self.adjust_learning_rate(self.optim, (count/iters_per_epoch))
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                
                count +=1 
                
                if self.global_iter ==1:
                    net_copy = deepcopy(self.net)
                    net_copy.to('cpu') 
                    filename = self.output_dir
                    #z_anal(net_copy, self.gnrl_data, filename, self.global_iter)
                    
                    """
                    if self.gnrl_dl != 0:
                        gnrlLoss = self.gnrl_loss()
                        print(loss.item(),recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss) 
                       # with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                       #     myfile.write('\n[{}] train_loss:{:.3f},  train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}'.format(self.global_iter, loss.item(), recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss))
                            
                        
                        self.gather.insert(iter=self.global_iter, trainLoss = loss.item(), 
                                           train_recon_loss=recon_loss.item(), train_KL_loss =
                                           KL_loss.item(), gnrlLoss = gnrlLoss, 
                                           gnrl_recon_loss =  self.gnrl_recon_loss,
                                           gnrl_kl_loss = self.gnrl_kl_loss  )
                    """
                if self.global_iter%self.gather_step == 0: #this step saves snapshot for plotting
                    if self.gnrl_dl != 0:
                        gnrlLoss = self.gnrl_loss()
                        print(loss.item(),recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss) 
                       # with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                       #     myfile.write('\n[{}] train_loss:{:.3f},  train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}'.format(self.global_iter, loss.item(), recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss))
                            
                        
                        self.gather.insert(iter=self.global_iter, trainLoss = loss.item(), 
                                           train_recon_loss=recon_loss.item(), train_KL_loss =
                                           KL_loss.item(), gnrlLoss = gnrlLoss, 
                                           gnrl_recon_loss =  self.gnrl_recon_loss,
                                           gnrl_kl_loss = self.gnrl_kl_loss  )
                   
                    net_copy = deepcopy(self.net)
                    net_copy.to('cpu')
                    filename = self.output_dir
                    #draw_iterative_img(net_copy, self.gnrl_data, self.global_iter, filename, 15)
                
                if self.global_iter%self.display_step == 0:
                    if self.z_dim_bern !=0 and self.z_dim_gauss != 0:
                        pbar.write('[{}] recon_loss:{:.3f} total_kld_gauss:{:.3f} mean_kld_gauss:{:.3f} total_kld_bern:{:.3f} mean_kld_bern:{:.3f}'.format(
                            self.global_iter, recon_loss.data,
                            total_kld_gauss.data[0], mean_kld_gauss.data[0], total_kld_bern.data[0],
                            mean_kld_bern.data[0]))
                    else:
                        pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f} '.format(
                             self.global_iter, recon_loss.item(), total_kld.item(), mean_kld.item()) )
                                    
                    if not self.AE:
                        if self.z_dim_bern != 0:
                            var = logvar.exp().mean(0).data
                            var_str = ''
                            for j, var_j in enumerate(var):
                                var_str += 'var{}:{:.4f} '.format(j+1, var_j)
                            pbar.write(var_str)
                      
                if self.global_iter%self.save_step == 0:  #this step makes the reconstruction plots for 30 images
                    self.save_checkpoint('last') 
                    if self.gnrl_dl != 0:
                        oldgnrlLoss = self.gnrlLoss
                        self.gnrl_loss()
                        print('old gnrl loss', oldgnrlLoss,'current gnrl loss', self.gnrlLoss )
                        if self.gnrlLoss < oldgnrlLoss:
                            self.save_checkpoint('best_gnrl')
                            pbar.write('Saved best GNRL checkpoint(iter:{})'.format(self.global_iter))
                    
                    self.test_plots()
                    self.gather.save_data(self.global_iter, self.output_dir, 'last' )
                    
                if self.global_iter%5000 == 0:
                    self.save_checkpoint(str(self.global_iter))
                    self.gather.save_data(self.global_iter, self.output_dir, None )

                if self.global_iter >= max_iter:
                    net_copy = deepcopy(self.net)
                    net_copy.to('cpu')
                    #find_top_N(net_copy, self.gnrl_data, self.output_dir,50,30)
                    out = True  
                    filename = self.output_dir
                    plotsave_tests_loss_list(net_copy,self.gnrl_data,filename,1000)  
                    #z_anal(net_copy, self.gnrl_data, filename, self.global_iter)    
                               
                    break
        #with open('outfile', 'wb') as fp:
        #pickle.dump(z, fp)
        #print('type')
        #print(type(self.z_prev))
        #f = open('/home/aiswarya/Columbia_WoRk/output.txt', 'w')
        #simplejson.dump(self.z_prev, f)
        #torch.save(self.z_prev, f)
        #f.close()

        pbar.write("[Training Finished]")
        pbar.close()