Пример #1
0
    def test_loss(self):
        print("Calculating test loss")
        testLoss = 0.0
        test_accuracy=0.0
        depth_accuracy = 0.0
        black_accuracy = 0.0
        white_accuracy = 0.0
        cnt = 0
        
        with torch.no_grad():
            for sample in self.test_dl:
                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)
                
                if self.testing_method =='supervised_encoder':
                    testLoss_list= self.run_model(self.testing_method, x, y, self.l2_loss)
                    final_out =testLoss_list[1]
                    if self.encoder_target_type== 'joint':
                        test_accuracy += get_accuracy(final_out,y,self.encoder_target_type)
                    else:
                        accuracy_list = get_accuracy(final_out,y,self.encoder_target_type)
                        depth_accuracy += accuracy_list[0]
                        black_accuracy += accuracy_list[1]
                        white_accuracy += accuracy_list[2]
                elif self.testing_method =='supervised_decoder':
                    x = x.type(torch.FloatTensor).to(self.device)
                    testLoss_list= self.run_model(self.testing_method, x, y, self.l2_loss)
                    test_accuracy = 0
                testLoss += testLoss_list[0]
                cnt += 1

        testLoss = testLoss.div(cnt)
        self.testLoss = testLoss.cpu().numpy()#[0]
        print('[{}] test_Loss:{:.3f}'.format(self.global_iter, self.testLoss))

        if self.encoder_target_type== 'joint':
            self.test_accuracy = test_accuracy / cnt
            print('[{}] test accuracy:{:.3f}'.format(self.global_iter, self.test_accuracy))
        else:
            self.test_depth_accuracy = depth_accuracy/cnt
            self.test_black_accuracy = black_accuracy/cnt
            self.test_white_accuracy = white_accuracy/cnt
            print('[{}] test_depth_accuracy:{:.3f}, test_black_accuracy:{:.3f}, test_white_accuracy:{:.3f}'.format(self.global_iter, self.test_depth_accuracy, self.test_black_accuracy,self.test_white_accuracy))
