def run_model(self, x, y): print(x.shape) shape = x.shape print('shape') print(shape) x = x.reshape(200,1024) x_recon_list, p_dist, mu, logvar, z = self.net(x, train=True) x = x.reshape(shape) x_recon = x_recon_list[-1] x_recon = x_recon.reshape(shape) recon_loss = reconstruction_loss(y, x_recon) if self.z_dim_bern == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(mu, logvar) KL_loss = self.beta*total_kld elif self.z_dim_gauss == 0: if not self.AE: total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist) KL_loss = self.gamma *total_kld elif self.AE: KL_loss =torch.tensor(0.0) elif self.z_dim_bern !=0 and self.z_dim_gauss != 0: total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(p_dist) total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(mu, logvar) KL_loss = self.gamma *total_kld_bern + self.beta*total_kld_gauss return([recon_loss,KL_loss ])
def run_model(self, x, y): x_recon, p_dist, mu, logvar = self.net(x, train=True) recon_loss = reconstruction_loss(y, x_recon) if self.z_dim_bern == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian( mu, logvar) KL_loss = self.beta * total_kld elif self.z_dim_gauss == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist) KL_loss = self.gamma * total_kld elif self.z_dim_bern != 0 and self.z_dim_gauss != 0: total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli( p_dist) total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian( mu, logvar) KL_loss = self.gamma * total_kld_bern + self.beta * total_kld_gauss return ([recon_loss, KL_loss])
def train(self): #self.net(train=True) iters_per_epoch = len(self.train_dl) print(iters_per_epoch, 'iters per epoch') max_iter = self.max_epoch * iters_per_epoch batch_size = self.train_dl.batch_size current_idxs = 0 current_flip_idx = [] count = 0 out = False pbar = tqdm(total=max_iter) pbar.update(self.global_iter) while not out: for sample in self.train_dl: self.global_iter += 1 pbar.update(1) if self.flip == True: if count % iters_per_epoch == 0: print("RESETTING COUNTER") count = 0 current_idxs = range(count * batch_size, (count + 1) * batch_size) current_flip_idx = [ x for x in self.flip_idx if x in current_idxs ] if not current_flip_idx: current_flip_idx_norm = None else: current_flip_idx_norm = [] current_flip_idx_norm[:] = [ i - count * batch_size for i in current_flip_idx ] else: current_flip_idx_norm = None x = sample['x'].to(self.device) y = sample['y'].to(self.device) x_recon, p_dist, mu, logvar = self.net(x, train=True) recon_loss = reconstruction_loss(y, x_recon) if self.z_dim_bern == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian( mu, logvar) KL_loss = self.beta * total_kld elif self.z_dim_gauss == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli( p_dist) KL_loss = self.gamma * total_kld elif self.z_dim_bern != 0 and self.z_dim_gauss != 0: total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli( p_dist) total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian( mu, logvar) KL_loss = self.gamma * total_kld_bern + self.beta * total_kld_gauss loss = recon_loss + KL_loss self.adjust_learning_rate(self.optim, (count / iters_per_epoch)) self.optim.zero_grad() loss.backward() self.optim.step() count += 1 if self.global_iter % self.gather_step == 0: self.test_loss() if self.gnrl_dl != 0: self.gnrl_loss() with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile: myfile.write( '\n[{}] train_loss:{:.3f}, train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}' .format(self.global_iter, float(loss.data), float(recon_loss.data), float(KL_loss.data), self.testLoss, self.test_recon_loss, self.test_kl_loss, self.gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss)) self.gather.insert( iter=self.global_iter, trainLoss=float(loss.data), train_recon_loss=float(recon_loss.data), train_KL_loss=float(KL_loss.data), testLoss=self.testLoss, test_recon_loss=self.test_recon_loss, test_kl_loss=self.test_kl_loss, gnrlLoss=self.gnrlLoss, gnrl_recon_loss=self.gnrl_recon_loss, gnrl_kl_loss=self.gnrl_kl_loss) else: with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile: myfile.write( '\n[{}] train_loss:{:.3f}, train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}' .format( self.global_iter, float(loss.data), float(recon_loss.data), float(KL_loss.data), self.testLoss, self.test_recon_loss, self.test_kl_loss, )) self.gather.insert( iter=self.global_iter, trainLoss=loss.data.cpu().numpy(), train_recon_loss=recon_loss.data.cpu().numpy(), train_KL_loss=KL_loss.data.cpu().numpy(), testLoss=self.testLoss, test_recon_loss=self.test_recon_loss, test_kl_loss=self.test_kl_loss) if self.global_iter % self.display_step == 0: if self.z_dim_bern != 0 and self.z_dim_gauss != 0: pbar.write( '[{}] recon_loss:{:.3f} total_kld_gauss:{:.3f} mean_kld_gauss:{:.3f} total_kld_bern:{:.3f} mean_kld_bern:{:.3f}' .format(self.global_iter, recon_loss.data, total_kld_gauss.data[0], mean_kld_gauss.data[0], total_kld_bern.data[0], mean_kld_bern.data[0])) else: pbar.write( '[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f} ' .format(self.global_iter, recon_loss.data, total_kld.data[0], mean_kld.data[0])) if self.z_dim_bern != 0: var = logvar.exp().mean(0).data var_str = '' for j, var_j in enumerate(var): var_str += 'var{}:{:.4f} '.format(j + 1, var_j) pbar.write(var_str) if self.global_iter % self.save_step == 0: self.save_checkpoint('last') oldtestLoss = self.testLoss self.test_loss() if self.gnrl_dl != 0: self.gnrl_loss() print('old test loss', oldtestLoss, 'current test loss', self.testLoss) if self.testLoss < oldtestLoss: self.save_checkpoint('best') pbar.write('Saved best checkpoint(iter:{})'.format( self.global_iter)) self.test_plots() self.gather.save_data(self.global_iter, self.output_dir, 'last') if self.global_iter % 500 == 0: self.save_checkpoint(str(self.global_iter)) self.gather.save_data(self.global_iter, self.output_dir, None) if self.global_iter >= max_iter: out = True break pbar.write("[Training Finished]") pbar.close()
def train(self): #self.net(train=True) iters_per_epoch = len(self.train_dl) print(iters_per_epoch, 'iters per epoch') max_iter = self.max_epoch*iters_per_epoch batch_size = self.train_dl.batch_size current_idxs = 0 current_flip_idx = [] count = 0 out = False pbar = tqdm(total=max_iter) pbar.update(self.global_iter) while not out: for sample in self.train_dl: self.global_iter += 1 pbar.update(1) if self.flip == True: if count%iters_per_epoch==0: print("RESETTING COUNTER") count=0 current_idxs = range(count*batch_size, (count+1)*batch_size) current_flip_idx = [x for x in self.flip_idx if x in current_idxs] if not current_flip_idx: current_flip_idx_norm = None else: current_flip_idx_norm = [] current_flip_idx_norm[:] = [i - count*batch_size for i in current_flip_idx] else: current_flip_idx_norm = None x = sample['x'].to(self.device) y = sample['y'].to(self.device) print(x.shape) shape = x.shape #x = x.reshape(-1) # flatten x = x.reshape(100,1024) print(x.shape) x_recon_list, p_dist, mu, logvar, z = self.net(x, train=True) x_recon = x_recon_list[-1] x_recon = x_recon.reshape(shape) x = x.reshape(shape) self.z_prev = z recon_loss = 0 #for x_recon in x_recon_list: #recon_loss = recon_loss + reconstruction_loss(y, x_recon) #recon_loss = reconstruction_loss(y, x_recon_list[-1]) #reconstruction loss for that sample recon_loss = reconstruction_loss(y, x_recon) if self.z_dim_bern == 0: total_kld, dim_wise_kld, mean_kld = kl_divergence_gaussian(mu, logvar) KL_loss = self.beta*total_kld elif self.z_dim_gauss == 0: if not self.AE: total_kld, dim_wise_kld, mean_kld = kl_divergence_bernoulli(p_dist) KL_loss = self.gamma *total_kld if self.AE: total_kld = torch.tensor(0); mean_kld=torch.tensor(0) KL_loss = torch.tensor(0.0) elif self.z_dim_bern !=0 and self.z_dim_gauss != 0: total_kld_bern, dim_wise_kld_bern, mean_kld_bern = kl_divergence_bernoulli(p_dist) total_kld_gauss, dim_wise_kld_gauss, mean_kld_gauss = kl_divergence_gaussian(mu, logvar) KL_loss = self.gamma *total_kld_bern + self.beta*total_kld_gauss loss = recon_loss + KL_loss #KL loss =0 for AE self.adjust_learning_rate(self.optim, (count/iters_per_epoch)) self.optim.zero_grad() loss.backward() self.optim.step() count +=1 if self.global_iter ==1: net_copy = deepcopy(self.net) net_copy.to('cpu') filename = self.output_dir #z_anal(net_copy, self.gnrl_data, filename, self.global_iter) """ if self.gnrl_dl != 0: gnrlLoss = self.gnrl_loss() print(loss.item(),recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss) # with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile: # myfile.write('\n[{}] train_loss:{:.3f}, train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}'.format(self.global_iter, loss.item(), recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss)) self.gather.insert(iter=self.global_iter, trainLoss = loss.item(), train_recon_loss=recon_loss.item(), train_KL_loss = KL_loss.item(), gnrlLoss = gnrlLoss, gnrl_recon_loss = self.gnrl_recon_loss, gnrl_kl_loss = self.gnrl_kl_loss ) """ if self.global_iter%self.gather_step == 0: #this step saves snapshot for plotting if self.gnrl_dl != 0: gnrlLoss = self.gnrl_loss() print(loss.item(),recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss) # with open("{}/LOGBOOK.txt".format(self.output_dir), "a") as myfile: # myfile.write('\n[{}] train_loss:{:.3f}, train_recon_loss:{:.3f}, train_KL_loss:{:.3f}, test_loss:{:.3f}, test_recon_loss:{:.3f} , test_KL_loss:{:.3f}, gnrl_loss:{:.3f}, gnrl_recon_loss:{:.3f}, gnrl_KL_loss:{:.3f}'.format(self.global_iter, loss.item(), recon_loss.item(), KL_loss.item(), gnrlLoss, self.gnrl_recon_loss, self.gnrl_kl_loss)) self.gather.insert(iter=self.global_iter, trainLoss = loss.item(), train_recon_loss=recon_loss.item(), train_KL_loss = KL_loss.item(), gnrlLoss = gnrlLoss, gnrl_recon_loss = self.gnrl_recon_loss, gnrl_kl_loss = self.gnrl_kl_loss ) net_copy = deepcopy(self.net) net_copy.to('cpu') filename = self.output_dir #draw_iterative_img(net_copy, self.gnrl_data, self.global_iter, filename, 15) if self.global_iter%self.display_step == 0: if self.z_dim_bern !=0 and self.z_dim_gauss != 0: pbar.write('[{}] recon_loss:{:.3f} total_kld_gauss:{:.3f} mean_kld_gauss:{:.3f} total_kld_bern:{:.3f} mean_kld_bern:{:.3f}'.format( self.global_iter, recon_loss.data, total_kld_gauss.data[0], mean_kld_gauss.data[0], total_kld_bern.data[0], mean_kld_bern.data[0])) else: pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f} '.format( self.global_iter, recon_loss.item(), total_kld.item(), mean_kld.item()) ) if not self.AE: if self.z_dim_bern != 0: var = logvar.exp().mean(0).data var_str = '' for j, var_j in enumerate(var): var_str += 'var{}:{:.4f} '.format(j+1, var_j) pbar.write(var_str) if self.global_iter%self.save_step == 0: #this step makes the reconstruction plots for 30 images self.save_checkpoint('last') if self.gnrl_dl != 0: oldgnrlLoss = self.gnrlLoss self.gnrl_loss() print('old gnrl loss', oldgnrlLoss,'current gnrl loss', self.gnrlLoss ) if self.gnrlLoss < oldgnrlLoss: self.save_checkpoint('best_gnrl') pbar.write('Saved best GNRL checkpoint(iter:{})'.format(self.global_iter)) self.test_plots() self.gather.save_data(self.global_iter, self.output_dir, 'last' ) if self.global_iter%5000 == 0: self.save_checkpoint(str(self.global_iter)) self.gather.save_data(self.global_iter, self.output_dir, None ) if self.global_iter >= max_iter: net_copy = deepcopy(self.net) net_copy.to('cpu') #find_top_N(net_copy, self.gnrl_data, self.output_dir,50,30) out = True filename = self.output_dir plotsave_tests_loss_list(net_copy,self.gnrl_data,filename,1000) #z_anal(net_copy, self.gnrl_data, filename, self.global_iter) break #with open('outfile', 'wb') as fp: #pickle.dump(z, fp) #print('type') #print(type(self.z_prev)) #f = open('/home/aiswarya/Columbia_WoRk/output.txt', 'w') #simplejson.dump(self.z_prev, f) #torch.save(self.z_prev, f) #f.close() pbar.write("[Training Finished]") pbar.close()