Пример #2
0
    def gnrl_loss(self):
        print("Calculating generalisation loss")
        gnrlLoss = 0.0
        gnrl_accuracy = 0.0
        depth_accuracy = 0.0
        black_accuracy = 0.0
        white_accuracy = 0.0
        gnrl_back_accuracy = 0.0
        gnrl_mid_accuracy = 0.0
        gnrl_front_accuracy =0.0
        gnrl_total_last_iter_loss= 0.0
        gnrl_back_loss = 0.0
        gnrl_front_loss = 0.0
        gnrl_xy_loss = 0.0
        gnrl_mid_loss = 0.0
        gnrl_l2_loss = 0.0
        gnrlLoss_set = []
        cnt = 0
        with torch.no_grad():
            for sample in self.gnrl_dl:
                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)
                    
                if self.testing_method =='supervised_encoder':
                    grnlLoss_list = self.run_model(self.testing_method, x, y, self.l2_loss)
                    final_encoder_losses = grnlLoss_list[1]
                    final_encoder_losses = [x.item() for x in final_encoder_losses]
                    gnrl_back_loss += final_encoder_losses[1] 
                    gnrl_front_loss += final_encoder_losses[3]
                    gnrl_xy_loss += final_encoder_losses[4]
                    gnrl_l2_loss +=  grnlLoss_list[3]
                    
                    gnrl_total_last_iter_loss += final_encoder_losses[0]
                    gnrl_total_last_iter_loss += grnlLoss_list[3]
                    
 
                    if self.n_digits ==3:
                        gnrl_mid_loss += final_encoder_losses[2]
                    
                    final_out = grnlLoss_list[2]
                    
                    if self.encoder_target_type== 'joint':
                        test_accuracy += get_accuracy(final_out,y,self.encoder_target_type, self.n_digits)
                    elif self.encoder_target_type== "depth_black_white" or self.encoder_target_type== "depth_black_white_xy_xy":
                        accuracy_list = get_accuracy(final_out,y,self.encoder_target_type, self.n_digits)
                        depth_accuracy += accuracy_list[0]
                        black_accuracy += accuracy_list[1]
                        white_accuracy += accuracy_list[2]
                    elif self.encoder_target_type== "depth_ordered_one_hot" or self.encoder_target_type== "depth_ordered_one_hot_xy" :
                        if self.n_digits ==2:
                            accuracy_list = get_accuracy(final_out,y, self.encoder_target_type, self.n_digits)
                            gnrl_back_accuracy += accuracy_list[0]
                            gnrl_front_accuracy += accuracy_list[1]
                                
                        elif self.n_digits ==3:
                            accuracy_list = get_accuracy(final_out,y, self.encoder_target_type, self.n_digits)
                            gnrl_back_accuracy += accuracy_list[0]
                            gnrl_mid_accuracy += accuracy_list[1]
                            gnrl_front_accuracy += accuracy_list[2]
                        
                elif self.testing_method =='supervised_decoder':
                    x = x.type(torch.FloatTensor).to(self.device)
                    grnlLoss_list = self.run_model(self.testing_method, x, y, self.l2_loss)
                    gnrlLoss_list_sep = grnlLoss_list[1]
                    gnrlLoss_list_sep = [x.item() for x in gnrlLoss_list_sep]
                    gnrl_total_last_iter_loss += gnrlLoss_list_sep[-1]
                    gnrl_total_last_iter_loss = gnrl_total_last_iter_loss + grnlLoss_list[-1]
                    gnrl_l2_loss +=  grnlLoss_list[-1]
                    grnl_accuracy = 0
                    #gnrlLoss_set = [sum(x) for x in zip(grnlLoss_list[1], gnrlLoss_set)]
                
                gnrlLoss += grnlLoss_list[0]

                cnt += 1
                
        gnrlLoss = gnrlLoss.div(cnt)
        gnrlLoss = gnrlLoss.cpu().item()
        self.gnrl_l2_loss = gnrl_l2_loss/cnt
        print('[{}] all iters gnrl_Loss:{:.3f}, l2_loss{:.3f}'.format(self.global_iter, gnrlLoss, self.gnrl_l2_loss))
        
        if self.testing_method =='supervised_decoder':
            self.gnrl_total_last_iter_loss = gnrl_total_last_iter_loss/cnt
            gnrlLoss_set = [x/cnt for x in gnrlLoss_set]
            print(gnrlLoss_set)
        elif self.testing_method =='supervised_encoder':
            if self.encoder_target_type== 'joint':
                self.gnrl_accuracy = gnrl_accuracy / cnt
                print('[{}] gnrl accuracy:{:.3f}'.format(self.global_iter, self.gnrl_accuracy))
            elif self.encoder_target_type== "depth_black_white" or self.encoder_target_type== "depth_black_white_xy_xy":
                self.gnrl_depth_accuracy = depth_accuracy/cnt
                self.gnrl_black_accuracy = black_accuracy/cnt
                self.gnrl_white_accuracy = white_accuracy/cnt
                print('[{}] gnrl_depth_accuracy:{:.3f}, gnrl_black_accuracy:{:.3f}, gnrl_white_accuracy:{:.3f}'.format(self.global_iter, self.gnrl_depth_accuracy, self.gnrl_black_accuracy,self.gnrl_white_accuracy))
            elif self.encoder_target_type== "depth_ordered_one_hot" or self.encoder_target_type== "depth_ordered_one_hot_xy" :
                self.gnrl_back_loss = gnrl_back_loss/cnt
                self.gnrl_front_loss = gnrl_front_loss/cnt
                self.gnrl_xy_loss = gnrl_xy_loss/cnt
                
                self.gnrl_total_last_iter_loss = gnrl_total_last_iter_loss/cnt
               

                if self.n_digits ==2:
                    self.gnrl_back_accuracy = gnrl_back_accuracy/cnt
                    self.gnrl_front_accuracy = gnrl_front_accuracy/cnt
                    print('[{}] tot_last_iter_gnrl_loss{:.3f}, last_iter_gnrl_back_loss:{:.3f}, last_iter_gnrl_front_loss:{:.3f}, last_iter_gnrl_xy_loss{:.3f}'.format(
                                                                                                                        self.global_iter, self.gnrl_total_last_iter_loss, 
                                                                                                                        self.gnrl_back_loss,self.gnrl_front_loss, self.gnrl_xy_loss))
                    print('[{}] gnrl_back_accuracy:{:.3f}, gnrl_front_accuracy:{:.3f}'.format(self.global_iter, self.gnrl_back_accuracy,self.gnrl_front_accuracy))
                elif self.n_digits ==3:
                    self.gnrl_mid_loss = gnrl_mid_loss/cnt
                    self.gnrl_back_accuracy = gnrl_back_accuracy/cnt
                    self.gnrl_mid_accuracy = gnrl_mid_accuracy/cnt
                    self.gnrl_front_accuracy = gnrl_front_accuracy/cnt
                    print('[{}] tot_last_iter_gnrl_loss{:.3f}, last_iter_gnrl_back_loss:{:.3f},last_iter_gnrl_mid_loss:{:.3f}, last_iter_gnrl_front_loss:{:.3f}, last_iter_gnrl_xy_loss{:.3f}'.format(
                                                                                                                        self.global_iter, self.gnrl_total_last_iter_loss, 
                                                                                                                        self.gnrl_back_loss,self.gnrl_mid_loss, self.gnrl_front_loss, self.gnrl_xy_loss))
                    print('[{}] gnrl_back_accuracy:{:.3f}, gnrl_mid_accuracy:{:.3f}, gnrl_front_accuracy:{:.3f}'.format(self.global_iter, self.gnrl_back_accuracy, self.gnrl_mid_accuracy,self.gnrl_front_accuracy))

        return(gnrlLoss)
Пример #3
0
    def train(self):
        #self.net(train=True)
        iters_per_epoch = len(self.train_dl)
        print(iters_per_epoch, 'iters per epoch')
        max_iter = self.max_epoch*iters_per_epoch
        batch_size = self.train_dl.batch_size
        oldgnrlLoss = math.inf
        
        count = 0
        out = False
        pbar = tqdm(total=max_iter)
        pbar.update(self.global_iter)
        
        while not out:
            for sample in self.train_dl:
                self.global_iter += 1
                pbar.update(1)
            
                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)
                
                #print(x.shape)
                #print(y.shape)
                #for i in range(x.size(0)):
                #    print(x[i,:])
                #    torchvision.utils.save_image( y[i,0,:,:] , '{}/x_{}_{}.png'.format(self.output_dir, self.global_iter, i)) 
                #   print(y[i,:])
                
                if self.testing_method =='supervised_encoder':
                    loss, final_loss_list, final_out, train_l2_reg_loss = self.run_model(self.testing_method, x, y, self.l2_loss)
                elif self.testing_method == 'supervised_decoder':
                    x = x.type(torch.FloatTensor).to(self.device)
                    loss, loss_list, recon, train_l2_reg_loss = self.run_model(self.testing_method, x, y, self.l2_loss)
                    loss_list = [round(x.item(),3) for x in loss_list]
                
                self.adjust_learning_rate(self.optim, (count/iters_per_epoch))
                
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                
                count +=1 
                
            
                if self.global_iter%self.gather_step == 0:
                    gnrlLoss = self.gnrl_loss()
                  
                    if self.testing_method =='supervised_encoder': 
                        
                        if self.encoder_target_type== 'joint':
                            self.gather.insert(iter=self.global_iter, train_loss=loss.data.item(), 
                                               gnrl_loss = gnrlLoss,
                                               train_accuracy = train_accuracy,
                                                gnrl_accuracy = self.accuracy)
                        elif self.encoder_target_type== "depth_black_white" or self.encoder_target_type== "depth_black_white_xy_xy":
                            accuracy_list = get_accuracy(final_out,y, self.encoder_target_type , self.n_digits)
                            train_depth_accuracy = accuracy_list[0]
                            train_black_accuracy = accuracy_list[1]
                            train_white_accuracy = accuracy_list[2]
                            self.gather.insert(iter=self.global_iter, train_loss=loss.data.item(), 
                                               gnrl_loss = gnrlLoss,
                                               train_depth_accuracy = train_depth_accuracy,
                                               train_black_accuracy = train_black_accuracy,
                                               train_white_accuracy= train_white_accuracy,
                                               gnrl_depth_accuracy = self.gnrl_depth_accuracy,
                                               gnrl_black_accuracy = self.gnrl_black_accuracy,
                                               gnrl_white_accuracy = self.gnrl_white_accuracy,
                                              depth_loss = float(final_loss_list[1]),
                                              black_loss = float(final_loss_list[2]),
                                              white_loss = float(final_loss_list[3]),
                                              xy_loss = float(final_loss_list[4]))
                        elif self.encoder_target_type== "depth_ordered_one_hot" or self.encoder_target_type== "depth_ordered_one_hot_xy" :
                            if self.n_digits ==2:
                                accuracy_list = get_accuracy(final_out,y, self.encoder_target_type, self.n_digits)
                                train_back_accuracy = accuracy_list[0]
                                train_front_accuracy = accuracy_list[1]
                                self.gather.insert(iter=self.global_iter, train_loss=loss.data.item(), 
                                                   gnrl_loss = gnrlLoss,
                                                   l2_reg_loss = train_l2_reg_loss.item(),
                                                   
                                                   train_back_accuracy = train_back_accuracy,
                                                   train_front_accuracy= train_front_accuracy,
                                                   gnrl_back_accuracy = self.gnrl_back_accuracy,
                                                   gnrl_front_accuracy = self.gnrl_front_accuracy,
                                                   
                                                   train_tot_final_iter_loss = float(final_loss_list[0]) + train_l2_reg_loss,
                                                   train_back_loss = float(final_loss_list[1]),
                                                   train_front_loss = float(final_loss_list[3]),
                                                   train_xy_loss = float(final_loss_list[4]),
                                                   
                                                   gnrl_tot_final_iter_loss = self.gnrl_total_last_iter_loss,
                                                   gnrl_back_loss = self.gnrl_back_loss,
                                                   gnrl_front_loss = self.gnrl_front_loss,
                                                   gnrl_xy_loss= self.gnrl_xy_loss
                                                  )
                                
                            elif self.n_digits ==3:
                                accuracy_list = get_accuracy(final_out,y, self.encoder_target_type, self.n_digits)
                                train_back_accuracy = accuracy_list[0]
                                train_mid_accuracy = accuracy_list[1]
                                train_front_accuracy = accuracy_list[2]
                                self.gather.insert(iter=self.global_iter, train_loss=loss.data.item(), 
                                                   gnrl_loss = gnrlLoss,
                                                   l2_reg_loss = train_l2_reg_loss,
                                                   
                                                   train_back_accuracy = train_back_accuracy,
                                                   train_mid_accuracy = train_mid_accuracy,
                                                   train_front_accuracy= train_front_accuracy,
                                                   gnrl_back_accuracy = self.gnrl_back_accuracy,
                                                   gnrl_mid_accuracy = self.gnrl_mid_accuracy,
                                                   gnrl_front_accuracy = self.gnrl_front_accuracy,
                                                   
                                                   train_tot_final_iter_loss = float(final_loss_list[0]) + train_l2_reg_loss,
                                                   train_back_loss = float(final_loss_list[1]),
                                                   train_mid_loss = float(final_loss_list[2]),
                                                   train_front_loss = float(final_loss_list[3]),
                                                   train_xy_loss = float(final_loss_list[4]),
                                                   
                                                   gnrl_tot_final_iter_loss = self.gnrl_total_last_iter_loss,
                                                   gnrl_back_loss = self.gnrl_back_loss,
                                                   gnrl_mid_loss = self.gnrl_mid_loss,
                                                   gnrl_front_loss = self.gnrl_front_loss,
                                                   gnrl_xy_loss=  self.gnrl_xy_loss,
                                                   
                                                  )
                                with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                                    myfile.write('\n[{}] train_loss:{:.3f}, gnrl_loss:{:.3f}'.format(self.global_iter,loss.data.item(), gnrlLoss))
                            
                        
                        
                    elif self.testing_method =='supervised_decoder':
                        if self.decoder =='B':
                            self.gather.insert(iter=self.global_iter, train_recon_loss = loss.item(), gnrl_recon_loss = gnrlLoss)
                            with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                                myfile.write('\n[{}] train_recon_loss:{:.3f}, gnrl_recon_loss:{:.3f}'.format(self.global_iter, loss.item(), gnrlLoss))
                        else:
                            self.gather.insert(iter=self.global_iter, train_recon_loss = loss.item(), gnrl_recon_loss = gnrlLoss, train_recon_last_iter_loss=loss_list[-1], gnrl_total_last_iter_loss= self.gnrl_total_last_iter_loss)
                            with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                                myfile.write('\n[{}] train_recon_loss:{:.3f}, gnrl_recon_loss:{:.3f}, {}'.format(self.global_iter, torch.mean(loss), gnrlLoss,loss_list))
                
                
                if self.global_iter%self.display_step == 0:
                    print('[{}] train loss:{:.3f}'.format(self.global_iter, loss.item()))
                    
                    if self.testing_method =='supervised_encoder': 
                        if self.encoder_target_type== 'joint':
                            train_accuracy = get_accuracy(final_out, y, self.encoder_target_type, self.n_digits)
                            print('[{}] train accuracy:{:.3f}'.format(self.global_iter, train_accuracy))
                        elif self.encoder_target_type== "depth_black_white" or self.encoder_target_type== "depth_black_white_xy_xy":
                            accuracy_list = get_accuracy(final_out,y,self.encoder_target_type, self.n_digits)
                            train_depth_accuracy = accuracy_list[0]
                            train_black_accuracy = accuracy_list[1]
                            train_white_accuracy = accuracy_list[2]
                            
                            print('[{}], train_depth_accuracy:{:.3f}, train_black_accuracy:{:.3f}, train_white_accuracy:{:.3f}'.format(self.global_iter, train_depth_accuracy, train_black_accuracy, train_white_accuracy))
                        elif self.encoder_target_type== "depth_ordered_one_hot" or self.encoder_target_type== "depth_ordered_one_hot_xy" :
                            if self.n_digits ==2:
                                accuracy_list = get_accuracy(final_out,y, self.encoder_target_type, self.n_digits)
                                train_back_accuracy = accuracy_list[0]
                                train_front_accuracy = accuracy_list[1]
                                print('[{}], train_back_accuracy:{:.3f}, train_front_accuracy:{:.3f}'.format(self.global_iter, train_back_accuracy, train_front_accuracy))
                                
                            elif self.n_digits ==3:
                                accuracy_list = get_accuracy(final_out,y, self.encoder_target_type , self.n_digits)
                                train_back_accuracy = accuracy_list[0]
                                train_mid_accuracy = accuracy_list[1]
                                train_front_accuracy = accuracy_list[2]

                                print('[{}], train_back_accuracy:{:.3f}, train_mid_accuracy:{:.3f}, train_front_accuracy:{:.3f}'.format(self.global_iter, train_back_accuracy, train_mid_accuracy, train_front_accuracy))
                    elif self.testing_method =='supervised_decoder':
                        if self.decoder != 'B':
                            print([round(x,3) for x in loss_list])

                if self.global_iter%self.save_step == 0:
                    self.save_checkpoint('last') 
                    
                    if self.gnrl_dl != 0:
                        gnrlLoss = self.gnrl_loss()
                        print('old gnrl loss', oldgnrlLoss,'current gnrl loss', gnrlLoss )
                        if gnrlLoss < oldgnrlLoss:
                            oldgnrlLoss = gnrlLoss
                            self.save_checkpoint('best_gnrl')
                            pbar.write('Saved best GNRL checkpoint(iter:{})'.format(self.global_iter))
                    
                    if self.testing_method == 'supervised_decoder':
                        self.test_images()
                    
                    self.gather.save_data(self.global_iter, self.output_dir, 'last' )
                    
                    
                if self.global_iter%5000 == 0:
                    self.save_checkpoint(str(self.global_iter))
                    self.gather.save_data(self.global_iter, self.output_dir, None )

                if self.global_iter >= max_iter:
                    out = True
                    break

        pbar.write("[Training Finished]")
        pbar.close()
def linear_readout_sup(net, max_epoch):
    
    sup_train_dl, sup_test_dl, sup_gnrl_dl = return_data_sup_encoder(net.args)    
    if net.encoder_target_type == 'joint':
        z_out = 10
    elif net.encoder_target_type == 'black_white':
        z_out = 20
    elif net.encoder_target_type == 'depth_black_white':
        z_out = 21
    elif net.encoder_target_type == 'depth_black_white_xy_xy':
        z_out = 25
        
    lin_net = Lin_model((net.z_dim_bern + net.z_dim_gauss), z_out)
    optim_2 = optim.Adam(lin_net.parameters(), lr=net.lr, betas=(net.beta1, net.beta2))    
    
    if torch.cuda.device_count()>1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        lin_net = nn.DataParallel(lin_net)
    lin_net = lin_net.to(net.device) 
        
    iters_per_epoch = len(sup_train_dl)
    print(iters_per_epoch, 'iters per epoch')
    max_iter = max_epoch*iters_per_epoch
    batch_size = sup_train_dl.batch_size
        
    count = 0
    out = False
    pbar = tqdm(total=max_iter)
    global_iter = 0
    pbar.update(global_iter)
        
    while not out:
        for sample in sup_train_dl:
            global_iter += 1
            pbar.update(1)
    
            x = sample['x'].to(net.device)
            y = sample['y'].to(net.device)
                
            with torch.no_grad():
                output = net.net._encode(x)
                #print(out.shape)
                output = output[:, :z_out]
                
                
                
            final_out = lin_net(final_out)
            loss = supervised_encoder_loss(final_out, y, net.encoder_target_type)
                
            l2 = 0
            for p in net.net.parameters():
                l2 = l2 + p.pow(2).sum() #*0.5
            loss = loss + net.l2_loss * l2
                
            optim_2.zero_grad()
            loss.backward()
            optim_2.step()
            
            if global_iter%(max_iter/500 )==0:
                print('[{}] train loss:{:.3f}'.format(global_iter, torch.mean(loss)))
                if net.encoder_target_type== 'joint':
                    train_accuracy = get_accuracy(final_out, y, net.encoder_target_type)
                    print('[{}] train accuracy:{:.3f}'.format(global_iter, train_accuracy))
                else:
                    accuracy_list = get_accuracy(final_out,y,net.encoder_target_type)
                    train_depth_accuracy = accuracy_list[0]
                    train_black_accuracy = accuracy_list[1]
                    train_white_accuracy = accuracy_list[2]

                    print('[{}], train_depth_accuracy:{:.3f}, train_black_accuracy:{:.3f}, train_white_accuracy:{:.3f}'.format(global_iter, train_depth_accuracy, train_black_accuracy, train_white_accuracy))
        
            count +=1 
            if global_iter >= max_iter:
                out = True
                break
                    
        pbar.write("[Training Finished]")
        pbar.close()
Пример #5
0
    def train(self):
        #self.net(train=True)
        iters_per_epoch = len(self.train_dl)
        print(iters_per_epoch, 'iters per epoch')
        max_iter = self.max_epoch*iters_per_epoch
        batch_size = self.train_dl.batch_size
        
        count = 0
        out = False
        pbar = tqdm(total=max_iter)
        pbar.update(self.global_iter)
        
        while not out:
            for sample in self.train_dl:
                self.global_iter += 1
                pbar.update(1)
            
                x = sample['x'].to(self.device)
                y = sample['y'].to(self.device)
                
                #print(x.shape)
                #print(y.shape)
                #for i in range(x.size(0)):
                #    #print(x[i,:])
                #    torchvision.utils.save_image( x[i,0,:,:] , '{}/x_{}_{}.png'.format(self.output_dir, self.global_iter, i)) 
                 #   print(y[i,:])
                
                if self.testing_method =='supervised_encoder':
                    loss, final_out = self.run_model(self.testing_method, x, y, self.l2_loss)
                
                elif self.testing_method == 'supervised_decoder':
                    x = x.type(torch.FloatTensor).to(self.device)
                    loss, recon = self.run_model(self.testing_method, x, y, self.l2_loss)
                    
                self.adjust_learning_rate(self.optim, (count/iters_per_epoch))
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                
                count +=1 
            
            
                if self.global_iter%self.gather_step == 0:
                    self.test_loss()
                    self.gnrl_loss()
                  
                    if self.testing_method =='supervised_encoder': 
                        if self.encoder_target_type== 'joint':
                            self.gather.insert(iter=self.global_iter, train_loss=loss.data, 
                                           test_loss = self.testLoss,gnrl_loss = self.gnrlLoss,
                                               train_accuracy = train_accuracy, test_accuracy = self.test_accuracy,
                                              gnrl_accuracy = self.accuracy)
                        else:
                            accuracy_list = get_accuracy(final_out,y, self.encoder_target_type)
                            train_depth_accuracy = accuracy_list[0]
                            train_black_accuracy = accuracy_list[1]
                            train_white_accuracy = accuracy_list[2]
                            self.gather.insert(iter=self.global_iter, train_loss=loss.data, 
                                           test_loss = self.testLoss, gnrl_loss = self.gnrlLoss,
                                               train_depth_accuracy = train_depth_accuracy,
                                               train_black_accuracy = train_black_accuracy,
                                               train_white_accuracy= train_white_accuracy,
                                               test_depth_accuracy = self.test_depth_accuracy, 
                                               test_black_accuracy = self.test_black_accuracy,
                                               test_white_accuracy  = self.test_white_accuracy,
                                               gnrl_depth_accuracy = self.gnrl_depth_accuracy,
                                               gnrl_black_accuracy = self.gnrl_black_accuracy,
                                               gnrl_white_accuracy = self.gnrl_white_accuracy )
                        
                        
                        
                    elif self.testing_method =='supervised_decoder':
                        
                        self.gather.insert(iter=self.global_iter, train_recon_loss = torch.mean(loss), 
                                           test_recon_loss = self.testLoss, gnrl_recon_loss = self.gnrlLoss)
                        
                        with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile:
                            myfile.write('\n[{}] train_recon_loss:{:.3f}, test_recon_loss:{:.3f} , gnrl_recon_loss:{:.3f}'.format(self.global_iter, torch.mean(loss), self.testLoss, self.gnrlLoss))
                
                
                if self.global_iter%self.display_step == 0:
                    print('[{}] train loss:{:.3f}'.format(self.global_iter, torch.mean(loss)))
                    if self.testing_method =='supervised_encoder': 
                        if self.encoder_target_type== 'joint':
                            train_accuracy = get_accuracy(final_out, y, self.encoder_target_type)
                            print('[{}] train accuracy:{:.3f}'.format(self.global_iter, train_accuracy))
                        else:
                            accuracy_list = get_accuracy(final_out,y,self.encoder_target_type)
                            train_depth_accuracy = accuracy_list[0]
                            train_black_accuracy = accuracy_list[1]
                            train_white_accuracy = accuracy_list[2]
                            
                            print('[{}], train_depth_accuracy:{:.3f}, train_black_accuracy:{:.3f}, train_white_accuracy:{:.3f}'.format(self.global_iter, train_depth_accuracy, train_black_accuracy, train_white_accuracy))
                            
                    
                        
                
                if self.global_iter%self.save_step == 0:
                    self.save_checkpoint('last') 
                    
                    oldtestLoss = self.testLoss
                    self.test_loss()
                    print('old test loss', oldtestLoss,'current test loss', self.testLoss )
                    
                    if self.gnrl_dl != 0:
                        oldgnrlLoss = self.gnrlLoss
                        self.gnrl_loss()
                        print('old gnrl loss', oldgnrlLoss,'current gnrl loss', self.gnrlLoss )
                        if self.gnrlLoss < oldgnrlLoss:
                            self.save_checkpoint('best_gnrl')
                            pbar.write('Saved best GNRL checkpoint(iter:{})'.format(self.global_iter))
                        
                    if self.testLoss  < oldtestLoss or self.gnrlLoss < oldgnrlLoss:
                        self.save_checkpoint('best_test')
                        pbar.write('Saved best TEST checkpoint(iter:{})'.format(self.global_iter))
                    
                    if self.testing_method == 'supervised_decoder':
                        self.test_images()
                    
                    self.gather.save_data(self.global_iter, self.output_dir, 'last' )
                    
                    
                if self.global_iter%5000 == 0:
                    self.save_checkpoint(str(self.global_iter))
                    self.gather.save_data(self.global_iter, self.output_dir, None )

                if self.global_iter >= max_iter:
                    out = True
                    break

        pbar.write("[Training Finished]")
        pbar.close()