class Solver(object): def __init__(self, args): # Misc use_cuda = args.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.name = args.name self.max_iter = int(args.max_iter) self.print_iter = args.print_iter self.global_iter = 0 self.global_iter_cls = 0 self.pbar = tqdm(total=self.max_iter) self.pbar_cls = tqdm(total=self.max_iter) # Data self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.eval_batch_size = args.eval_batch_size self.data_loader = return_data(args, 0) self.data_loader_eval = return_data(args, 2) # Networks & Optimizers self.z_dim = args.z_dim self.gamma = args.gamma self.beta = args.beta self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D self.alpha = args.alpha self.beta = args.beta self.grl = args.grl self.lr_cls = args.lr_cls self.beta1_cls = args.beta1_D self.beta2_cls = args.beta2_D if args.dataset == 'dsprites': self.VAE = FactorVAE1(self.z_dim).to(self.device) self.nc = 1 else: self.VAE = FactorVAE2(self.z_dim).to(self.device) self.nc = 3 self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE)) self.pacls = classifier(30, 2).cuda() self.revcls = classifier(30, 2).cuda() self.tcls = classifier(30, 2).cuda() self.trevcls = classifier(30, 2).cuda() self.targetcls = classifier(59, 2).cuda() self.pa_target = classifier(30, 2).cuda() self.target_pa = paclassifier(1, 1).cuda() self.pa_pa = classifier(30, 2).cuda() self.D = Discriminator(self.z_dim).to(self.device) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1_D, self.beta2_D)) self.optim_pacls = optim.Adam(self.pacls.parameters(), lr=self.lr_D) self.optim_revcls = optim.Adam(self.revcls.parameters(), lr=self.lr_D) self.optim_tcls = optim.Adam(self.tcls.parameters(), lr=self.lr_D) self.optim_trevcls = optim.Adam(self.trevcls.parameters(), lr=self.lr_D) self.optim_cls = optim.Adam(self.targetcls.parameters(), lr=self.lr_cls) self.optim_pa_target = optim.Adam(self.pa_target.parameters(), lr=self.lr_cls) self.optim_target_pa = optim.Adam(self.target_pa.parameters(), lr=self.lr_cls) self.optim_pa_pa = optim.Adam(self.pa_pa.parameters(), lr=self.lr_cls) self.nets = [ self.VAE, self.D, self.pacls, self.targetcls, self.revcls, self.pa_target, self.tcls, self.trevcls ] # Visdom self.viz_on = args.viz_on self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc') self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc') self.image_gather = DataGather('true', 'recon') if self.viz_on: self.viz_port = args.viz_port self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_ra_iter = args.viz_ra_iter self.viz_ta_iter = args.viz_ta_iter if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['D_z']): self.viz_init() # Checkpoint self.ckpt_dir = os.path.join(args.ckpt_dir, args.name) self.ckpt_save_iter = args.ckpt_save_iter mkdirs(self.ckpt_dir + "/cls") mkdirs(self.ckpt_dir + "/vae") if args.ckpt_load: self.load_checkpoint(args.ckpt_load) # Output(latent traverse GIF) self.output_dir = os.path.join(args.output_dir, args.name) self.output_save = args.output_save mkdirs(self.output_dir) def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False for i_num in range(self.max_iter - self.global_iter): total_pa_num = 0 total_pa_correct_num = 0 total_male_num = 0 total_male_correct = 0 total_female_num = 0 total_female_correct = 0 total_rev_num = 0 total_rev_correct_num = 0 total_t_num = 0 total_t_correct_num = 0 total_t_rev_num = 0 total_t_rev_correct_num = 0 for i, (x_true1, x_true2, heavy_makeup, male) in enumerate(self.data_loader): #from PIL import Image #from torchvision import transforms #import pdb;pdb.set_trace() #import pdb;pdb.set_trace() heavy_makeup = heavy_makeup.to(self.device) male = male.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() z_reverse = grad_reverse(z.split(30, 1)[-1]) #z_reverse=z.split(10,1)[-1] reverse_output = self.revcls(z_reverse) output = self.pacls(z.split(30, 1)[-1]) z_t_reverse = grad_reverse(z.split(30, 1)[0]) #z_reverse=z.split(10,1)[-1] t_reverse_output = self.trevcls(z_t_reverse) t_output = self.tcls(z.split(30, 1)[0]) #if i==0: # print(output.argmax(1)) # # print(t_reverse_output.argmax(1)) rev_correct = ( reverse_output.argmax(1) == heavy_makeup).sum().float() rev_num = heavy_makeup.size(0) pa_correct = (output.argmax(1) == male).sum().float() pa_num = male.size(0) t_correct = (t_output.argmax(1) == heavy_makeup).sum().float() t_num = heavy_makeup.size(0) t_rev_correct = ( t_reverse_output.argmax(1) == male).sum().float() t_rev_num = male.size(0) total_pa_correct_num += pa_correct total_pa_num += pa_num total_rev_correct_num += rev_correct total_rev_num += rev_num total_t_correct_num += t_correct total_t_num += t_num total_t_rev_correct_num += t_rev_correct total_t_rev_num += t_rev_num total_male_num += (male == 1).sum() total_female_num += (male == 0).sum() #import pdb;pdb.set_trace() total_male_correct += ((output.argmax(1) == male) * (male == 1)).sum() total_female_correct += ((output.argmax(1) == male) * (male == 0)).sum() ''' pa_correct=(output.argm ax(1)==male).sum() pa_num=male.size(0) total_pa_correct_num+=pa_correct total_pa_num+=pa_num total_male_num+=(male==1).sum() total_male_correct+=((output.argmax(1)==male)*(male==1)).sum() total_female_num+=(male==0).sum() total_female_correct+=((output.argmax(1)==male)*(male==0)).sum() ''' #weight=torch.tensor([3.5,1.0]).cuda() pa_cls = F.cross_entropy(output, male) #pa_cls=F.cross_entropy(output,male) rev_cls = F.cross_entropy(reverse_output, heavy_makeup) t_pa_cls = F.cross_entropy(t_output, heavy_makeup) t_rev_cls = F.cross_entropy(t_reverse_output, male) vae_loss = vae_recon_loss + self.beta * vae_kld # + self.grl*rev_cls self.optim_VAE.zero_grad() self.optim_pacls.zero_grad() self.optim_revcls.zero_grad() self.optim_tcls.zero_grad() self.optim_trevcls.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() self.optim_pacls.step() self.optim_revcls.step() self.optim_tcls.step() self.optim_trevcls.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) #D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) D_tc_loss = (F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() #D_tc_loss.backward() self.optim_D.step() self.pbar.update(1) self.global_iter += 1 pa_acc = float(total_pa_correct_num) / float(total_pa_num) rev_acc = float(total_rev_correct_num) / float(total_rev_num) t_acc = float(total_t_correct_num) / float(total_t_num) t_rev_acc = float(total_t_rev_correct_num) / float(total_t_rev_num) male_acc = float(total_male_correct) / float(total_male_num) female_acc = float(total_female_correct) / float(total_female_num) if self.global_iter % self.print_iter == 0: self.pbar.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f} pa_cls_loss:{:.3f} pa_acc:{:.3f} m_acc:{:.3f} f_acc:{:.3f} rev_acc:{:.3f} t_acc:{:.3f} t_rev_acc:{:.3f}' .format(self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item(), pa_cls.item(), pa_acc, male_acc, female_acc, rev_acc, t_acc, t_rev_acc)) if self.global_iter % self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) #self.ckpt_save_iter+=1 if self.viz_on and (self.global_iter % self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2 * self.batch_size self.line_gather.insert( iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), kld=vae_kld.item(), acc=D_acc.item()) #viz_ll_iter+=1 if self.viz_on and (self.global_iter % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() #viz_la_iter+=1 if self.viz_on and (self.global_iter % self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() #viz_ra_iter+=1 if self.viz_on and (self.global_iter % self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: self.visualize_traverse(limit=3, inter=2 / 3) self.pbar.write("[Training Finished]") self.pbar.close() self.train_cls() def train_cls(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False for i_num in range(80): for i, (x_true1, x_true2, heavy_makeup, male) in enumerate(self.data_loader): for name, param in self.VAE.named_parameters(): #if name=='encode.0.weight': # print(param[0]) param.requires_grad = False male = male.to(self.device) heavy_makeup = heavy_makeup.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #weight=torch.tensor([1.0,3.0]).cuda() #target=self.targetcls(z) dim_list = [24] #dim_list=[18,47] target_z = z.split(1, 1)[0] for dim in range(1, len(z[0])): if dim not in dim_list: target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1) target = self.targetcls(target_z) pa_target = self.pa_target(z.split(30, 1)[-1]) target_pa = self.target_pa(z.split(1, 1)[0]) pa_pa = self.pa_pa(z.split(30, 1)[-1]) #weight=torch.tensor([1.0,3.0]).cuda() target_cls = F.cross_entropy(target, heavy_makeup) #import pdb;pdb.set_trace() #pa_target_cls=F.cross_entropy(pa_target,heavy_makeup) #target_pa_cls=F.cross_entropy(target_pa,male) #pa_pa_cls=F.cross_entropy(pa_pa,male) vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss target_loss = target_cls self.optim_cls.zero_grad() #self.optim_pa_target.zero_grad() #self.optim_target_pa.zero_grad() #self.optim_pa_pa.zero_grad() target_loss.backward() #self.optim_pa_target.step() self.optim_cls.step() #self.optim_target_pa.step() #self.optim_pa_pa.step() self.global_iter_cls += 1 self.pbar_cls.update(1) if self.global_iter_cls % self.print_iter == 0: acc = ((target.argmax(1) == heavy_makeup).sum().float() / len(x_true1)).item() pa_acc = ((pa_target.argmax(1) == heavy_makeup).sum().float() / len(x_true1)).item() self.pbar_cls.write( '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}' .format(self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc)) #if self.global_iter_cls%self.ckpt_save_iter == 0: # self.save_checkpoint_cls(self.global_iter_cls) self.val() self.pbar_cls.write("[Classifier Training Finished]") self.pbar_cls.close() ''' def train_cls(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) init_weight=self.target_pa.fc.weight init_bias=self.target_pa.fc.bias out = False total_list=list(range(60)) total_max_value=[] total_min_value=[] for dim in range(60): self.target_pa.fc.weight=init_weight self.target_pa.fc.bias=init_bias max_dim=0 max_value=0 for i_num in range(1): total_target_pa_true=0 total_target_pa_num=0 for i, (x_true1,x_true2,heavy_makeup, male) in enumerate(self.data_loader): for name,param in self.VAE.named_parameters(): #if name=='encode.0.weight': # print(param[0]) param.requires_grad=False male=male.to(self.device) heavy_makeup=heavy_makeup.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() target=self.targetcls(z.split(59,1)[0]) #target=self.targetcls(z) pa_target=self.pa_target(z.split(30,1)[-1]) target_pa=self.target_pa(z.split(1,1)[dim]) pa_pa=self.pa_pa(z.split(30,1)[-1]) #import pdb;pdb.set_trace() target_pa_cls =F.binary_cross_entropy(target_pa.squeeze(1),male.type_as(target_pa)) target_cls=F.cross_entropy(target,heavy_makeup) pa_target_cls=F.cross_entropy(pa_target,heavy_makeup) pa_pa_cls=F.cross_entropy(pa_pa,male) vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss target_loss=target_cls+pa_target_cls+target_pa_cls+pa_pa_cls total_target_pa_true+=((target_pa.squeeze(1)>0.5).type(torch.LongTensor).cuda()==male).sum().float() total_target_pa_num+=len(male) self.optim_cls.zero_grad() self.optim_pa_target.zero_grad() self.optim_target_pa.zero_grad() self.optim_pa_pa.zero_grad() target_loss.backward() self.optim_pa_target.step() self.optim_cls.step() self.optim_target_pa.step() self.optim_pa_pa.step() self.global_iter_cls += 1 self.pbar_cls.update(1) if self.global_iter_cls%self.print_iter == 0: acc=((target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item() pa_acc=((pa_target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item() self.pbar_cls.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'.format( self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc )) total_target_pa_acc=total_target_pa_true/total_target_pa_num print(total_target_pa_acc) if total_target_pa_acc>max_value: max_value=total_target_pa_acc #if self.global_iter_cls%self.ckpt_save_iter == 0: # self.save_checkpoint_cls(self.global_iter_cls) #self.val() #import pdb;pdb.set_trace() total_max_value.append(max_value) self.pbar_cls.write("[Classifier Training Finished]") self.pbar_cls.close() total_max_index=torch.from_numpy(np.array(total_max_value)).topk(5) print(total_max_index) ''' def val(self): ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) total_true = 0 total_num = 0 total_male_heavy = 0 total_male_nonheavy = 0 total_female_heavy = 0 total_female_nonheavy = 0 total_male_heavy_num = 0 total_male_nonheavy_num = 0 total_female_heavy_num = 0 total_female_nonheavy_num = 0 total_pa_num = 0 total_pa_true = 0 total_target_pa_num = 0 total_target_pa_true = 0 total_pa_pa_num = 0 total_pa_pa_true = 0 demo = 0 total_male = 0 total_female = 0 total_male_pred = 0 total_female_pred = 0 iter = 0 recon_tsum = np.zeros((self.batch_size, 64, 64, 3)) recon_csum = np.zeros((self.batch_size, 64, 64, 3)) recon_psum = np.zeros((self.batch_size, 64, 64, 3)) recon_sum = np.zeros((self.batch_size, 64, 64, 3)) origin_sum = np.zeros((self.batch_size, 64, 64, 3)) for i, (x_true1, x_true2, heavy_makeup, male) in enumerate(self.data_loader_eval): #for name,param in self.VAE.named_parameters(): # param.requires_grad=False #for name,param in self.targetcls.named_parameters(): # param.requires_grad=False male = male.to(self.device) heavy_makeup = heavy_makeup.to(self.device) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) #dim_list=[18,47] dim_list = [24] target_z = z.split(1, 1)[0] for dim in range(1, len(z[0])): if dim not in dim_list: target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1) target = self.targetcls(target_z) pa_target = self.pa_target(z.split(30, 1)[-1]) target_pa = self.target_pa(z.split(1, 1)[0]) pa_pa = self.pa_pa(z.split(30, 1)[-1]) #self.optim_cls.zero_grad() z_t = z.split(30, 1)[0].unsqueeze(2).unsqueeze(2) z_p = z.split(30, 1)[-1].unsqueeze(2).unsqueeze(2) noise = torch.zeros(z.size(0), 30, 1, 1).cuda() z_t = torch.cat([z_t, noise], 1) z_p = torch.cat([z_p, noise], 1) recon_t = F.sigmoid(self.VAE.decode(z_t)) recon_p = F.sigmoid(self.VAE.decode(z_p)) recon_tsum += recon_t.transpose(1, 2).transpose( 2, 3).cpu().detach().numpy() recon_psum += recon_p.transpose(1, 2).transpose( 2, 3).cpu().detach().numpy() origin_sum += x_true1.transpose(1, 2).transpose( 2, 3).cpu().detach().numpy() recon_sum += F.sigmoid(x_recon).transpose(1, 2).transpose( 2, 3).cpu().detach().numpy() iter += 1 male_heavy = (target.argmax(1) == 1) * (heavy_makeup == 1) * (male == 1) male_heavy = male_heavy.sum() male_heavy_num = ((heavy_makeup == 1) * (male == 1)).sum() male_nonheavy = (target.argmax(1) == 0) * (heavy_makeup == 0) * (male == 1) male_nonheavy = male_nonheavy.sum() male_nonheavy_num = ((heavy_makeup == 0) * (male == 1)).sum() female_heavy = (target.argmax(1) == 1) * (heavy_makeup == 1) * (male == 0) female_heavy = female_heavy.sum() female_heavy_num = ((heavy_makeup == 1) * (male == 0)).sum() female_nonheavy = (target.argmax(1) == 0) * (heavy_makeup == 0) * (male == 0) female_nonheavy = female_nonheavy.sum() female_nonheavy_num = ((heavy_makeup == 0) * (male == 0)).sum() total_male_heavy += male_heavy total_male_nonheavy += male_nonheavy total_female_heavy += female_heavy total_female_nonheavy += female_nonheavy total_male_heavy_num += male_heavy_num total_male_nonheavy_num += male_nonheavy_num total_female_heavy_num += female_heavy_num total_female_nonheavy_num += female_nonheavy_num total_pa_true += ( pa_target.argmax(1) == heavy_makeup).sum().float() total_pa_num += len(heavy_makeup) total_target_pa_true += (target_pa.argmax(1) == male).sum().float() total_target_pa_num += len(male) total_pa_pa_true += (pa_pa.argmax(1) == male).sum().float() total_pa_pa_num += len(male) total_true += (target.argmax(1) == heavy_makeup).sum().float() total_num += len(x_true1) total_male += (male == 1).sum() total_female += (male == 0).sum() total_male_pred += ((target.argmax(1) == 1) * (male == 1)).sum() total_female_pred += ((target.argmax(1) == 1) * (male == 0)).sum() #import pdb;pdb.set_trace() male_heavy_acc = total_male_heavy.float() / total_male_heavy_num.float( ) male_nonheavy_acc = total_male_nonheavy.float( ) / total_male_nonheavy_num.float() female_heavy_acc = total_female_heavy.float( ) / total_female_heavy_num.float() female_nonheavy_acc = total_female_nonheavy.float( ) / total_female_nonheavy_num.float() ''' plt.imshow(origin_sum.mean(0)/iter) plt.savefig('./figure/origin/origin'+str(i)+'.png') plt.imshow(recon_sum.mean(0)/iter) plt.savefig('./figure/recon/recon'+str(i)+'.png') plt.imshow(recon_tsum.mean(0)/iter) plt.savefig('./figure/target/target'+str(i)+'.png') plt.imshow(recon_psum.mean(0)/iter) plt.savefig('./figure/protected/protected'+str(i)+'.png') ''' print(total_male_heavy_num.item(), total_male_nonheavy_num.item(), total_female_heavy_num.item(), total_female_nonheavy_num.item()) print("\nmale_heavy: ", male_heavy_acc.item(), "\tfemale_heavy: ", female_heavy_acc.item()) print("male_nonheavy: ", male_nonheavy_acc.item(), "\tfemale_nonheavy: ", female_nonheavy_acc.item()) print("Male_prob:", float(total_male_pred) / float(total_male)) print("feMale_prob:", float(total_female_pred) / float(total_female)) #import pdb;pdb.set_trace() print("DP:", (float(total_male_pred) / float(total_male) - float(total_female_pred) / float(total_female))) print("eoo(1):", male_heavy_acc.item() - female_heavy_acc.item()) print("eoo(0):", male_nonheavy_acc.item() - female_nonheavy_acc.item()) #import pdb;pdb.set_trace() total_acc = total_true / total_num total_pa_acc = total_pa_true / total_pa_num total_target_pa_acc = total_target_pa_true / total_target_pa_num total_pa_pa_acc = total_pa_pa_true / total_pa_pa_num print("target->target Accuracy: ", total_acc.item()) print("PA->target Accuracy: ", total_pa_acc.item()) print("target->PA Accuracy: ", total_target_pa_acc.item()) print("PA->PA Accuracy: ", total_pa_pa_acc.item()) def visualize_recon(self): data = self.image_gather.data true_image = data['true'][0] recon_image = data['recon'][0] true_image = make_grid(true_image) recon_image = make_grid(recon_image) sample = torch.stack([true_image, recon_image], dim=0) self.viz.images(sample, env=self.name + '/recon_image', opts=dict(title=str(self.global_iter))) def visualize_line(self): data = self.line_gather.data iters = torch.Tensor(data['iter']) recon = torch.Tensor(data['recon']) kld = torch.Tensor(data['kld']) D_acc = torch.Tensor(data['acc']) soft_D_z = torch.Tensor(data['soft_D_z']) soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm']) soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1) self.viz.line(X=iters, Y=soft_D_zs, env=self.name + '/lines', win=self.win_id['D_z'], update='append', opts=dict(xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) self.viz.line(X=iters, Y=recon, env=self.name + '/lines', win=self.win_id['recon'], update='append', opts=dict( xlabel='iteration', ylabel='reconstruction loss', )) self.viz.line(X=iters, Y=D_acc, env=self.name + '/lines', win=self.win_id['acc'], update='append', opts=dict( xlabel='iteration', ylabel='discriminator accuracy', )) self.viz.line(X=iters, Y=kld, env=self.name + '/lines', win=self.win_id['kld'], update='append', opts=dict( xlabel='iteration', ylabel='kl divergence', )) def visualize_traverse(self, limit=3, inter=2 / 3, loc=-1): self.net_mode(train=False) decoder = self.VAE.decode encoder = self.VAE.encode interpolation = torch.arange(-limit, limit + 0.1, inter) random_img = self.data_loader.dataset.__getitem__(0)[1] random_img = random_img.to(self.device).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] if self.dataset.lower() == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = { 'fixed_square': fixed_img_z1, 'fixed_ellipse': fixed_img_z2, 'fixed_heart': fixed_img_z3, 'random_img': random_img_z } elif self.dataset.lower() == 'celeba': fixed_idx1 = 70000 # 'CelebA/img_align_celeba/191282.jpg' fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg' fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg' fixed_idx4 = 70059 # 'CelebA/img_align_celeba/070060.jpg' fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0] fixed_img4 = fixed_img4.to(self.device).unsqueeze(0) fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim] Z = { 'fixed_1': fixed_img_z1, 'fixed_2': fixed_img_z2, 'fixed_3': fixed_img_z3, 'fixed_4': fixed_img_z4, 'random': random_img_z } elif self.dataset.lower() == '3dchairs': fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png fixed_idx2 = 5172 # 3DChairs/images/14657_image_020_p020_t232_r096.png fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = { 'fixed_1': fixed_img_z1, 'fixed_2': fixed_img_z2, 'fixed_3': fixed_img_z3, 'random': random_img_z } else: fixed_idx = 0 fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0] fixed_img = fixed_img.to(self.device).unsqueeze(0) fixed_img_z = encoder(fixed_img)[:, :self.z_dim] random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device) Z = { 'fixed_img': fixed_img_z, 'random_img': random_img_z, 'random_z': random_z } gifs = [] for key in Z: z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format( key, self.global_iter) self.viz.images(samples, env=self.name + '/traverse', opts=dict(title=title), nrow=len(interpolation)) if self.output_save: output_dir = os.path.join(self.output_dir, str(self.global_iter)) mkdirs(output_dir) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join( output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(str(os.path.join(output_dir, key + '*.jpg')), str(os.path.join(output_dir, key + '.gif')), delay=10) self.net_mode(train=True) def viz_init(self): zero_init = torch.zeros([1]) self.viz.line(X=zero_init, Y=torch.stack([zero_init, zero_init], -1), env=self.name + '/lines', win=self.win_id['D_z'], opts=dict(xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) self.viz.line(X=zero_init, Y=zero_init, env=self.name + '/lines', win=self.win_id['recon'], opts=dict( xlabel='iteration', ylabel='reconstruction loss', )) self.viz.line(X=zero_init, Y=zero_init, env=self.name + '/lines', win=self.win_id['acc'], opts=dict( xlabel='iteration', ylabel='discriminator accuracy', )) self.viz.line(X=zero_init, Y=zero_init, env=self.name + '/lines', win=self.win_id['kld'], opts=dict( xlabel='iteration', ylabel='kl divergence', )) def net_mode(self, train): if not isinstance(train, bool): raise ValueError('Only bool type is supported. True|False') for net in self.nets: if train: net.train() else: net.eval() def save_checkpoint(self, ckptname='last', verbose=True): model_states = { 'D': self.D.state_dict(), 'VAE': self.VAE.state_dict(), 'PACLS': self.pacls.state_dict(), 'REVCLS': self.revcls.state_dict(), 'T_CLS': self.tcls.state_dict(), 'T_REVCLS': self.trevcls.state_dict() } optim_states = { 'optim_D': self.optim_D.state_dict(), 'optim_VAE': self.optim_VAE.state_dict(), 'optim_PACLS': self.optim_pacls.state_dict(), 'optim_REVCLS': self.optim_revcls.state_dict(), 'optim_TCLS': self.optim_tcls.state_dict(), 'optim_TREVCLS': self.optim_trevcls.state_dict() } states = { 'iter': self.global_iter, 'model_states': model_states, 'optim_states': optim_states } #import pdb;pdb.set_trace() filepath = os.path.join(self.ckpt_dir + "/vae", str(ckptname)) with open(filepath, 'wb+') as f: torch.save(states, f) if verbose: self.pbar.write("=> saved checkpoint '{}' (iter {})".format( filepath, self.global_iter)) def save_checkpoint_cls(self, ckptname='last', verbose=True): model_states = { 'D': self.D.state_dict(), 'VAE': self.VAE.state_dict(), 'PACLS': self.pacls.state_dict(), 'REVCLS': self.revcls.state_dict(), 'TCLS': self.targetcls.state_dict(), 'VALCLS': self.pa_target.state_dict(), 'T_CLS': self.tcls.state_dict(), 'T_REVCLS': self.trevcls.state_dict() } optim_states = { 'optim_D': self.optim_D.state_dict(), 'optim_VAE': self.optim_VAE.state_dict(), 'optim_Tcls': self.optim_cls.state_dict(), 'optim_PACLS': self.optim_pacls.state_dict(), 'optim_REVCLS': self.optim_revcls.state_dict(), 'optim_TCLS': self.optim_tcls.state_dict(), 'optim_TREVCLS': self.optim_trevcls.state_dict(), 'optim_VALCLS': self.optim_pa_target.state_dict() } states = { 'iter': self.global_iter_cls, 'model_states': model_states, 'optim_states': optim_states } #import pdb;pdb.set_trace() filepath = os.path.join(self.ckpt_dir + "/cls", str(ckptname)) with open(filepath, 'wb+') as f: torch.save(states, f) if verbose: self.pbar.write("=> saved checkpoint '{}' (iter {})".format( filepath, self.global_iter_cls)) def load_checkpoint(self, ckptname='last', verbose=True): if ckptname == 'last': ckpts = os.listdir(self.ckpt_dir + '/vae') if not ckpts: if verbose: self.pbar.write("=> no checkpoint found") return ckpts = [int(ckpt) for ckpt in ckpts] ckpts.sort(reverse=True) ckptname = str(ckpts[0]) #import pdb;pdb.set_trace() filepath = os.path.join(self.ckpt_dir + '/vae', ckptname) if os.path.isfile(filepath): with open(filepath, 'rb') as f: checkpoint = torch.load(f) self.global_iter = checkpoint['iter'] self.VAE.load_state_dict(checkpoint['model_states']['VAE']) self.D.load_state_dict(checkpoint['model_states']['D']) self.pacls.load_state_dict(checkpoint['model_states']['PACLS']) self.revcls.load_state_dict(checkpoint['model_states']['REVCLS']) self.tcls.load_state_dict(checkpoint['model_states']['T_CLS']) self.trevcls.load_state_dict( checkpoint['model_states']['T_REVCLS']) self.optim_VAE.load_state_dict( checkpoint['optim_states']['optim_VAE']) self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D']) self.optim_pacls.load_state_dict( checkpoint['optim_states']['optim_PACLS']) self.optim_revcls.load_state_dict( checkpoint['optim_states']['optim_REVCLS']) self.optim_tcls.load_state_dict( checkpoint['optim_states']['optim_TCLS']) self.optim_trevcls.load_state_dict( checkpoint['optim_states']['optim_TREVCLS']) self.pbar.update(self.global_iter) if verbose: self.pbar.write("=> loaded checkpoint '{} (iter {})'".format( filepath, self.global_iter)) else: if verbose: self.pbar.write( "=> no checkpoint found at '{}'".format(filepath)) def load_checkpoint_cls(self, ckptname='last', verbose=True): if ckptname == 'last': ckpts = os.listdir(self.ckpt_dir + "/cls") if not ckpts: if verbose: self.pbar.write("=> no checkpoint found") return ckpts = [int(ckpt) for ckpt in ckpts] ckpts.sort(reverse=True) ckptname = str(ckpts[0]) filepath = os.path.join(self.ckpt_dir + '/cls', ckptname) if os.path.isfile(filepath): with open(filepath, 'rb') as f: checkpoint = torch.load(f) self.global_iter_cls = checkpoint['iter'] self.VAE.load_state_dict(checkpoint['model_states']['VAE']) self.D.load_state_dict(checkpoint['model_states']['D']) self.pacls.load_state_dict(checkpoint['model_states']['PACLS']) self.revcls.load_state_dict(checkpoint['model_states']['REVCLS']) self.targetcls.load_state_dict(checkpoint['model_states']['TCLS']) self.pa_target.load_state_dict( checkpoint['model_states']['VALCLS']) self.tcls.load_state_dict(checkpoint['model_states']['T_CLS']) self.trevcls.load_state_dict( checkpoint['model_states']['T_REVCLS']) self.optim_VAE.load_state_dict( checkpoint['optim_states']['optim_VAE']) self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D']) self.optim_pacls.load_state_dict( checkpoint['optim_states']['optim_PACLS']) self.optim_revcls.load_state_dict( checkpoint['optim_states']['optim_REVCLS']) self.optim_pa_target.load_state_dict( checkpoint['optim_states']['optim_VALCLS']) self.optim_tcls.load_state_dict( checkpoint['optim_states']['optim_TCLS']) self.optim_trevcls.load_state_dict( checkpoint['optim_states']['optim_TREVCLS']) self.pbar.update(self.global_iter_cls) if verbose: self.pbar.write("=> loaded checkpoint '{} (iter {})'".format( filepath, self.global_iter_cls)) else: if verbose: self.pbar.write( "=> no checkpoint found at '{}'".format(filepath))
class Trainer(object): def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.max_epoch = args.max_epoch self.global_epoch = 0 self.global_iter = 0 self.z_dim = args.z_dim self.z_var = args.z_var self.z_sigma = math.sqrt(args.z_var) self._lambda = args.reg_weight self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 self.lr_schedules = {30: 2, 50: 5, 100: 10} if args.dataset.lower() == 'celeba': self.nc = 3 self.decoder_dist = 'gaussian' else: raise NotImplementedError net = WAE self.net = cuda(net(self.z_dim, self.nc), self.use_cuda) self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) self.gather = DataGather() self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on if self.viz_on: self.viz = visdom.Visdom(env=self.viz_name + '_lines', port=self.viz_port) self.win_recon = None self.win_mmd = None self.win_mu = None self.win_var = None self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name) if not self.ckpt_dir.exists(): self.ckpt_dir.mkdir(parents=True, exist_ok=True) self.ckpt_name = args.ckpt_name if self.ckpt_name is not None: self.load_checkpoint(self.ckpt_name) self.save_output = args.save_output self.output_dir = Path(args.output_dir).joinpath(args.viz_name) if not self.output_dir.exists(): self.output_dir.mkdir(parents=True, exist_ok=True) self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) def train(self): self.net.train() iters_per_epoch = len(self.data_loader) max_iter = self.max_epoch * iters_per_epoch pbar = tqdm(total=max_iter) with tqdm(total=max_iter) as pbar: pbar.update(self.global_iter) out = False while not out: for x in self.data_loader: pbar.update(1) self.global_iter += 1 if self.global_iter % iters_per_epoch == 0: self.global_epoch += 1 self.optim = multistep_lr_decay(self.optim, self.global_epoch, self.lr_schedules) x = Variable(cuda(x, self.use_cuda)) x_recon, z_tilde = self.net(x) z = self.sample_z(template=z_tilde, sigma=self.z_sigma) recon_loss = F.mse_loss( x_recon, x, size_average=False).div(self.batch_size) mmd_loss = mmd(z_tilde, z, z_var=self.z_var) total_loss = recon_loss + self._lambda * mmd_loss self.optim.zero_grad() total_loss.backward() self.optim.step() if self.global_iter % 1000 == 0: self.gather.insert( iter=self.global_iter, mu=z.mean(0).data, var=z.var(0).data, recon_loss=recon_loss.data, mmd_loss=mmd_loss.data, ) if self.global_iter % 5000 == 0: self.gather.insert(images=x.data) self.gather.insert(images=x_recon.data) self.viz_reconstruction() self.viz_lines() self.sample_x_from_z(n_sample=100) self.gather.flush() self.save_checkpoint('last') pbar.write( '[{}] total_loss:{:.3f} recon_loss:{:.3f} mmd_loss:{:.3f}' .format(self.global_iter, total_loss.data[0], recon_loss.data[0], mmd_loss.data[0])) if self.global_iter % 20000 == 0: self.save_checkpoint(str(self.global_iter)) if self.global_iter >= max_iter: out = True break pbar.write("[Training Finished]") def viz_reconstruction(self): self.net.eval() x = self.gather.data['images'][0][:100] x = make_grid(x, normalize=True, nrow=10) x_recon = F.sigmoid(self.gather.data['images'][1][:100]) x_recon = make_grid(x_recon, normalize=True, nrow=10) images = torch.stack([x, x_recon], dim=0).cpu() self.viz.images(images, env=self.viz_name + '_reconstruction', opts=dict(title=str(self.global_iter)), nrow=2) self.net.train() def viz_lines(self): self.net.eval() recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() mmd_losses = torch.stack(self.gather.data['mmd_loss']).cpu() mus = torch.stack(self.gather.data['mu']).cpu() vars = torch.stack(self.gather.data['var']).cpu() iters = torch.Tensor(self.gather.data['iter']) legend = [] for z_j in range(self.z_dim): legend.append('z_{}'.format(z_j)) if self.win_recon is None: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) else: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', win=self.win_recon, update='append', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) if self.win_mmd is None: self.win_mmd = self.viz.line(X=iters, Y=mmd_losses, env=self.viz_name + '_lines', opts=dict( width=400, height=400, xlabel='iteration', title='maximum mean discrepancy', )) else: self.win_mmd = self.viz.line(X=iters, Y=mmd_losses, env=self.viz_name + '_lines', win=self.win_mmd, update='append', opts=dict( width=400, height=400, xlabel='iteration', title='maximum mean discrepancy', )) if self.win_mu is None: self.win_mu = self.viz.line(X=iters, Y=mus, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend, xlabel='iteration', title='posterior mean', )) else: self.win_mu = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', win=self.win_mu, update='append', opts=dict( width=400, height=400, legend=legend, xlabel='iteration', title='posterior mean', )) if self.win_var is None: self.win_var = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend, xlabel='iteration', title='posterior variance', )) else: self.win_var = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', win=self.win_var, update='append', opts=dict( width=400, height=400, legend=legend, xlabel='iteration', title='posterior variance', )) self.net.train() def sample_z(self, n_sample=None, dim=None, sigma=None, template=None): if n_sample is None: n_sample = self.batch_size if dim is None: dim = self.z_dim if sigma is None: sigma = self.z_sigma if template is not None: z = sigma * Variable(template.data.new(template.size()).normal_()) else: z = sigma * torch.randn(n_sample, dim) z = Variable(cuda(z, self.use_cuda)) return z def sample_x_from_z(self, n_sample): self.net.eval() z = self.sample_z(n_sample=n_sample, sigma=self.z_sigma) x_gen = F.sigmoid(self.net._decode(z)[:100]).data.cpu() x_gen = make_grid(x_gen, normalize=True, nrow=10) self.viz.images(x_gen, env=self.viz_name + '_sampling_from_random_z', opts=dict(title=str(self.global_iter))) self.net.train() def save_checkpoint(self, filename, silent=True): model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } win_states = { 'recon': self.win_recon, 'mmd': self.win_mmd, 'mu': self.win_mu, 'var': self.win_var, } states = { 'iter': self.global_iter, 'epoch': self.global_epoch, 'win_states': win_states, 'model_states': model_states, 'optim_states': optim_states } file_path = self.ckpt_dir.joinpath(filename) torch.save(states, file_path.open('wb+')) if not silent: print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename, silent=False): file_path = self.ckpt_dir.joinpath(filename) if file_path.is_file(): checkpoint = torch.load(file_path.open('rb')) self.global_iter = checkpoint['iter'] self.global_epoch = checkpoint['epoch'] self.win_recon = checkpoint['win_states']['recon'] self.win_mmd = checkpoint['win_states']['mmd'] self.win_var = checkpoint['win_states']['var'] self.win_mu = checkpoint['win_states']['mu'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) if not silent: print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: if not silent: print("=> no checkpoint found at '{}'".format(file_path))
class Solver(object): #### def __init__(self, args): self.args = args self.name = ( '%s_gamma_%s_zDim_%s' + \ '_lrVAE_%s_lrD_%s_rseed_%s' ) % \ ( args.dataset, args.gamma, args.z_dim, args.lr_VAE, args.lr_D, args.rseed ) # to be appended by run_id self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info self.dset_dir = args.dset_dir self.dataset = args.dataset if args.dataset.endswith('dsprites'): self.nc = 1 elif args.dataset == '3dfaces': self.nc = 1 else: self.nc = 3 # groundtruth factor labels (only available for "dsprites") if self.dataset == 'dsprites': # latent factor = (color, shape, scale, orient, pos-x, pos-y) # color = {1} (1) # shape = {1=square, 2=oval, 3=heart} (3) # scale = {0.5, 0.6, ..., 1.0} (6) # orient = {2*pi*(k/39)}_{k=0}^39 (40) # pos-x = {k/31}_{k=0}^31 (32) # pos-y = {k/31}_{k=0}^31 (32) # (number of variations = 1*3*6*40*32*32 = 737280) latent_values = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_values.npy'), encoding='latin1') self.latent_values = latent_values[:, [1, 2, 3, 4, 5]] # latent values (actual values);(737280 x 5) latent_classes = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_classes.npy'), encoding='latin1') self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]] # classes ({0,1,...,K}-valued); (737280 x 5) self.latent_sizes = np.array([3, 6, 40, 32, 32]) self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter # groundtruth factor labels elif self.dataset == 'oval_dsprites': latent_classes = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_classes.npy'), encoding='latin1') idx = np.where(latent_classes[:, 1] == 1)[0] # "oval" shape only self.latent_classes = latent_classes[idx, :] self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]] # classes ({0,1,...,K}-valued); (245760 x 4) latent_values = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_values.npy'), encoding='latin1') self.latent_values = latent_values[idx, :] self.latent_values = self.latent_values[:, [2, 3, 4, 5]] # latent values (actual values);(245760 x 4) self.latent_sizes = np.array([6, 40, 32, 32]) self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter # groundtruth factor labels elif self.dataset == '3dfaces': # latent factor = (id, azimuth, elevation, lighting) # id = {0,1,...,49} (50) # azimuth = {-1.0,-0.9,...,0.9,1.0} (21) # elevation = {-1.0,0.8,...,0.8,1.0} (11) # lighting = {-1.0,0.8,...,0.8,1.0} (11) # (number of variations = 50*21*11*11 = 127050) latent_classes, latent_values = np.load( os.path.join(self.dset_dir, '3d_faces/rtqichen/gt_factor_labels.npy')) self.latent_values = latent_values # latent values (actual values);(127050 x 4) self.latent_classes = latent_classes # classes ({0,1,...,K}-valued); (127050 x 4) self.latent_sizes = np.array([50, 21, 11, 11]) self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter elif self.dataset == 'celeba': self.N = 202599 self.eval_metrics = False elif self.dataset == 'edinburgh_teapots': # latent factor = (azimuth, elevation, R, G, B) # azimuth = [0, 2*pi] # elevation = [0, pi/2] # R, G, B = [0,1] # # "latent_values" = original (real) factor values # "latent_classes" = equal binning into K=10 classes # # (refer to "data/edinburgh_teapots/my_make_split_data.py") K = 10 val_ranges = [2 * np.pi, np.pi / 2, 1, 1, 1] bins = [] for j in range(5): bins.append(np.linspace(0, val_ranges[j], K + 1)) latent_values = np.load( os.path.join(self.dset_dir, 'edinburgh_teapots', 'gtfs_tr.npz'))['data'] latent_values = np.concatenate( (latent_values, np.load( os.path.join(self.dset_dir, 'edinburgh_teapots', 'gtfs_va.npz'))['data']), axis=0) latent_values = np.concatenate( (latent_values, np.load( os.path.join(self.dset_dir, 'edinburgh_teapots', 'gtfs_te.npz'))['data']), axis=0) self.latent_values = latent_values latent_classes = np.zeros(latent_values.shape) for j in range(5): latent_classes[:, j] = np.digitize(latent_values[:, j], bins[j]) self.latent_classes = latent_classes - 1 # {0,...,K-1}-valued self.latent_sizes = K * np.ones(5, 'int64') self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.z_dim = args.z_dim self.gamma = args.gamma self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict(DZ='win_DZ', recon='win_recon', kl='win_kl', kl_alpha='win_kl_alpha') self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon', 'kl', 'kl_alpha') if self.eval_metrics: self.win_id['metrics'] = 'win_metrics' import visdom self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records") mkdirs("ckpts") mkdirs("outputs") # set run id if args.run_id < 0: # create a new id k = 0 rfname = os.path.join("records", self.name + '_run_0.txt') while os.path.exists(rfname): k += 1 rfname = os.path.join("records", self.name + '_run_%d.txt' % k) self.run_id = k else: # user-provided id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # records (text file to store console outputs) self.record_file = 'records/%s.txt' % self.name # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # outputs self.output_dir_recon = os.path.join("outputs", self.name + '_recon') # dir for reconstructed images self.output_dir_synth = os.path.join("outputs", self.name + '_synth') # dir for synthesized images self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl') # dir for latent traversed images #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter if self.ckpt_load_iter == 0: # create a new model # create a vae model if args.dataset.endswith('dsprites'): self.encoder = Encoder1(self.z_dim) self.decoder = Decoder1(self.z_dim) elif args.dataset == '3dfaces': self.encoder = Encoder3(self.z_dim) self.decoder = Decoder3(self.z_dim) elif args.dataset == 'celeba': self.encoder = Encoder4(self.z_dim) self.decoder = Decoder4(self.z_dim) elif args.dataset.endswith('teapots'): # self.encoder = Encoder4(self.z_dim) # self.decoder = Decoder4(self.z_dim) self.encoder = Encoder_ResNet(self.z_dim) self.decoder = Decoder_ResNet(self.z_dim) else: pass #self.VAE = FactorVAE2(self.z_dim) # create a prior alpha model self.prior_alpha = PriorAlphaParams(self.z_dim) # create a posterior alpha model self.post_alpha = PostAlphaParams(self.z_dim) # create a discriminator model self.D = Discriminator(self.z_dim) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') if self.use_cuda: print('Models moved to GPU...') self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.prior_alpha = self.prior_alpha.cuda() self.post_alpha = self.post_alpha.cuda() self.D = self.D.cuda() print('...done') # get VAE parameters vae_params = list(self.encoder.parameters()) + \ list(self.decoder.parameters()) + \ list(self.prior_alpha.parameters()) + \ list(self.post_alpha.parameters()) # get discriminator parameters dis_params = list(self.D.parameters()) # create optimizers self.optim_vae = optim.Adam(vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE]) self.optim_dis = optim.Adam(dis_params, lr=self.lr_D, betas=[self.beta1_D, self.beta2_D]) #### def train(self): self.set_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long) zeros = torch.zeros(self.batch_size, dtype=torch.long) if self.use_cuda: ones = ones.cuda() zeros = zeros.cuda() # prepare dataloader (iterable) print('Start loading data...') self.data_loader = create_dataloader(self.args) print('...done') # iterators from dataloader iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) iter_per_epoch = min(len(iterator1), len(iterator2)) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) #============================================ # TRAIN THE VAE (ENC & DEC) #============================================ # sample a mini-batch X, ids = next(iterator1) # (n x C x H x W) if self.use_cuda: X = X.cuda() # enc(X) mu, std, logvar = self.encoder(X) # prior alpha params a, b = self.prior_alpha() # posterior alpha params ah, bh = self.post_alpha() # kl loss kls = 0.5 * ( \ (ah/bh)*(mu**2+std**2) - 1.0 + \ bh.log() - ah.digamma() - logvar ) # (n x z_dim) loss_kl = kls.sum(1).mean() # kl loss on alpha kls_alpha = ( \ (ah-a)*ah.digamma() - ah.lgamma() + a.lgamma() + \ a*(bh.log()-b.log()) + (ah/bh)*(b-bh) ) # z_dim-dim loss_kl_alpha = kls_alpha.sum() / self.N # reparam'ed samples if self.use_cuda: Eps = torch.cuda.FloatTensor(mu.shape).normal_() else: Eps = torch.randn(mu.shape) Z = mu + Eps * std # dec(Z) X_recon = self.decoder(Z) # recon loss loss_recon = F.binary_cross_entropy_with_logits( X_recon, X, reduction='sum').div(X.size(0)) # dis(Z) DZ = self.D(Z) # tc loss loss_tc = (DZ[:, 0] - DZ[:, 1]).mean() # total loss for vae vae_loss = loss_recon + loss_kl + loss_kl_alpha + \ self.gamma*loss_tc # update vae self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() #============================================ # TRAIN THE DISCRIMINATOR #============================================ # sample a mini-batch X2, ids = next(iterator2) # (n x C x H x W) if self.use_cuda: X2 = X2.cuda() # enc(X2) mu, std, _ = self.encoder(X2) # reparam'ed samples if self.use_cuda: Eps = torch.cuda.FloatTensor(mu.shape).normal_() else: Eps = torch.randn(mu.shape) Z = mu + Eps * std # dis(Z) DZ = self.D(Z) # dim-wise permutated Z over the mini-batch perm_Z = [] for zj in Z.split(1, 1): idx = torch.randperm(Z.size(0)) perm_zj = zj[idx] perm_Z.append(perm_zj) Z_perm = torch.cat(perm_Z, 1) Z_perm = Z_perm.detach() # dis(Z_perm) DZ_perm = self.D(Z_perm) # discriminator loss dis_loss = 0.5 * (F.cross_entropy(DZ, zeros) + F.cross_entropy(DZ_perm, ones)) # update discriminator self.optim_dis.zero_grad() dis_loss.backward() self.optim_dis.step() ########################################## # print the losses if iteration % self.print_iter == 0: prn_str = ( '[iter %d (epoch %d)] vae_loss: %.3f | ' + \ 'dis_loss: %.3f\n ' + \ '(recon: %.3f, kl: %.3f, kl_alpha: %.3f, tc: %.3f)' \ ) % \ ( iteration, epoch, vae_loss.item(), dis_loss.item(), loss_recon.item(), loss_kl.item(), loss_kl_alpha.item(), loss_tc.item() ) prn_str += '\n a = {}'.format( a.detach().cpu().numpy().round(2)) prn_str += '\n b = {}'.format( b.detach().cpu().numpy().round(2)) prn_str += '\n ah = {}'.format( ah.detach().cpu().numpy().round(2)) prn_str += '\n bh = {}'.format( bh.detach().cpu().numpy().round(2)) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str, )) record.close() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # save output images (recon, synth, etc.) if iteration % self.output_save_iter == 0: # 1) save the recon images self.save_recon(iteration, X, torch.sigmoid(X_recon).data) # 2) save the synth images self.save_synth(iteration, howmany=100) # 3) save the latent traversed images if self.dataset.lower() == '3dchairs': self.save_traverse(iteration, limb=-2, limu=2, inter=0.5) else: self.save_traverse(iteration, limb=-3, limu=3, inter=0.1) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): # compute discriminator accuracy p_DZ = F.softmax(DZ, 1)[:, 0].detach() p_DZ_perm = F.softmax(DZ_perm, 1)[:, 0].detach() # insert line stats self.line_gather.insert(iter=iteration, p_DZ=p_DZ.mean().item(), p_DZ_perm=p_DZ_perm.mean().item(), recon=loss_recon.item(), kl=loss_kl.item(), kl_alpha=loss_kl_alpha.item()) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() # evaluate metrics if self.eval_metrics and (iteration % self.eval_metrics_iter == 0): metric1, _ = self.eval_disentangle_metric1() metric2, _ = self.eval_disentangle_metric2() prn_str = ( '********\n[iter %d (epoch %d)] ' + \ 'metric1 = %.4f, metric2 = %.4f\n********' ) % \ (iteration, epoch, metric1, metric2) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str, )) record.close() # (visdom) visulaize metrics if self.viz_on: self.visualize_line_metrics(iteration, metric1, metric2) #### def eval_disentangle_metric1(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader(self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch Xb, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: Xb = Xb.cuda() # enc(Xb) mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor fixed" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # a true factor (id and class value) to fix fac_id = j fac_class = np.random.randint(self.latent_sizes[fac_id]) # randomly select images (with the fixed factor) indices = np.where(self.latent_classes[:, fac_id] == fac_class)[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if Xb.shape[0] < 1: # no more samples continue if self.use_cuda: Xb = Xb.cuda() mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = fac_id i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification smallest_var_dims = np.argmin(vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[smallest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric1 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric1, C #### def eval_disentangle_metric2(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader(self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch Xb, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: Xb = Xb.cuda() # enc(Xb) mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor varied" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # randomly choose true factors (id's and class values) to fix fac_ids = list(np.setdiff1d(factor_ids, j)) fac_classes = \ [ np.random.randint(self.latent_sizes[k]) for k in fac_ids ] # randomly select images (with the other factors fixed) if len(fac_ids) > 1: indices = np.where( np.sum(self.latent_classes[:, fac_ids] == fac_classes, 1) == len(fac_ids))[0] else: indices = np.where( self.latent_classes[:, fac_ids] == fac_classes)[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if Xb.shape[0] < 1: # no more samples continue if self.use_cuda: Xb = Xb.cuda() mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = j i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification largest_var_dims = np.argmax(vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[largest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric2 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric2, C #### def save_recon(self, iters, true_images, recon_images): # make a merge of true and recon, eg, # merged[0,...] = true[0,...], # merged[1,...] = recon[0,...], # merged[2,...] = true[1,...], # merged[3,...] = recon[1,...], ... n = true_images.shape[0] perm = torch.arange(0, 2 * n).view(2, n).transpose(1, 0) perm = perm.contiguous().view(-1) merged = torch.cat([true_images, recon_images], dim=0) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'recon_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image(tensor=merged, filename=fname, nrow=2 * int(np.sqrt(n)), pad_value=1) #### def save_synth(self, iters, howmany=100): self.set_mode(train=False) decoder = self.decoder Z = torch.randn(howmany, self.z_dim) if self.use_cuda: Z = Z.cuda() # do synthesis X = torch.sigmoid(decoder(Z)).data.cpu() # save the results as image fname = os.path.join(self.output_dir_synth, 'synth_%s.jpg' % iters) mkdirs(self.output_dir_synth) save_image(tensor=X, filename=fname, nrow=int(np.sqrt(howmany)), pad_value=1) self.set_mode(train=True) #### def save_traverse(self, iters, limb=-3, limu=3, inter=2 / 3, loc=-1): self.set_mode(train=False) encoder = self.encoder decoder = self.decoder interpolation = torch.arange(limb, limu + 0.001, inter) i = np.random.randint(self.N) random_img = self.data_loader.dataset.__getitem__(i)[0] if self.use_cuda: random_img = random_img.cuda() random_img = random_img.unsqueeze(0) random_img_zmu, _, _ = encoder(random_img) if self.dataset.lower() == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed_square': fixed_img1, 'fixed_ellipse': fixed_img2, 'fixed_heart': fixed_img3, 'random_img': random_img } Z = { 'fixed_square': fixed_img_zmu1, 'fixed_ellipse': fixed_img_zmu2, 'fixed_heart': fixed_img_zmu3, 'random_img': random_img_zmu } elif self.dataset.lower() == 'oval_dsprites': fixed_idx1 = 87040 # oval1 fixed_idx2 = 220045 # oval2 fixed_idx3 = 178560 # oval3 fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed1': fixed_img1, 'fixed2': fixed_img2, 'fixed3': fixed_img3, 'random_img': random_img } Z = { 'fixed1': fixed_img_zmu1, 'fixed2': fixed_img_zmu2, 'fixed3': fixed_img_zmu3, 'random_img': random_img_zmu } elif self.dataset.lower() == '3dfaces': fixed_idx1 = 6245 fixed_idx2 = 10205 fixed_idx3 = 68560 fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed1': fixed_img1, 'fixed2': fixed_img2, 'fixed3': fixed_img3, 'random_img': random_img } Z = { 'fixed1': fixed_img_zmu1, 'fixed2': fixed_img_zmu2, 'fixed3': fixed_img_zmu3, 'random_img': random_img_zmu } elif self.dataset.lower() == 'celeba': fixed_idx1 = 191281 fixed_idx2 = 143307 fixed_idx3 = 101535 fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed1': fixed_img1, 'fixed2': fixed_img2, 'fixed3': fixed_img3, 'random_img': random_img } Z = { 'fixed1': fixed_img_zmu1, 'fixed2': fixed_img_zmu2, 'fixed3': fixed_img_zmu3, 'random_img': random_img_zmu } elif self.dataset.lower() == 'edinburgh_teapots': fixed_idx1 = 7040 fixed_idx2 = 32800 fixed_idx3 = 78560 fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed1': fixed_img1, 'fixed2': fixed_img2, 'fixed3': fixed_img3, 'random_img': random_img } Z = { 'fixed1': fixed_img_zmu1, 'fixed2': fixed_img_zmu2, 'fixed3': fixed_img_zmu3, 'random_img': random_img_zmu } # elif self.dataset.lower() == '3dchairs': # # fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png # fixed_idx2 = 5172 # 3DChairs/images/14657_image_020_p020_t232_r096.png # fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png # # fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] # fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) # fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] # # fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] # fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) # fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] # # fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] # fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) # fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] # # Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, # 'fixed_3':fixed_img_z3, 'random':random_img_zmu} # else: raise NotImplementedError # do traversal and collect generated images gifs = [] for key in Z: z_ori = Z[key] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = torch.sigmoid(decoder(z)).data gifs.append(sample) # save the generated files, also the animated gifs out_dir = os.path.join(self.output_dir_trvsl, str(iters)) mkdirs(self.output_dir_trvsl) mkdirs(out_dir) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): I = torch.cat([IMG[key], gifs[i][j]], dim=0) save_image(tensor=I.cpu(), filename=os.path.join(out_dir, '%s_%03d.jpg' % (key, j)), nrow=1 + self.z_dim, pad_value=1) # make animated gif grid2gif(out_dir, key, str(os.path.join(out_dir, key + '.gif')), delay=10) self.set_mode(train=True) #### def viz_init(self): self.viz.close(env=self.name + '/lines', win=self.win_id['DZ']) self.viz.close(env=self.name + '/lines', win=self.win_id['recon']) self.viz.close(env=self.name + '/lines', win=self.win_id['kl']) self.viz.close(env=self.name + '/lines', win=self.win_id['kl_alpha']) if self.eval_metrics: self.viz.close(env=self.name + '/lines', win=self.win_id['metrics']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) recon = torch.Tensor(data['recon']) kl = torch.Tensor(data['kl']) kl_alpha = torch.Tensor(data['kl_alpha']) p_DZ = torch.Tensor(data['p_DZ']) p_DZ_perm = torch.Tensor(data['p_DZ_perm']) p_DZs = torch.stack([p_DZ, p_DZ_perm], -1) # (#items x 2) self.viz.line(X=iters, Y=p_DZs, env=self.name + '/lines', win=self.win_id['DZ'], update='append', opts=dict(xlabel='iter', ylabel='D(z)', title='Discriminator-Z', legend=[ 'D(z)', 'D(z_perm)', ])) self.viz.line(X=iters, Y=recon, env=self.name + '/lines', win=self.win_id['recon'], update='append', opts=dict(xlabel='iter', ylabel='recon loss', title='Reconstruction')) self.viz.line(X=iters, Y=kl, env=self.name + '/lines', win=self.win_id['kl'], update='append', opts=dict(xlabel='iter', ylabel='E_q(alpha)E_x[kl(q(z|x)||p(z|alpha)]', title='KL divergence')) self.viz.line(X=iters, Y=kl_alpha, env=self.name + '/lines', win=self.win_id['kl_alpha'], update='append', opts=dict(xlabel='iter', ylabel='KL(q(alpha)||p(alpha)) / N', title='KL divergence on alpha')) #### def visualize_line_metrics(self, iters, metric1, metric2): # prepare data to plot iters = torch.tensor([iters], dtype=torch.int64).detach() metric1 = torch.tensor([metric1]) metric2 = torch.tensor([metric2]) metrics = torch.stack([metric1.detach(), metric2.detach()], -1) self.viz.line(X=iters, Y=metrics, env=self.name + '/lines', win=self.win_id['metrics'], update='append', opts=dict(xlabel='iter', ylabel='metrics', title='Disentanglement metrics', legend=['metric1', 'metric2'])) #### def set_mode(self, train=True): if train: self.encoder.train() self.decoder.train() self.D.train() else: self.encoder.eval() self.decoder.eval() self.D.eval() #### def save_checkpoint(self, iteration): encoder_path = os.path.join(self.ckpt_dir, 'iter_%s_encoder.pt' % iteration) decoder_path = os.path.join(self.ckpt_dir, 'iter_%s_decoder.pt' % iteration) prior_alpha_path = os.path.join(self.ckpt_dir, 'iter_%s_prior_alpha.pt' % iteration) post_alpha_path = os.path.join(self.ckpt_dir, 'iter_%s_post_alpha.pt' % iteration) D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % iteration) mkdirs(self.ckpt_dir) torch.save(self.encoder, encoder_path) torch.save(self.decoder, decoder_path) torch.save(self.prior_alpha, prior_alpha_path) torch.save(self.post_alpha, post_alpha_path) torch.save(self.D, D_path) #### def load_checkpoint(self): encoder_path = os.path.join(self.ckpt_dir, 'iter_%s_encoder.pt' % self.ckpt_load_iter) decoder_path = os.path.join(self.ckpt_dir, 'iter_%s_decoder.pt' % self.ckpt_load_iter) prior_alpha_path = os.path.join( self.ckpt_dir, 'iter_%s_prior_alpha.pt' % self.ckpt_load_iter) post_alpha_path = os.path.join( self.ckpt_dir, 'iter_%s_post_alpha.pt' % self.ckpt_load_iter) D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % self.ckpt_load_iter) if self.use_cuda: self.encoder = torch.load(encoder_path) self.decoder = torch.load(decoder_path) self.prior_alpha = torch.load(prior_alpha_path) self.post_alpha = torch.load(post_alpha_path) self.D = torch.load(D_path) else: self.encoder = torch.load(encoder_path, map_location='cpu') self.decoder = torch.load(decoder_path, map_location='cpu') self.prior_alpha = torch.load(prior_alpha_path, map_location='cpu') self.post_alpha = torch.load(post_alpha_path, map_location='cpu') self.D = torch.load(D_path, map_location='cpu')
class Solver(object): #### def __init__(self, args): self.args = args self.name = '%s_lr_%s_a_%s_r_%s_k_%s' % \ (args.dataset_name, args.lr_VAE, args.alpha, args.gamma, args.k_fold) self.device = args.device self.temp=0.66 self.dt=0.4 self.eps=1e-9 self.alpha=args.alpha self.gamma=args.gamma self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold)) self.dataset_dir = args.dataset_dir self.dataset_name = args.dataset_name # self.N = self.latent_values.shape[0] # self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE print(args.desc) # set run id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # records (text file to store console outputs) self.record_file = 'records/%s.txt' % self.name # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # outputs self.output_dir_recon = os.path.join("outputs", self.name + '_recon') #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter self.obs_len = args.obs_len self.pred_len = args.pred_len # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict( map_loss='win_map_loss', test_map_loss='win_test_map_loss' ) self.line_gather = DataGather( 'iter', 'loss', 'test_loss' ) import visdom self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port, env=self.name) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records"); mkdirs("ckpts"); mkdirs("outputs") if self.ckpt_load_iter == 0 or args.dataset_name =='all': # create a new model # self.encoder = Encoder( # fc_hidden_dim=args.hidden_dim, # output_dim=args.latent_dim, # drop_out=args.dropout_map).to(self.device) # # self.decoder = Decoder( # fc_hidden_dim=args.hidden_dim, # input_dim=args.latent_dim).to(self.device) num_filters = [32, 32, 32, 64, 64, 32, 32] # input = env + 8 past + lg / output = env + sg(including lg) self.sg_unet = Unet(input_channels=1, num_classes=1, num_filters=num_filters, apply_last_layer=True, padding=True).to(self.device) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') # get VAE parameters # vae_params = \ # list(self.encoder.parameters()) + \ # list(self.decoder.parameters()) vae_params = \ list(self.sg_unet.parameters()) # create optimizers self.optim_vae = optim.Adam( vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE] ) # prepare dataloader (iterable) print('Start loading data...') if self.ckpt_load_iter != self.max_iter: print("Initializing train dataset") _, self.train_loader = data_loader(self.args, args.dataset_dir, 'train', shuffle=True) print("Initializing val dataset") self.args.batch_size = 1 _, self.val_loader = data_loader(self.args, args.dataset_dir, 'test', shuffle=False) self.args.batch_size = args.batch_size print( 'There are {} iterations per epoch'.format(len(self.train_loader.dataset) / args.batch_size) ) print('...done') def preprocess_map(self, local_map, aug=False): local_map = torch.from_numpy(local_map).float().to(self.device) if aug: all_heatmaps = [] for h in local_map: h = torch.tensor(h).float().to(self.device) degree = np.random.choice([0, 90, 180, -90]) all_heatmaps.append( transforms.Compose([ transforms.RandomRotation(degrees=(degree, degree)) ])(h) ) all_heatmaps = torch.stack(all_heatmaps) else: all_heatmaps = local_map return all_heatmaps #### def train(self): self.set_mode(train=True) torch.autograd.set_detect_anomaly(True) data_loader = self.train_loader self.N = len(data_loader.dataset) iterator = iter(data_loader) iter_per_epoch = len(iterator) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch +=1 iterator = iter(data_loader) # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, pred_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = next(iterator) sampled_local_map = [] for s, e in seq_start_end: rng = list(range(s,e)) random.shuffle(rng) sampled_local_map.append(local_map[rng[:2]]) sampled_local_map = np.concatenate(sampled_local_map) batch_size = sampled_local_map.shape[0] local_map = self.preprocess_map(sampled_local_map, aug=True) recon_local_map = self.sg_unet.forward(local_map) recon_local_map = F.sigmoid(recon_local_map) focal_loss = F.mse_loss(recon_local_map, local_map).sum().div(batch_size) self.optim_vae.zero_grad() focal_loss.backward() self.optim_vae.step() # save model parameters if (iteration % (iter_per_epoch*10) == 0): self.save_checkpoint(epoch) # (visdom) insert current line stats if iteration == iter_per_epoch or (self.viz_on and (iteration % (iter_per_epoch * 10) == 0)): test_recon_map_loss = self.test() self.line_gather.insert(iter=epoch, loss=focal_loss.item(), test_loss= test_recon_map_loss.item(), ) prn_str = ('[iter_%d (epoch_%d)] loss: %.3f \n' ) % \ (iteration, epoch, focal_loss.item()) print(prn_str) self.visualize_line() self.line_gather.flush() def test(self): self.set_mode(train=False) loss=0 b = 0 with torch.no_grad(): for abatch in self.val_loader: b += 1 (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, pred_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = abatch batch_size = obs_traj.size(1) # =sum(seq_start_end[:,1] - seq_start_end[:,0]) local_map = self.preprocess_map(local_map, aug=False) recon_local_map = self.sg_unet.forward(local_map) recon_local_map = F.sigmoid(recon_local_map) focal_loss = F.mse_loss(recon_local_map, local_map).sum().div(batch_size) loss += focal_loss self.set_mode(train=True) return loss.div(b) #### def make_feat(self, test_loader): from sklearn.manifold import TSNE # from data.trajectories import seq_collate # from data.macro_trajectories import TrajectoryDataset # from torch.utils.data import DataLoader # test_dset = TrajectoryDataset('../datasets/large_real/Trajectories', data_split='test', device=self.device) # test_loader = DataLoader(dataset=test_dset, batch_size=1, # shuffle=True, num_workers=0) self.set_mode(train=False) with torch.no_grad(): test_enc_feat = [] total_scenario = [] b = 0 for batch in test_loader: b+=1 if len(test_enc_feat) > 0 and np.concatenate(test_enc_feat).shape[0] > 1000: break (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, fut_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = batch rng = list(range(len(local_map))) random.shuffle(rng) sampling_idx = rng[:32] local_map1 = local_map[sampling_idx] local_map1 = self.preprocess_map(local_map1, aug=False) self.sg_unet.forward(local_map1) test_enc_feat.append(self.sg_unet.enc_feat.view(len(local_map1), -1).detach().cpu().numpy()) for m in map_path[sampling_idx]: total_scenario.append(int(m.split('/')[-1].split('.')[0])) import matplotlib.pyplot as plt test_enc_feat = np.concatenate(test_enc_feat) print(test_enc_feat.shape) # tsne = TSNE(n_components=2, random_state=0) # tsne_feat = tsne.fit_transform(test_enc_feat) all_feat = np.concatenate([test_enc_feat, np.expand_dims(np.array(total_scenario),1)], 1) np.save('large_tsne_r10_k0_tr.npy', all_feat) print('done') ''' import pandas as pd df = pd.read_csv('C:\dataset\large_real/large_5_bs1.csv') data = np.array(df) # all_feat = np.load('large_tsne_ae1_tr.npy') all_feat_tr = np.load('large_tsne_lg_k0_tr.npy') all_feat_te = np.load('large_tsne_lg_k0_te.npy') # tsne_faet = np.concatenate([all_feat[:,:2], all_feat_te[:,:2]]) all_feat = np.concatenate([all_feat_tr[:,:-3], all_feat_te[:,:-3]]) tsne = TSNE(n_components=2, random_state=0, perplexity=30) tsne_feat = tsne.fit_transform(all_feat) # tsne_faet = all_feat_tr[:,:-3] # obst_ratio = all_feat_tr[:,-3] # curv = all_feat_tr[:,-2] # scenario = all_feat_tr[:,-1] tsne_faet = all_feat_tr[:,:-3] obst_ratio = all_feat_tr[:,-3] curv = all_feat_tr[:,-2] scenario = np.concatenate([all_feat_tr[:,-1], all_feat_te[:,-1]]) labels = scenario //10 labels = obst_ratio*100 //10 # labels = curv*100 //10 target_names = ['Training', 'Test'] colors = np.array(['blue', 'red']) labels= np.array(df['0.5']) // 10 labels= np.array(df['# agent']) //10 labels= np.array(df['curvature'])*100 //10 labels= np.array(df['map ratio'])*100 //10 ## k fold labels k=0 labels = scenario //10 for i in range(len(labels)): if labels[i] in range(k*3,(k+1)*3): labels[i] = 0 else: labels[i] = 1 # colors = ['red', 'magenta', 'lightgreen', 'slateblue', 'blue', 'darkgreen', 'darkorange', # 'gray', 'purple', 'turquoise', 'midnightblue', 'olive', 'black', 'pink', 'burlywood', # 'yellow'] colors = np.array(['gray','pink', 'orange', 'magenta', 'darkgreen', 'cyan', 'blue', 'red', 'lightgreen', 'olive', 'burlywood', 'purple']) target_names = np.unique(labels) fig = plt.figure(figsize=(5,4)) fig.tight_layout() # labels = np.concatenate([np.zeros(len(all_feat_tr)), np.ones(len(all_feat_te))]) target_names = ['Training', 'Test'] colors = np.array(['blue', 'red']) for color, i, target_name in zip(colors, np.unique(labels), target_names): plt.scatter(tsne_feat[labels == i, 0], tsne_feat[labels == i, 1], alpha=.5, color=color, label=str(target_name), s=10) fig.axes[0]._get_axis_list()[0].set_visible(False) fig.axes[0]._get_axis_list()[1].set_visible(False) plt.legend(loc=0, shadow=False, scatterpoints=1) ''' #### def viz_init(self): self.viz.close(env=self.name, win=self.win_id['test_map_loss']) self.viz.close(env=self.name, win=self.win_id['map_loss']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) test_map_loss = torch.Tensor(data['test_loss']) map_loss = torch.Tensor(data['loss']) self.viz.line( X=iters, Y=map_loss, env=self.name, win=self.win_id['map_loss'], update='append', opts=dict(xlabel='iter', ylabel='loss', title='Recon. map loss') ) self.viz.line( X=iters, Y=test_map_loss, env=self.name, win=self.win_id['test_map_loss'], update='append', opts=dict(xlabel='iter', ylabel='test_loss', title='Recon. map loss - Test'), ) # # # def set_mode(self, train=True): # # if train: # self.encoder.train() # self.decoder.train() # else: # self.encoder.eval() # self.decoder.eval() # # #### # def save_checkpoint(self, iteration): # # encoder_path = os.path.join( # self.ckpt_dir, # 'iter_%s_encoder.pt' % iteration # ) # decoder_path = os.path.join( # self.ckpt_dir, # 'iter_%s_decoder.pt' % iteration # ) # # # mkdirs(self.ckpt_dir) # # torch.save(self.encoder, encoder_path) # torch.save(self.decoder, decoder_path) #### def set_mode(self, train=True): if train: self.sg_unet.train() else: self.sg_unet.eval() #### def save_checkpoint(self, iteration): sg_unet_path = os.path.join( self.ckpt_dir, 'iter_%s_sg_unet.pt' % iteration ) mkdirs(self.ckpt_dir) torch.save(self.sg_unet, sg_unet_path) #### def load_checkpoint(self): sg_unet_path = os.path.join( self.ckpt_dir, 'iter_%s_sg_unet.pt' % self.ckpt_load_iter ) if self.device == 'cuda': sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_run_8/iter_100_sg_unet.pt' sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_9/iter_100_sg_unet.pt' print('>>>>>>>>>>> load: ', sg_unet_path) self.sg_unet = torch.load(sg_unet_path) else: sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_run_8/iter_100_sg_unet.pt' sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_1_run_10/iter_200_sg_unet.pt' # sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_10/iter_200_sg_unet.pt' sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_9/iter_40_sg_unet.pt' # sg_unet_path = 'd:\crowd\mcrowd\ckpts\mapae.path_lr_0.001_a_0.25_r_2.0_run_2/iter_3360_sg_unet.pt' self.sg_unet = torch.load(sg_unet_path, map_location='cpu') #### # # def load_checkpoint(self): # # encoder_path = os.path.join( # self.ckpt_dir, # 'iter_%s_encoder.pt' % self.ckpt_load_iter # ) # decoder_path = os.path.join( # self.ckpt_dir, # 'iter_%s_decoder.pt' % self.ckpt_load_iter # ) # # if self.device == 'cuda': # self.encoder = torch.load(encoder_path) # self.decoder = torch.load(decoder_path) # else: # self.encoder = torch.load(encoder_path, map_location='cpu') # self.decoder = torch.load(decoder_path, map_location='cpu') # # def load_map_weights(self, map_path): # if self.device == 'cuda': # loaded_map_w = torch.load(map_path) # else: # loaded_map_w = torch.load(map_path, map_location='cpu') # self.encoder.conv1.weight = loaded_map_w.map_net.conv1.weight # self.encoder.conv2.weight = loaded_map_w.map_net.conv2.weight # self.encoder.conv3.weight = loaded_map_w.map_net.conv3.weight
class Solver(object): def __init__(self, args): # Misc use_cuda = args.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.name = args.name self.max_iter = int(args.max_iter) self.print_iter = args.print_iter self.global_iter = 0 self.pbar = tqdm(total=self.max_iter) # Data self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) # Networks & Optimizers self.z_dim = args.z_dim self.gamma = args.gamma self.etaS = args.etaS self.etaH = args.etaH self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D self.lr_r = args.lr_r self.beta1_r = args.beta1_r self.beta2_r = args.beta2_r ones = torch.Tensor(np.ones([self.z_dim])*0.5).to(self.device) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1 self.r = torch.nn.Parameter(ones) if args.dataset == 'dsprites': self.VAE = RF_VAE1(self.z_dim).to(self.device) self.nc = 1 else: self.VAE = RF_VAE2(self.z_dim).to(self.device) self.nc = 3 self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE)) self.D = Discriminator(self.z_dim).to(self.device) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1_D, self.beta2_D)) self.optim_r = optim.Adam([self.r],lr=self.lr_r, betas=(self.beta1_r,self.beta2_r)) self.nets = [self.VAE, self.D] # Visdom self.viz_on = args.viz_on self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc',r_distribute = 'r_distribute') self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc','r_distribute') self.image_gather = DataGather('true', 'recon') if self.viz_on: self.viz_port = args.viz_port self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_ra_iter = args.viz_ra_iter self.viz_ta_iter = args.viz_ta_iter # Checkpoint self.ckpt_dir = os.path.join(args.ckpt_dir, args.name) self.ckpt_save_iter = args.ckpt_save_iter mkdirs(self.ckpt_dir) if args.ckpt_load: self.load_checkpoint(args.ckpt_load) # Output(latent traverse GIF) self.output_dir = os.path.join(args.output_dir, args.name) self.output_save = args.output_save mkdirs(self.output_dir) def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader: self.global_iter += 1 self.pbar.update(1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) vae_recon_loss = recon_loss(x_true1, x_recon) vae_kld = kl_divergence(mu, logvar,self.r) H_r = entropy(self.r) D_z = self.D(self.r*z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss + self.etaS*self.r.abs().sum() + self.etaH*H_r self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() self.optim_r.zero_grad() vae_loss.backward(retain_graph=True) self.optim_r.step() x_true2 = x_true2.to(self.device) z_prime = self.VAE(x_true2, no_dec=True) z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(self.r*z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_D.zero_grad() D_tc_loss.backward() self.optim_D.step() if self.global_iter%self.print_iter == 0: self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.viz_on and (self.global_iter%self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2*self.batch_size self.line_gather.insert(iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), kld=vae_kld.item(), acc=D_acc.item(), r_distribute=self.r.data.cpu()) if self.viz_on and (self.global_iter%self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() if self.viz_on and (self.global_iter%self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() if self.viz_on and (self.global_iter%self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: self.visualize_traverse(limit=3, inter=2/3) if self.global_iter >= self.max_iter: out = True break self.pbar.write("[Training Finished]") self.pbar.close() def visualize_recon(self): data = self.image_gather.data true_image = data['true'][0] recon_image = data['recon'][0] true_image = make_grid(true_image) recon_image = make_grid(recon_image) sample = torch.stack([true_image, recon_image], dim=0) self.viz.images(sample, env=self.name+'/recon_image', opts=dict(title=str(self.global_iter))) def visualize_line(self): data = self.line_gather.data iters = torch.Tensor(data['iter']) recon = torch.Tensor(data['recon']) kld = torch.Tensor(data['kld']) D_acc = torch.Tensor(data['acc']) soft_D_z = torch.Tensor(data['soft_D_z']) soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm']) r_distribute = data['r_distribute'][-1] soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1) if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['D_z']): self.viz.line(X=iters, Y=soft_D_zs, env=self.name + '/lines', win=self.win_id['D_z'], opts=dict( xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) else: self.viz.line(X=iters, Y=soft_D_zs, env=self.name+'/lines', win=self.win_id['D_z'], update='append', opts=dict( xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['recon']): self.viz.line(X=iters, Y=recon, env=self.name + '/lines', win=self.win_id['recon'], opts=dict( xlabel='iteration', ylabel='reconstruction loss', )) else: self.viz.line(X=iters, Y=recon, env=self.name+'/lines', win=self.win_id['recon'], update='append', opts=dict( xlabel='iteration', ylabel='reconstruction loss',)) if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['acc']): self.viz.line(X=iters, Y=D_acc, env=self.name + '/lines', win=self.win_id['acc'], opts=dict( xlabel='iteration', ylabel='discriminator accuracy', )) else: self.viz.line(X=iters, Y=D_acc, env=self.name+'/lines', win=self.win_id['acc'], update='append', opts=dict( xlabel='iteration', ylabel='discriminator accuracy',)) if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['kld']): self.viz.line(X=iters, Y=kld, env=self.name+'/lines', win=self.win_id['kld'], opts=dict( xlabel='iteration', ylabel='kl divergence',)) else: self.viz.line(X=iters, Y=kld, env=self.name + '/lines', win=self.win_id['kld'], update='append', opts=dict( xlabel='iteration', ylabel='kl divergence', )) if self.viz.win_exists(env=self.name + '/lines', win=self.win_id['r_distribute']): self.viz.close(win=self.win_id['r_distribute'],env=self.name + '/lines') self.viz.bar(X=r_distribute, env=self.name + '/lines', win=self.win_id['r_distribute'], opts=dict( xlabel='dimention', ylabel='relevance score', )) def visualize_traverse(self, limit=3, inter=2/3, loc=-1): self.net_mode(train=False) decoder = self.VAE.decode encoder = self.VAE.encode interpolation = torch.arange(-limit, limit+0.1, inter) random_img = self.data_loader.dataset.__getitem__(0)[1] random_img = random_img.to(self.device).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] if self.dataset.lower() == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2, 'fixed_heart':fixed_img_z3, 'random_img':random_img_z} elif self.dataset.lower() == 'celeba': fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg' fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg' fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg' fixed_idx4 = 70059 # 'CelebA/img_align_celeba/070060.jpg' fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0] fixed_img4 = fixed_img4.to(self.device).unsqueeze(0) fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim] Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4, 'random':random_img_z} elif self.dataset.lower() == '3dchairs': fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png fixed_idx2 = 5172 # 3DChairs/images/14657_image_020_p020_t232_r096.png fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, 'fixed_3':fixed_img_z3, 'random':random_img_z} else: fixed_idx = 0 fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0] fixed_img = fixed_img.to(self.device).unsqueeze(0) fixed_img_z = encoder(fixed_img)[:, :self.z_dim] random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device) Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} gifs = [] for key in Z: z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) self.viz.images(samples, env=self.name+'/traverse', opts=dict(title=title), nrow=len(interpolation)) if self.output_save: output_dir = os.path.join(self.output_dir, str(self.global_iter)) mkdirs(output_dir) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) name_str = '' for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) name_str = name_str + '{}_{}.jpg '.format(key, j) grid2gif(name_str, str(os.path.join(output_dir, key+'.gif')),output_dir,delay=10) self.net_mode(train=True) def net_mode(self, train): if not isinstance(train, bool): raise ValueError('Only bool type is supported. True|False') for net in self.nets: if train: net.train() else: net.eval() def save_checkpoint(self, ckptname='last', verbose=True): model_states = {'D':self.D.state_dict(), 'VAE':self.VAE.state_dict(), 'r':self.r} optim_states = {'optim_D':self.optim_D.state_dict(), 'optim_VAE':self.optim_VAE.state_dict(), 'optim_r':self.optim_r.state_dict()} states = {'iter':self.global_iter, 'model_states':model_states, 'optim_states':optim_states} filepath = os.path.join(self.ckpt_dir, str(ckptname)) with open(filepath, 'wb+') as f: torch.save(states, f) if verbose: self.pbar.write("=> saved checkpoint '{}' (iter {})".format(filepath, self.global_iter)) def load_checkpoint(self, ckptname='last', verbose=True): if ckptname == 'last': ckpts = os.listdir(self.ckpt_dir) if not ckpts: if verbose: self.pbar.write("=> no checkpoint found") return ckpts = [int(ckpt) for ckpt in ckpts] ckpts.sort(reverse=True) ckptname = str(ckpts[0]) filepath = os.path.join(self.ckpt_dir, ckptname) if os.path.isfile(filepath): with open(filepath, 'rb') as f: checkpoint = torch.load(f) self.global_iter = checkpoint['iter'] self.VAE.load_state_dict(checkpoint['model_states']['VAE']) self.D.load_state_dict(checkpoint['model_states']['D']) self.r = checkpoint['model_states']['r'] self.optim_VAE.load_state_dict(checkpoint['optim_states']['optim_VAE']) self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D']) self.optim_r.load_state_dict(checkpoint['optim_states']['optim_r']) self.pbar.update(self.global_iter) if verbose: self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(filepath, self.global_iter)) else: if verbose: self.pbar.write("=> no checkpoint found at '{}'".format(filepath))
class Solver(object): #### def __init__(self, args): self.args = args self.name = '%s_lamkl_%s_zA_%s_zB_%s_zS_%s_HYPER_beta1_%s_beta2_%s_beta3_%s' % \ ( args.dataset, args.lamkl, args.zA_dim, args.zB_dim, args.zS_dim, args.beta1, args.beta2, args.beta3) # to be appended by run_id self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info self.dset_dir = args.dset_dir self.dataset = args.dataset self.nc = 3 # self.N = self.latent_values.shape[0] self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.zA_dim = args.zA_dim self.zB_dim = args.zB_dim self.zS_dim = args.zS_dim self.lamkl = args.lamkl self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D self.beta1 = args.beta1 self.beta2 = args.beta2 self.beta3 = args.beta3 self.is_mss = args.is_mss # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict( recon='win_recon', kl='win_kl', capa='win_capa' ) self.line_gather = DataGather( 'iter', 'recon_both', 'recon_A', 'recon_B', 'kl_A', 'kl_B', 'cont_capacity_loss_infA', 'disc_capacity_loss_infA', 'cont_capacity_loss_infB', 'disc_capacity_loss_infB' ) # if self.eval_metrics: # self.win_id['metrics'] = 'win_metrics' import visdom self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records"); mkdirs("ckpts"); mkdirs("outputs") # set run id if args.run_id < 0: # create a new id k = 0; rfname = os.path.join("records", self.name + '_run_0.txt') while os.path.exists(rfname): k += 1 rfname = os.path.join("records", self.name + '_run_%d.txt' % k) self.run_id = k else: # user-provided id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # records (text file to store console outputs) self.record_file = 'records/%s.txt' % self.name # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # outputs self.output_dir_recon = os.path.join("outputs", self.name + '_recon') # dir for reconstructed images self.output_dir_synth = os.path.join("outputs", self.name + '_synth') # dir for synthesized images self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl') #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter self.n_pts = args.n_pts self.n_data = args.n_data if self.ckpt_load_iter == 0: # create a new model self.encoderA = EncoderA(self.zA_dim, self.zS_dim) self.encoderB = EncoderA(self.zB_dim, self.zS_dim) self.decoderA = DecoderA(self.zA_dim, self.zS_dim) self.decoderB = DecoderA(self.zB_dim, self.zS_dim) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') if self.use_cuda: print('Models moved to GPU...') self.encoderA = self.encoderA.cuda() self.encoderB = self.encoderB.cuda() self.decoderA = self.decoderA.cuda() self.decoderB = self.decoderB.cuda() print('...done') # get VAE parameters vae_params = \ list(self.encoderA.parameters()) + \ list(self.encoderB.parameters()) + \ list(self.decoderA.parameters()) + \ list(self.decoderB.parameters()) # create optimizers self.optim_vae = optim.Adam( vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE] ) #### def train(self): self.set_mode(train=True) # prepare dataloader (iterable) print('Start loading data...') dset = DIGIT('./data', train=True) self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True) test_dset = DIGIT('./data', train=False) self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True) print('test: ', len(test_dset)) self.N = len(self.data_loader.dataset) print('...done') # iterators from dataloader iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) iter_per_epoch = min(len(iterator1), len(iterator2)) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ # sample a mini-batch XA, XB, index = next(iterator1) # (n x C x H x W) index = index.cpu().detach().numpy() if self.use_cuda: XA = XA.cuda() XB = XB.cuda() # zA, zS = encA(xA) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB) # read current values # zS = encAB(xA,xB) via POE cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist) # kl losses #A latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function( self.use_cuda, iteration, latent_dist_infA) loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA #B latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]} (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function( self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0]) loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB loss_capa = capacity_loss_infB # encoder samples (for training) ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA) ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB) ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE) # encoder samples (for cross-modal prediction) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB) # reconstructed samples (given joint modal observation) XA_POE_recon = self.decoderA(ZA_infA, ZS_POE) XB_POE_recon = self.decoderB(ZB_infB, ZS_POE) # reconstructed samples (given single modal observation) XA_infA_recon = self.decoderA(ZA_infA, ZS_infA) XB_infB_recon = self.decoderB(ZB_infB, ZS_infB) # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0)) loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli") # loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli") # loss_recon_POE = \ F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \ F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0)) # loss_recon = loss_recon_infB # total loss for vae vae_loss = loss_recon + loss_capa # update vae self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() # print the losses if iteration % self.print_iter == 0: prn_str = ( \ '[iter %d (epoch %d)] vae_loss: %.3f ' + \ '(recon: %.3f, capa: %.3f)\n' + \ ' rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \ ' kl_infA = %.3f, kl_infB = %.3f' + \ ' cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \ ' cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n' ) % \ (iteration, epoch, vae_loss.item(), loss_recon.item(), loss_capa.item(), loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(), loss_kl_infA.item(), loss_kl_infB.item(), cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(), cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(), ) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str,)) record.close() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # save output images (recon, synth, etc.) if iteration % self.output_save_iter == 0: # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE) # 1) save the recon images self.save_recon(iteration) # self.save_recon2(iteration, index, XA, XB, # torch.sigmoid(XA_infA_recon).data, # torch.sigmoid(XB_infB_recon).data, # torch.sigmoid(XA_POE_recon).data, # torch.sigmoid(XB_POE_recon).data, # muA_infA, muB_infB, muS_infA, muS_infB, muS_POE, # logalpha, logalphaA, logalphaB # ) z_A, z_B, z_S = self.get_stat() # # # # # 2) save the pure-synthesis images # # self.save_synth_pure( iteration, howmany=100 ) # # # # 3) save the cross-modal-synthesis images # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3) # # # 4) save the latent traversed images self.save_traverseB(iteration, z_A, z_B, z_S) # self.get_loglike(logalpha, logalphaA, logalphaB) # # 3) save the latent traversed images # if self.dataset.lower() == '3dchairs': # self.save_traverse(iteration, limb=-2, limu=2, inter=0.5) # else: # self.save_traverse(iteration, limb=-3, limu=3, inter=0.1) if iteration % self.eval_metrics_iter == 0: self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): self.line_gather.insert(iter=iteration, recon_both=loss_recon_POE.item(), recon_A=loss_recon_infA.item(), recon_B=loss_recon_infB.item(), kl_A=loss_kl_infA.item(), kl_B=loss_kl_infB.item(), cont_capacity_loss_infA=cont_capacity_loss_infA.item(), disc_capacity_loss_infA=disc_capacity_loss_infA.item(), cont_capacity_loss_infB=cont_capacity_loss_infB.item(), disc_capacity_loss_infB=disc_capacity_loss_infB.item() ) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() # evaluate metrics # if self.eval_metrics and (iteration % self.eval_metrics_iter == 0): # # metric1, _ = self.eval_disentangle_metric1() # metric2, _ = self.eval_disentangle_metric2() # # prn_str = ( '********\n[iter %d (epoch %d)] ' + \ # 'metric1 = %.4f, metric2 = %.4f\n********' ) % \ # (iteration, epoch, metric1, metric2) # print(prn_str) # if self.record_file: # record = open(self.record_file, 'a') # record.write('%s\n' % (prn_str,)) # record.close() # # # (visdom) visulaize metrics # if self.viz_on: # self.visualize_line_metrics(iteration, metric1, metric2) # #### def eval_disentangle_metric1(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader( self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch XAb, XBb, _, _, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: XAb = XAb.cuda() XBb = XBb.cuda() # z = encA(xA) mu_infA, _, logvar_infA = self.encoderA(XAb) # z = encB(xB) mu_infB, _, logvar_infB = self.encoderB(XBb) # z = encAB(xA,xB) via POE mu_POE, _, _ = apply_poe( self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB, ) mub = mu_POE M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor fixed" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # a true factor (id and class value) to fix fac_id = j fac_class = np.random.randint(self.latent_sizes[fac_id]) # randomly select images (with the fixed factor) indices = np.where( self.latent_classes[:, fac_id] == fac_class)[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): XAb, XBb, _, _, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if XAb.shape[0] < 1: # no more samples continue; if self.use_cuda: XAb = XAb.cuda() XBb = XBb.cuda() mu_infA, _, logvar_infA = self.encoderA(XAb) mu_infB, _, logvar_infB = self.encoderB(XBb) mu_POE, _, _ = apply_poe(self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB, ) mub = mu_POE M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = fac_id i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification smallest_var_dims = np.argmin( vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[smallest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric1 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric1, C #### def eval_disentangle_metric2(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader( self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch XAb, XBb, _, _, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: XAb = XAb.cuda() XBb = XBb.cuda() # z = encA(xA) mu_infA, _, logvar_infA = self.encoderA(XAb) # z = encB(xB) mu_infB, _, logvar_infB = self.encoderB(XBb) # z = encAB(xA,xB) via POE mu_POE, _, _ = apply_poe( self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB, ) mub = mu_POE M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor varied" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # randomly choose true factors (id's and class values) to fix fac_ids = list(np.setdiff1d(factor_ids, j)) fac_classes = \ [np.random.randint(self.latent_sizes[k]) for k in fac_ids] # randomly select images (with the other factors fixed) if len(fac_ids) > 1: indices = np.where( np.sum(self.latent_classes[:, fac_ids] == fac_classes, 1) == len(fac_ids) )[0] else: indices = np.where( self.latent_classes[:, fac_ids] == fac_classes )[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): XAb, XBb, _, _, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if XAb.shape[0] < 1: # no more samples continue; if self.use_cuda: XAb = XAb.cuda() XBb = XBb.cuda() mu_infA, _, logvar_infA = self.encoderA(XAb) mu_infB, _, logvar_infB = self.encoderB(XBb) mu_POE, _, _ = apply_poe(self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB, ) mub = mu_POE M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = j i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification largest_var_dims = np.argmax( vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[largest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric2 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric2, C def save_recon(self, iters): self.set_mode(train=False) mkdirs(self.output_dir_recon) fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121] fixed_idxs60 = [] for idx in fixed_idxs: for i in range(6): fixed_idxs60.append(idx + i) XA = [0] * len(fixed_idxs60) XB = [0] * len(fixed_idxs60) for i, idx in enumerate(fixed_idxs60): XA[i], XB[i] = \ self.data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: XA[i] = XA[i].cuda() XB[i] = XB[i].cuda() XA = torch.stack(XA) XB = torch.stack(XB) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB) # zS = encAB(xA,xB) via POE cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # encoder samples (for training) ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA) ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB) ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE, train=False) # encoder samples (for cross-modal prediction) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False) # reconstructed samples (given joint modal observation) XA_POE_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_POE)) XB_POE_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_POE)) # reconstructed samples (given single modal observation) XA_infA_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_infA)) XB_infB_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_infB)) WS = torch.ones(XA.shape) if self.use_cuda: WS = WS.cuda() n = XA.shape[0] perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0) perm = perm.contiguous().view(-1) ## img # merged = torch.cat( # [ XA, XB, XA_infA_recon, XB_infB_recon, # XA_POE_recon, XB_POE_recon, WS ], dim=0 # ) merged = torch.cat( [XA, XA_infA_recon, XA_POE_recon, WS], dim=0 ) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'reconA_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image( tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)), pad_value=1 ) WS = torch.ones(XB.shape) if self.use_cuda: WS = WS.cuda() n = XB.shape[0] perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0) perm = perm.contiguous().view(-1) ## ingr merged = torch.cat( [XB, XB_infB_recon, XB_POE_recon, WS], dim=0 ) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'reconB_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image( tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)), pad_value=1 ) self.set_mode(train=True) #### def save_synth_pure(self, iters, howmany=100): self.set_mode(train=False) decoderA = self.decoderA decoderB = self.decoderB Z = torch.randn(howmany, self.z_dim) if self.use_cuda: Z = Z.cuda() # do synthesis XA = torch.sigmoid(decoderA(Z)).data XB = torch.sigmoid(decoderB(Z)).data WS = torch.ones(XA.shape) if self.use_cuda: WS = WS.cuda() perm = torch.arange(0, 3 * howmany).view(3, howmany).transpose(1, 0) perm = perm.contiguous().view(-1) merged = torch.cat([XA, XB, WS], dim=0) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join( self.output_dir_synth, 'synth_pure_%s.jpg' % iters ) mkdirs(self.output_dir_synth) save_image( tensor=merged, filename=fname, nrow=3 * int(np.sqrt(howmany)), pad_value=1 ) self.set_mode(train=True) #### def save_synth_cross_modal(self, iters, z_A_stat, z_B_stat, train=True, howmany=3): self.set_mode(train=False) if train: data_loader = self.data_loader fixed_idxs = [3246, 7001, 14308, 19000, 27447, 33103, 38002, 45232, 51000, 55125] else: data_loader = self.test_data_loader fixed_idxs = [2, 982, 2300, 3400, 4500, 5500, 6500, 7500, 8500, 9500] fixed_XA = [0] * len(fixed_idxs) fixed_XB = [0] * len(fixed_idxs) for i, idx in enumerate(fixed_idxs): fixed_XA[i], fixed_XB[i] = \ data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: fixed_XA[i] = fixed_XA[i].cuda() fixed_XB[i] = fixed_XB[i].cuda() fixed_XA = torch.stack(fixed_XA) fixed_XB = torch.stack(fixed_XB) _, _, _, cate_prob_infA = self.encoderA(fixed_XA) # zB, zS = encB(xB) _, _, _, cate_prob_infB = self.encoderB(fixed_XB) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False) if self.use_cuda: ZS_infA = ZS_infA.cuda() ZS_infB = ZS_infB.cuda() decoderA = self.decoderA decoderB = self.decoderB # mkdirs(os.path.join(self.output_dir_synth, str(iters))) fixed_XA_3ch = [] for i in range(len(fixed_XA)): each_XA = fixed_XA[i].clone().squeeze() fixed_XA_3ch.append(torch.stack([each_XA, each_XA, each_XA])) fixed_XA_3ch = torch.stack(fixed_XA_3ch) WS = torch.ones(fixed_XA_3ch.shape) if self.use_cuda: WS = WS.cuda() n = len(fixed_idxs) perm = torch.arange(0, (howmany + 2) * n).view(howmany + 2, n).transpose(1, 0) perm = perm.contiguous().view(-1) ######## 1) generate xB from given xA (A2B) ######## merged = torch.cat([fixed_XA_3ch], dim=0) for k in range(howmany): # z_B_stat = np.array(z_B_stat) # z_B_stat_mean = np.mean(z_B_stat, 0) # ZB = torch.Tensor(z_B_stat_mean) # ZB_list = [] # for _ in range(n): # ZB_list.append(ZB) # ZB = torch.stack(ZB_list) ZB = torch.randn(n, self.zB_dim) z_B_stat = np.array(z_B_stat) z_B_stat_mean = np.mean(z_B_stat, 0) ZB = ZB + torch.Tensor(z_B_stat_mean) if self.use_cuda: ZB = ZB.cuda() XB_synth = torch.sigmoid(decoderB(ZB, ZS_infA)) # given XA # merged = torch.cat([merged, fixed_XA_3ch], dim=0) merged = torch.cat([merged, XB_synth], dim=0) merged = torch.cat([merged, WS], dim=0) merged = merged[perm, :].cpu() # save the results as image if train: fname = os.path.join( self.output_dir_synth, 'synth_cross_modal_A2B_%s.jpg' % iters ) else: fname = os.path.join( self.output_dir_synth, 'eval_synth_cross_modal_A2B_%s.jpg' % iters ) mkdirs(self.output_dir_synth) save_image( tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)), pad_value=1 ) ######## 2) generate xA from given xB (B2A) ######## merged = torch.cat([fixed_XB], dim=0) for k in range(howmany): # z_A_stat = np.array(z_A_stat) # z_A_stat_mean = np.mean(z_A_stat, 0) # ZA = torch.Tensor(z_A_stat_mean) # ZA_list = [] # for _ in range(n): # ZA_list.append(ZA) # ZA = torch.stack(ZA_list) ZA = torch.randn(n, self.zA_dim) z_A_stat = np.array(z_A_stat) z_A_stat_mean = np.mean(z_A_stat, 0) ZA = ZA + torch.Tensor(z_A_stat_mean) if self.use_cuda: ZA = ZA.cuda() XA_synth = torch.sigmoid(decoderA(ZA, ZS_infB)) # given XB XA_synth_3ch = [] for i in range(len(XA_synth)): each_XA = XA_synth[i].clone().squeeze() XA_synth_3ch.append(torch.stack([each_XA, each_XA, each_XA])) # merged = torch.cat([merged, fixed_XB[:,:,2:30, 2:30]], dim=0) merged = torch.cat([merged, torch.stack(XA_synth_3ch)], dim=0) merged = torch.cat([merged, WS], dim=0) merged = merged[perm, :].cpu() # save the results as image if train: fname = os.path.join( self.output_dir_synth, 'synth_cross_modal_B2A_%s.jpg' % iters ) else: fname = os.path.join( self.output_dir_synth, 'eval_synth_cross_modal_B2A_%s.jpg' % iters ) mkdirs(self.output_dir_synth) save_image( tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)), pad_value=1 ) self.set_mode(train=True) def get_stat(self): encoderA = self.encoderA encoderB = self.encoderB z_A, z_B, z_S = [], [], [] for _ in range(10000): rand_i = np.random.randint(self.N) random_XA, random_XB = self.data_loader.dataset.__getitem__(rand_i)[0:2] if self.use_cuda: random_XA = random_XA.cuda() random_XB = random_XB.cuda() random_XA = random_XA.unsqueeze(0) random_XB = random_XB.unsqueeze(0) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(random_XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(random_XB) cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) z_A.append(muA_infA.cpu().detach().numpy()[0]) z_B.append(muB_infB.cpu().detach().numpy()[0]) z_S.append(cate_prob_POE.cpu().detach().numpy()[0]) return z_A, z_B, z_S def save_traverseA(self, iters, z_A, z_B, z_S, loc=-1): self.set_mode(train=False) encoderA = self.encoderA encoderB = self.encoderB decoderA = self.decoderA decoderB = self.decoderB interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim)) interpolationB = torch.tensor(np.linspace(-3, 3, self.zS_dim)) interpolationS = torch.tensor(np.linspace(-3, 3, self.zS_dim)) print('------------ traverse interpolation ------------') print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A))) print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B))) print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S))) if self.record_file: #### fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121] fixed_XA = [0] * len(fixed_idxs) fixed_XB = [0] * len(fixed_idxs) for i, idx in enumerate(fixed_idxs): fixed_XA[i], fixed_XB[i] = \ self.data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: fixed_XA[i] = fixed_XA[i].cuda() fixed_XB[i] = fixed_XB[i].cuda() fixed_XA[i] = fixed_XA[i].unsqueeze(0) fixed_XB[i] = fixed_XB[i].unsqueeze(0) fixed_XA = torch.cat(fixed_XA, dim=0) fixed_XB = torch.cat(fixed_XB, dim=0) fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA) # zB, zS = encB(xB) fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB) # zS = encAB(xA,xB) via POE fixed_cate_probS = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False) fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape #### WS = torch.ones(saving_shape) if self.use_cuda: WS = WS.cuda() # do traversal and collect generated images gifs = [] zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS tempA = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32 for row in range(self.zA_dim): if loc != -1 and row != loc: continue zA = zA_ori.clone() temp = [] for val in interpolationA: zA[:, row] = val sampleA = torch.sigmoid(decoderA(zA, zS_ori)).data temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0)) tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32 temp = [] for i in range(self.zS_dim): zS = np.zeros((1, self.zS_dim)) zS[0, i % self.zS_dim] = 1. zS = torch.Tensor(zS) zS = torch.cat([zS] * len(fixed_idxs), dim=0) if self.use_cuda: zS = zS.cuda() sampleA = torch.sigmoid(decoderA(zA_ori, zS)).data temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0)) tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) gifs = torch.cat(tempA, dim=0) #torch.Size([11, 10, 1, 384, 32]) # save the generated files, also the animated gifs out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train') mkdirs(self.output_dir_trvsl) mkdirs(out_dir) for j, val in enumerate(interpolationA): # I = torch.cat([IMG[key], gifs[:][j]], dim=0) I = gifs[:,j] save_image( tensor=I.cpu(), filename=os.path.join(out_dir, '%03d.jpg' % (j)), nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim, pad_value=1) # make animated gif grid2gif2( out_dir, str(os.path.join(out_dir, 'mnist_traverse' + '.gif')), delay=10 ) self.set_mode(train=True) ### def save_traverseB(self, iters, z_A, z_B, z_S, loc=-1): self.set_mode(train=False) encoderA = self.encoderA encoderB = self.encoderB decoderB = self.decoderB interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim)) print('------------ traverse interpolation ------------') print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A))) print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B))) print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S))) if self.record_file: #### fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121] fixed_XA = [0] * len(fixed_idxs) fixed_XB = [0] * len(fixed_idxs) for i, idx in enumerate(fixed_idxs): fixed_XA[i], fixed_XB[i] = \ self.data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: fixed_XA[i] = fixed_XA[i].cuda() fixed_XB[i] = fixed_XB[i].cuda() fixed_XA[i] = fixed_XA[i].unsqueeze(0) fixed_XB[i] = fixed_XB[i].unsqueeze(0) fixed_XA = torch.cat(fixed_XA, dim=0) fixed_XB = torch.cat(fixed_XB, dim=0) fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA) # zB, zS = encB(xB) fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB) # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False) fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False) saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape #### WS = torch.ones(saving_shape) if self.use_cuda: WS = WS.cuda() # do traversal and collect generated images gifs = [] zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS tempB = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32 for row in range(self.zB_dim): if loc != -1 and row != loc: continue zB = zB_ori.clone() temp = [] for val in interpolationA: zB[:, row] = val sampleB = torch.sigmoid(decoderB(zB, zS_ori)).data temp.append((torch.cat([sampleB[i] for i in range(sampleB.shape[0])], dim=1)).unsqueeze(0)) tempB.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32 temp = [] for i in range(self.zS_dim): zS = np.zeros((1, self.zS_dim)) zS[0, i % self.zS_dim] = 1. zS = torch.Tensor(zS) zS = torch.cat([zS] * len(fixed_idxs), dim=0) if self.use_cuda: zS = zS.cuda() sampleB = torch.sigmoid(decoderB(zB_ori, zS)).data temp.append((torch.cat([sampleB[i] for i in range(sampleB.shape[0])], dim=1)).unsqueeze(0)) tempB.append(torch.cat(temp, dim=0).unsqueeze(0)) gifs = torch.cat(tempB, dim=0) #torch.Size([11, 10, 1, 384, 32]) # save the generated files, also the animated gifs out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train') mkdirs(self.output_dir_trvsl) mkdirs(out_dir) for j, val in enumerate(interpolationA): # I = torch.cat([IMG[key], gifs[:][j]], dim=0) I = gifs[:,j] save_image( tensor=I.cpu(), filename=os.path.join(out_dir, '%03d.jpg' % (j)), nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim, pad_value=1) # make animated gif grid2gif2( out_dir, str(os.path.join(out_dir, 'fmnist_traverse' + '.gif')), delay=10 ) self.set_mode(train=True) #### def viz_init(self): self.viz.close(env=self.name + '/lines', win=self.win_id['recon']) self.viz.close(env=self.name + '/lines', win=self.win_id['kl']) self.viz.close(env=self.name + '/lines', win=self.win_id['capa']) # if self.eval_metrics: # self.viz.close(env=self.name+'/lines', win=self.win_id['metrics']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) recon_both = torch.Tensor(data['recon_both']) recon_A = torch.Tensor(data['recon_A']) recon_B = torch.Tensor(data['recon_B']) kl_A = torch.Tensor(data['kl_A']) kl_B = torch.Tensor(data['kl_B']) cont_capacity_loss_infA = torch.Tensor(data['cont_capacity_loss_infA']) disc_capacity_loss_infA = torch.Tensor(data['disc_capacity_loss_infA']) cont_capacity_loss_infB = torch.Tensor(data['cont_capacity_loss_infB']) disc_capacity_loss_infB = torch.Tensor(data['disc_capacity_loss_infB']) recons = torch.stack( [recon_both.detach(), recon_A.detach(), recon_B.detach()], -1 ) kls = torch.stack( [kl_A.detach(), kl_B.detach()], -1 ) each_capa = torch.stack( [cont_capacity_loss_infA.detach(), disc_capacity_loss_infA.detach(), cont_capacity_loss_infB.detach(), disc_capacity_loss_infB.detach()], -1 ) self.viz.line( X=iters, Y=recons, env=self.name + '/lines', win=self.win_id['recon'], update='append', opts=dict(xlabel='iter', ylabel='recon losses', title='Recon Losses', legend=['both', 'A', 'B']) ) self.viz.line( X=iters, Y=kls, env=self.name + '/lines', win=self.win_id['kl'], update='append', opts=dict(xlabel='iter', ylabel='kl losses', title='KL Losses', legend=['A', 'B']), ) self.viz.line( X=iters, Y=each_capa, env=self.name + '/lines', win=self.win_id['capa'], update='append', opts=dict(xlabel='iter', ylabel='logalpha', title='Capacity loss', legend=['cont_capaA', 'disc_capaA', 'cont_capaB', 'disc_capaB']), ) #### def visualize_line_metrics(self, iters, metric1, metric2): # prepare data to plot iters = torch.tensor([iters], dtype=torch.int64).detach() metric1 = torch.tensor([metric1]) metric2 = torch.tensor([metric2]) metrics = torch.stack([metric1.detach(), metric2.detach()], -1) self.viz.line( X=iters, Y=metrics, env=self.name + '/lines', win=self.win_id['metrics'], update='append', opts=dict(xlabel='iter', ylabel='metrics', title='Disentanglement metrics', legend=['metric1', 'metric2']) ) def set_mode(self, train=True): if train: self.encoderA.train() self.encoderB.train() self.decoderA.train() self.decoderB.train() else: self.encoderA.eval() self.encoderB.eval() self.decoderA.eval() self.decoderB.eval() #### def save_checkpoint(self, iteration): encoderA_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderA.pt' % iteration ) encoderB_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderB.pt' % iteration ) decoderA_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderA.pt' % iteration ) decoderB_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderB.pt' % iteration ) mkdirs(self.ckpt_dir) torch.save(self.encoderA, encoderA_path) torch.save(self.encoderB, encoderB_path) torch.save(self.decoderA, decoderA_path) torch.save(self.decoderB, decoderB_path) #### def load_checkpoint(self): encoderA_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderA.pt' % self.ckpt_load_iter ) encoderB_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderB.pt' % self.ckpt_load_iter ) decoderA_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderA.pt' % self.ckpt_load_iter ) decoderB_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderB.pt' % self.ckpt_load_iter ) if self.use_cuda: self.encoderA = torch.load(encoderA_path) self.encoderB = torch.load(encoderB_path) self.decoderA = torch.load(decoderA_path) self.decoderB = torch.load(decoderB_path) else: self.encoderA = torch.load(encoderA_path, map_location='cpu') self.encoderB = torch.load(encoderB_path, map_location='cpu') self.decoderA = torch.load(decoderA_path, map_location='cpu') self.decoderB = torch.load(decoderB_path, map_location='cpu')
class Trainer(object): def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.max_epoch = args.max_epoch self.global_epoch = 0 self.global_iter = 0 self.z_dim = args.z_dim self.z_var = args.z_var self.z_sigma = math.sqrt(args.z_var) self.prior_dist = torch.distributions.Normal( torch.zeros(self.z_dim), torch.ones(self.z_dim) * self.z_sigma) self._lambda = args.reg_weight self.lr = args.lr self.lr_D = args.lr_D self.beta1 = args.beta1 self.beta2 = args.beta2 self.lr_schedules = {30: 2, 50: 5, 100: 10} if args.dataset.lower() == 'celeba': self.nc = 3 self.decoder_dist = 'gaussian' else: self.nc = 1 self.decoder_dist = 'gaussian' # raise NotImplementedError self.net = cuda(WAE(self.z_dim, self.nc), self.use_cuda) self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) self.D = cuda(Adversary(self.z_dim), self.use_cuda) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1, self.beta2)) self.gather = DataGather() self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on if self.viz_on: self.viz = visdom.Visdom(env=self.viz_name + '_lines', port=self.viz_port) self.win_recon = None self.win_QD = None self.win_D = None self.win_mu = None self.win_var = None else: self.viz = None self.win_recon = None self.win_QD = None self.win_D = None self.win_mu = None self.win_var = None self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name) if not self.ckpt_dir.exists(): self.ckpt_dir.mkdir(parents=True, exist_ok=True) self.ckpt_name = args.ckpt_name if self.ckpt_name is not None: self.load_checkpoint(self.ckpt_name) self.save_output = args.save_output self.output_dir = Path(args.output_dir).joinpath(args.viz_name) if not self.output_dir.exists(): self.output_dir.mkdir(parents=True, exist_ok=True) self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) def train(self): self.net.train() ones = Variable(cuda(torch.ones(self.batch_size, 1), self.use_cuda)) zeros = Variable(cuda(torch.zeros(self.batch_size, 1), self.use_cuda)) iters_per_epoch = len(self.data_loader) max_iter = self.max_epoch * iters_per_epoch pbar = tqdm(total=max_iter) with tqdm(total=max_iter) as pbar: pbar.update(self.global_iter) out = False while not out: for x in self.data_loader: #x,label = x pbar.update(1) self.global_iter += 1 if self.global_iter % iters_per_epoch == 0: self.global_epoch += 1 self.optim = multistep_lr_decay(self.optim, self.global_epoch, self.lr_schedules) x = Variable(cuda(x, self.use_cuda)) x_recon, z_tilde = self.net(x) z = self.sample_z(template=z_tilde, sigma=self.z_sigma) log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1) #D_z = self.D(z) + log_p_z.view(-1, 1) #D_z_tilde = self.D(z_tilde) + log_p_z.view(-1, 1) D_z = self.D(z) D_z_tilde = self.D(z_tilde) D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \ F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros) total_D_loss = self._lambda * D_loss self.optim_D.zero_grad() total_D_loss.backward(retain_graph=True) self.optim_D.step() recon_loss = F.mse_loss( x_recon, x, size_average=False).div(self.batch_size) Q_loss = F.binary_cross_entropy_with_logits( D_z_tilde + log_p_z, ones) total_AE_loss = recon_loss + self._lambda * Q_loss self.optim.zero_grad() total_AE_loss.backward() self.optim.step() if self.global_iter % 10 == 0: self.gather.insert( iter=self.global_iter, D_z=F.sigmoid(D_z).mean().detach().data, D_z_tilde=F.sigmoid( D_z_tilde).mean().detach().data, mu=z.mean(0).data, var=z.var(0).data, recon_loss=recon_loss.data, Q_loss=Q_loss.data, D_loss=D_loss.data) if self.global_iter % 50 == 0: self.save_reconstruction() if self.viz: self.gather.insert(images=x.data) self.gather.insert(images=x_recon.data) self.viz_reconstruction() self.viz_lines() self.sample_x_from_z(n_sample=100) self.gather.flush() self.save_checkpoint('last') pbar.write( '[{}] recon_loss:{:.3f} Q_loss:{:.3f} D_loss:{:.3f}' .format(self.global_iter, recon_loss.data[0], Q_loss.data[0], D_loss.data[0])) pbar.write('D_z:{:.3f} D_z_tilde:{:.3f}'.format( F.sigmoid(D_z).mean().detach().data[0], F.sigmoid(D_z_tilde).mean().detach().data[0])) if self.global_iter % 2000 == 0: self.save_checkpoint(str(self.global_iter)) if self.global_iter >= max_iter: out = True break pbar.write("[Training Finished]") def viz_reconstruction(self): self.net.eval() x = self.gather.data['images'][0][:100] x = make_grid(x, normalize=True, nrow=10) x_recon = F.sigmoid(self.gather.data['images'][1][:100]) x_recon = make_grid(x_recon, normalize=True, nrow=10) images = torch.stack([x, x_recon], dim=0).cpu() if self.viz: self.viz.images(images, env=self.viz_name + '_reconstruction', opts=dict(title=str(self.global_iter)), nrow=2) self.net.train() def save_reconstruction(self): self.net.eval() import numpy as np for item in self.data_loader: x = Variable(cuda(item, self.use_cuda)) x_recon, z_tilde = self.net(x) x_recon = x_recon.data[:5] x = x.data[:5] #x_grid = make_grid(x, normalize=True, nrow=10) #x_recon = F.sigmoid(x_recon) #x_grid_recon = make_grid(x_recon, normalize=True, nrow=10) #images = torch.stack([x_grid, x_grid_recon], dim=0).cpu() images = torch.stack([x, x_recon], dim=0).cpu() np.save('reconstruction.npy', images.numpy()) break self.net.train() def viz_lines(self): self.net.eval() recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() Q_losses = torch.stack(self.gather.data['Q_loss']).cpu() D_losses = torch.stack(self.gather.data['D_loss']).cpu() QD_losses = torch.cat([Q_losses, D_losses], 1) D_zs = torch.stack(self.gather.data['D_z']).cpu() D_z_tildes = torch.stack(self.gather.data['D_z_tilde']).cpu() Ds = torch.cat([D_zs, D_z_tildes], 1) mus = torch.stack(self.gather.data['mu']).cpu() vars = torch.stack(self.gather.data['var']).cpu() iters = torch.Tensor(self.gather.data['iter']) legend_z = [] for z_j in range(self.z_dim): legend_z.append('z_{}'.format(z_j)) legend_QD = ['Q_loss', 'D_loss'] legend_D = ['D(z)', 'D(z_tilde)'] if self.win_recon is None: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) else: self.win_recon = self.viz.line(X=iters, Y=recon_losses, env=self.viz_name + '_lines', win=self.win_recon, update='append', opts=dict( width=400, height=400, xlabel='iteration', title='reconsturction loss', )) if self.win_QD is None: self.win_QD = self.viz.line(X=iters, Y=QD_losses, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend_QD, xlabel='iteration', title='Q&D Losses', )) else: self.win_QD = self.viz.line(X=iters, Y=QD_losses, env=self.viz_name + '_lines', win=self.win_QD, update='append', opts=dict( width=400, height=400, legend=legend_QD, xlabel='iteration', title='Q&D Losses', )) if self.win_D is None: self.win_D = self.viz.line(X=iters, Y=Ds, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend_D, xlabel='iteration', title='D(.)', )) else: self.win_D = self.viz.line(X=iters, Y=Ds, env=self.viz_name + '_lines', win=self.win_D, update='append', opts=dict( width=400, height=400, legend=legend_D, xlabel='iteration', title='D(.)', )) if self.win_mu is None: self.win_mu = self.viz.line(X=iters, Y=mus, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend_z, xlabel='iteration', title='posterior mean', )) else: self.win_mu = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', win=self.win_mu, update='append', opts=dict( width=400, height=400, legend=legend_z, xlabel='iteration', title='posterior mean', )) if self.win_var is None: self.win_var = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', opts=dict( width=400, height=400, legend=legend_z, xlabel='iteration', title='posterior variance', )) else: self.win_var = self.viz.line(X=iters, Y=vars, env=self.viz_name + '_lines', win=self.win_var, update='append', opts=dict( width=400, height=400, legend=legend_z, xlabel='iteration', title='posterior variance', )) self.net.train() def sample_z(self, n_sample=None, dim=None, sigma=None, template=None): if n_sample is None: n_sample = self.batch_size if dim is None: dim = self.z_dim if sigma is None: sigma = self.z_sigma if template is not None: z = sigma * Variable(template.data.new(template.size()).normal_()) else: z = sigma * torch.randn(n_sample, dim) z = Variable(cuda(z, self.use_cuda)) return z def sample_x_from_z(self, n_sample): self.net.eval() z = self.sample_z(n_sample=n_sample, sigma=self.z_sigma) x_gen = F.sigmoid(self.net._decode(z)[:100]).data.cpu() x_gen = make_grid(x_gen, normalize=True, nrow=10) self.viz.images(x_gen, env=self.viz_name + '_sampling_from_random_z', opts=dict(title=str(self.global_iter))) self.net.train() def save_checkpoint(self, filename, silent=True): model_states = { 'net': self.net.state_dict(), 'D': self.D.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), 'optim_D': self.optim_D.state_dict() } win_states = { 'recon': self.win_recon, 'QD': self.win_QD, 'D': self.win_D, 'mu': self.win_mu, 'var': self.win_var, } states = { 'iter': self.global_iter, 'epoch': self.global_epoch, 'win_states': win_states, 'model_states': model_states, 'optim_states': optim_states } file_path = self.ckpt_dir.joinpath(filename) torch.save(states, file_path.open('wb+')) if not silent: print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename, silent=False): file_path = self.ckpt_dir.joinpath(filename) print(file_path) if file_path.is_file(): checkpoint = torch.load(file_path.open('rb')) self.global_iter = checkpoint['iter'] self.global_epoch = checkpoint['epoch'] self.win_recon = checkpoint['win_states']['recon'] self.win_QD = checkpoint['win_states']['QD'] self.win_D = checkpoint['win_states']['D'] self.win_var = checkpoint['win_states']['var'] self.win_mu = checkpoint['win_states']['mu'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) self.D.load_state_dict(checkpoint['model_states']['D']) self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D']) if not silent: print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: if not silent: print("=> no checkpoint found at '{}'".format(file_path))
class Solver(object): #### def __init__(self, args): self.args = args args.num_sg = args.load_e self.name = '%s_bs%s_zD_%s_dr_mlp_%s_dr_rnn_%s_enc_hD_%s_dec_hD_%s_mlpD_%s_lr_%s_klw_%s_ll_prior_w_%s_zfb_%s_scale_%s_num_sg_%s' \ 'ctxtD_%s_coll_th_%s_w_coll_%s_beta_%s_lr_e_%s_k_%s' % \ (args.dataset_name, args.batch_size, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.encoder_h_dim, args.decoder_h_dim, args.mlp_dim, args.lr_VAE, args.kl_weight, args.ll_prior_w, args.fb, args.scale, args.num_sg, args.context_dim, args.coll_th, args.w_coll, args.beta, args.lr_e, args.k_fold) # to be appended by run_id # self.use_cuda = args.cuda and torch.cuda.is_available() self.device = args.device self.temp = 1.99 self.dt = 0.4 self.eps = 1e-9 self.ll_prior_w = args.ll_prior_w self.sg_idx = np.array(range(12)) self.sg_idx = np.flip(11 - self.sg_idx[::(12 // args.num_sg)]) self.coll_th = args.coll_th self.beta = args.beta self.context_dim = args.context_dim self.w_coll = args.w_coll self.z_fb = args.fb self.scale = args.scale self.kl_weight = args.kl_weight self.lg_kl_weight = args.lg_kl_weight self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold)) self.dataset_dir = args.dataset_dir self.dataset_name = args.dataset_name # self.N = self.latent_values.shape[0] # self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.zS_dim = args.zS_dim self.w_dim = args.w_dim self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE print(args.desc) # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records") mkdirs("ckpts") mkdirs("outputs") # set run id if args.run_id < 0: # create a new id k = 0 rfname = os.path.join("records", self.name + '_run_0.txt') while os.path.exists(rfname): k += 1 rfname = os.path.join("records", self.name + '_run_%d.txt' % k) self.run_id = k else: # user-provided id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict(recon='win_recon', loss_kl='win_loss_kl', loss_recon='win_loss_recon', ade_min='win_ade_min', fde_min='win_fde_min', ade_avg='win_ade_avg', fde_avg='win_fde_avg', ade_std='win_ade_std', fde_std='win_fde_std', test_loss_recon='win_test_loss_recon', test_loss_kl='win_test_loss_kl', loss_recon_prior='win_loss_recon_prior', loss_coll='win_loss_coll', test_loss_coll='win_test_loss_coll', test_total_coll='win_test_total_coll', total_coll='win_total_coll') self.line_gather = DataGather( 'iter', 'loss_recon', 'loss_kl', 'loss_recon_prior', 'ade_min', 'fde_min', 'ade_avg', 'fde_avg', 'ade_std', 'fde_std', 'test_loss_recon', 'test_loss_kl', 'test_loss_coll', 'loss_coll', 'test_total_coll', 'total_coll') self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port, env=self.name) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter self.obs_len = 8 self.pred_len = 12 self.num_layers = args.num_layers self.decoder_h_dim = args.decoder_h_dim if self.ckpt_load_iter == 0 or args.dataset_name == 'all': # create a new model lg_cvae_path = 'large.lgcvae_enc_block_1_fcomb_block_2_wD_10_lr_0.0001_lg_klw_1.0_a_0.25_r_2.0_fb_5.0_anneal_e_10_load_e_3_run_4' lg_cvae_path = os.path.join('ckpts', lg_cvae_path, 'iter_150_lg_cvae.pt') if self.device == 'cuda': self.lg_cvae = torch.load(lg_cvae_path) self.encoderMx = EncoderX(args.zS_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, map_mlp_dim=args.map_mlp_dim, map_feat_dim=args.map_feat_dim, num_layers=args.num_layers, dropout_mlp=args.dropout_mlp, dropout_rnn=args.dropout_rnn, device=self.device).to(self.device) self.encoderMy = EncoderY(args.zS_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, num_layers=args.num_layers, dropout_mlp=args.dropout_mlp, dropout_rnn=args.dropout_rnn, device=self.device).to(self.device) self.decoderMy = Decoder(args.pred_len, dec_h_dim=self.decoder_h_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, z_dim=args.zS_dim, num_layers=args.num_layers, device=args.device, dropout_rnn=args.dropout_rnn, scale=args.scale, dt=self.dt, context_dim=args.context_dim).to( self.device) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') # get VAE parameters vae_params = \ list(self.encoderMx.parameters()) + \ list(self.encoderMy.parameters()) + \ list(self.decoderMy.parameters()) # create optimizers self.optim_vae = optim.Adam(vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE]) self.scheduler = optim.lr_scheduler.LambdaLR( optimizer=self.optim_vae, lr_lambda=lambda epoch: args.lr_e**epoch) print('Start loading data...') if self.ckpt_load_iter != self.max_iter: print("Initializing train dataset") _, self.train_loader = data_loader(self.args, args.dataset_dir, 'train', shuffle=True) print("Initializing val dataset") _, self.val_loader = data_loader(self.args, args.dataset_dir, 'val', shuffle=True) print('There are {} iterations per epoch'.format( len(self.train_loader.dataset) / args.batch_size)) print('...done') def temmp(self): aa = torch.zeros((120, 2, 256, 256)).to(self.device) self.lg_cvae.unet.down_forward(aa) #### def train(self): self.set_mode(train=True) data_loader = self.train_loader self.N = len(data_loader.dataset) iterator = iter(data_loader) iter_per_epoch = len(iterator) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) + 1 e_coll_loss = 0 e_total_coll = 0 for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: # print(iteration) print('==== epoch %d done ====' % epoch) if epoch % 10 == 0: if self.optim_vae.param_groups[0]['lr'] > 5e-4: self.scheduler.step() else: self.optim_vae.param_groups[0]['lr'] = 5e-4 print("lr: ", self.optim_vae.param_groups[0]['lr'], ' // w_coll: ', self.w_coll) print('e_coll_loss: ', e_coll_loss, ' // e_total_coll: ', e_total_coll) epoch += 1 iterator = iter(data_loader) prev_e_coll_loss = e_coll_loss prev_e_total_coll = e_total_coll e_coll_loss = 0 e_total_coll = 0 # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, fut_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = next(iterator) batch_size = obs_traj.size( 1) #=sum(seq_start_end[:,1] - seq_start_end[:,0]) #-------- trajectories -------- (hx, mux, log_varx) \ = self.encoderMx(obs_traj_st, seq_start_end, train=True) (muy, log_vary) \ = self.encoderMy(obs_traj_st[-1], fut_vel_st, seq_start_end, hx, train=True) p_dist = Normal( mux, torch.clamp(torch.sqrt(torch.exp(log_varx)), min=1e-8)) q_dist = Normal( muy, torch.clamp(torch.sqrt(torch.exp(log_vary)), min=1e-8)) # TF, goals, z~posterior fut_rel_pos_dist_tf_post = self.decoderMy( seq_start_end, obs_traj_st[-1], obs_traj[-1, :, :2], hx, q_dist.rsample(), fut_traj[list(self.sg_idx), :, :2].permute(1, 0, 2), # goal self.sg_idx, fut_vel_st, # TF train=True) # NO TF, predicted goals, z~prior fut_rel_pos_dist_prior = self.decoderMy( seq_start_end, obs_traj_st[-1], obs_traj[-1, :, :2], hx, p_dist.rsample(), fut_traj[list(self.sg_idx), :, :2].permute(1, 0, 2), # goal self.sg_idx, train=True) ll_tf_post = fut_rel_pos_dist_tf_post.log_prob( fut_vel_st).sum().div(batch_size) ll_prior = fut_rel_pos_dist_prior.log_prob(fut_vel_st).sum().div( batch_size) loss_kl = kl_divergence(q_dist, p_dist) loss_kl = torch.clamp(loss_kl, min=self.z_fb).sum().div(batch_size) # print('log_likelihood:', loglikelihood.item(), ' kl:', loss_kl.item()) loglikelihood = ll_tf_post + self.ll_prior_w * ll_prior traj_elbo = loglikelihood - self.kl_weight * loss_kl coll_loss = torch.tensor(0.0).to(self.device) total_coll = 0 n_scene = 0 if self.w_coll > 0: pred_fut_traj = integrate_samples( fut_rel_pos_dist_prior.rsample() * self.scale, obs_traj[-1, :, :2], dt=self.dt) pred_fut_traj_post = integrate_samples( fut_rel_pos_dist_tf_post.rsample() * self.scale, obs_traj[-1, :, :2], dt=self.dt) for s, e in seq_start_end: n_scene += 1 num_ped = e - s if num_ped == 1: continue for t in range(self.pred_len): ## prior curr1 = pred_fut_traj[t, s:e].repeat(num_ped, 1) curr2 = self.repeat(pred_fut_traj[t, s:e], num_ped) dist = torch.norm(curr1 - curr2, dim=1) dist = dist.reshape(num_ped, num_ped) diff_agent_dist = dist[torch.where(dist > 0)] coll_loss += (torch.sigmoid( -(diff_agent_dist - self.coll_th) * self.beta)).sum() total_coll += ( len(torch.where(diff_agent_dist < 0.5)[0]) / 2) ## posterior curr1_post = pred_fut_traj_post[t, s:e].repeat( num_ped, 1) curr2_post = self.repeat(pred_fut_traj_post[t, s:e], num_ped) dist_post = torch.norm(curr1_post - curr2_post, dim=1) dist_post = dist_post.reshape(num_ped, num_ped) diff_agent_dist_post = dist_post[torch.where( dist_post > 0)] coll_loss += (torch.sigmoid( -(diff_agent_dist_post - self.coll_th) * self.beta)).sum() total_coll += ( len(torch.where(diff_agent_dist_post < 0.5)[0]) / 2) coll_loss = coll_loss.div(batch_size) total_coll = total_coll / batch_size loss = -traj_elbo + self.w_coll * coll_loss e_coll_loss += coll_loss.item() e_total_coll += total_coll self.optim_vae.zero_grad() loss.backward() self.optim_vae.step() # save model parameters if epoch > 100 and (iteration % (iter_per_epoch * 20) == 0): self.save_checkpoint(epoch) # (visdom) insert current line stats if epoch > 100: if iteration == iter_per_epoch or ( self.viz_on and (iteration % (iter_per_epoch * 20) == 0)): ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ test_loss_recon, test_loss_kl, test_loss_coll, test_total_coll = self.evaluate_dist(self.val_loader, loss=True) self.line_gather.insert( iter=epoch, ade_min=ade_min, fde_min=fde_min, ade_avg=ade_avg, fde_avg=fde_avg, ade_std=ade_std, fde_std=fde_std, loss_recon=-ll_tf_post.item(), loss_recon_prior=-ll_prior.item(), loss_kl=loss_kl.item(), loss_coll=prev_e_coll_loss, total_coll=prev_e_total_coll, test_loss_recon=test_loss_recon.item(), test_loss_kl=test_loss_kl.item(), test_loss_coll=test_loss_coll.item(), test_total_coll=test_total_coll) prn_str = ('[iter_%d (epoch_%d)] vae_loss: %.3f ' + \ '(recon: %.3f, kl: %.3f)\n' + \ 'ADE min: %.2f, FDE min: %.2f, ADE avg: %.2f, FDE avg: %.2f\n' ) % \ (iteration, epoch, loss.item(), -loglikelihood.item(), loss_kl.item(), ade_min, fde_min, ade_avg, fde_avg ) print(prn_str) self.visualize_line() self.line_gather.flush() def repeat(self, tensor, num_reps): """ Inputs: -tensor: 2D tensor of any shape -num_reps: Number of times to repeat each row Outpus: -repeat_tensor: Repeat each row such that: R1, R1, R2, R2 """ col_len = tensor.size(1) tensor = tensor.unsqueeze(dim=1).repeat(1, num_reps, 1) tensor = tensor.view(-1, col_len) return tensor def evaluate_dist(self, data_loader, loss=False): self.set_mode(train=False) total_traj = 0 loss_recon = loss_kl = 0 coll_loss = 0 total_coll = 0 n_scene = 0 all_ade = [] all_fde = [] with torch.no_grad(): b = 0 for batch in data_loader: b += 1 (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, fut_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = batch batch_size = fut_traj.size(1) total_traj += fut_traj.size(1) # -------- trajectories -------- (hx, mux, log_varx) \ = self.encoderMx(obs_traj_st, seq_start_end) p_dist = Normal( mux, torch.clamp(torch.sqrt(torch.exp(log_varx)), min=1e-8)) fut_rel_pos_dist20 = [] for _ in range(4): # NO TF, pred_goals, z~prior fut_rel_pos_dist_prior = self.decoderMy( seq_start_end, obs_traj_st[-1], obs_traj[-1, :, :2], hx, p_dist.rsample(), fut_traj[list(self.sg_idx), :, :2].permute(1, 0, 2), # goal self.sg_idx, ) fut_rel_pos_dist20.append(fut_rel_pos_dist_prior) if loss: (muy, log_vary) \ = self.encoderMy(obs_traj_st[-1], fut_vel_st, seq_start_end, hx, train=False) q_dist = Normal(muy, torch.sqrt(torch.exp(log_vary))) loss_recon -= fut_rel_pos_dist_prior.log_prob( fut_vel_st).sum().div(batch_size) kld = kl_divergence(q_dist, p_dist).sum().div(batch_size) loss_kl += kld pred_fut_traj = integrate_samples( fut_rel_pos_dist_prior.rsample() * self.scale, obs_traj[-1, :, :2], dt=self.dt) for s, e in seq_start_end: n_scene += 1 num_ped = e - s if num_ped == 1: continue seq_traj = pred_fut_traj[:, s:e] for i in range(len(seq_traj)): curr1 = seq_traj[i].repeat(num_ped, 1) curr2 = self.repeat(seq_traj[i], num_ped) dist = torch.norm(curr1 - curr2, dim=1) dist = dist.reshape(num_ped, num_ped) diff_agent_dist = dist[torch.where(dist > 0)] if len(diff_agent_dist) > 0: # diff_agent_dist[torch.where(diff_agent_dist > self.coll_th)] += self.beta coll_loss += (torch.sigmoid( -(diff_agent_dist - self.coll_th) * self.beta)).sum().div(batch_size) total_coll += (len( torch.where(diff_agent_dist < 0.5)[0]) / 2) / batch_size ade, fde = [], [] for dist in fut_rel_pos_dist20: pred_fut_traj = integrate_samples(dist.rsample() * self.scale, obs_traj[-1, :, :2], dt=self.dt) ade.append( displacement_error(pred_fut_traj, fut_traj[:, :, :2], mode='raw')) fde.append( final_displacement_error(pred_fut_traj[-1], fut_traj[-1, :, :2], mode='raw')) all_ade.append(torch.stack(ade)) all_fde.append(torch.stack(fde)) all_ade = torch.cat(all_ade, dim=1).cpu().numpy() all_fde = torch.cat(all_fde, dim=1).cpu().numpy() ade_min = np.min(all_ade, axis=0).mean() / self.pred_len fde_min = np.min(all_fde, axis=0).mean() ade_avg = np.mean(all_ade, axis=0).mean() / self.pred_len fde_avg = np.mean(all_fde, axis=0).mean() ade_std = np.std(all_ade, axis=0).mean() / self.pred_len fde_std = np.std(all_fde, axis=0).mean() self.set_mode(train=True) if loss: return ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, \ loss_recon/b, loss_kl/b, coll_loss/b, total_coll else: return ade_min, fde_min, \ ade_avg, fde_avg, \ ade_std, fde_std, def collision_stat(self, data_loader): self.set_mode(train=False) total_coll1 = 0 total_coll2 = 0 total_coll3 = 0 total_coll4 = 0 total_coll5 = 0 total_coll6 = 0 n_scene = 0 total_ped = [] e_ped = [] avg_dist = 0 min_dist = 10000 n_agent = 0 with torch.no_grad(): b = 0 for batch in data_loader: b += 1 (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end, obs_frames, fut_frames, map_path, inv_h_t, local_map, local_ic, local_homo) = batch for s, e in seq_start_end: n_scene += 1 num_ped = e - s total_ped.append(num_ped) if num_ped == 1: continue e_ped.append(num_ped) seq_traj = fut_traj[:, s:e, :2] for i in range(len(seq_traj)): curr1 = seq_traj[i].repeat(num_ped, 1) curr2 = self.repeat(seq_traj[i], num_ped) dist = torch.sqrt(torch.pow(curr1 - curr2, 2).sum(1)).cpu().numpy() dist = dist.reshape(num_ped, num_ped) diff_agent_idx = np.triu_indices(num_ped, k=1) diff_agent_dist = dist[diff_agent_idx] avg_dist += diff_agent_dist.sum() min_dist = min(min_dist, diff_agent_dist.min()) n_agent += len(diff_agent_dist) total_coll1 += (diff_agent_dist < 0.05).sum() total_coll2 += (diff_agent_dist < 0.1).sum() total_coll3 += (diff_agent_dist < 0.2).sum() total_coll4 += (diff_agent_dist < 0.3).sum() total_coll5 += (diff_agent_dist < 0.4).sum() total_coll6 += (diff_agent_dist < 0.5).sum() print('total_coll1: ', total_coll1) print('total_coll2: ', total_coll2) print('total_coll3: ', total_coll3) print('total_coll4: ', total_coll4) print('total_coll5: ', total_coll5) print('total_coll6: ', total_coll6) print('n_scene: ', n_scene) print('e_ped:', len(e_ped)) print('total_ped:', len(total_ped)) print('avg_dist:', avg_dist / n_agent) print('min_dist:', avg_dist / n_agent) print('e_ped:', np.array(e_ped).mean()) print('total_ped:', np.array(total_ped).mean()) #### def viz_init(self): self.viz.close(env=self.name, win=self.win_id['loss_recon']) self.viz.close(env=self.name, win=self.win_id['loss_recon_prior']) self.viz.close(env=self.name, win=self.win_id['loss_kl']) self.viz.close(env=self.name, win=self.win_id['test_loss_recon']) self.viz.close(env=self.name, win=self.win_id['test_loss_kl']) self.viz.close(env=self.name, win=self.win_id['ade_min']) self.viz.close(env=self.name, win=self.win_id['fde_min']) self.viz.close(env=self.name, win=self.win_id['ade_avg']) self.viz.close(env=self.name, win=self.win_id['fde_avg']) self.viz.close(env=self.name, win=self.win_id['ade_std']) self.viz.close(env=self.name, win=self.win_id['fde_std']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) loss_recon = torch.Tensor(data['loss_recon']) loss_recon_prior = torch.Tensor(data['loss_recon_prior']) loss_kl = torch.Tensor(data['loss_kl']) ade_min = torch.Tensor(data['ade_min']) fde_min = torch.Tensor(data['fde_min']) ade_avg = torch.Tensor(data['ade_avg']) fde_avg = torch.Tensor(data['fde_avg']) ade_std = torch.Tensor(data['ade_std']) fde_std = torch.Tensor(data['fde_std']) test_loss_recon = torch.Tensor(data['test_loss_recon']) test_loss_kl = torch.Tensor(data['test_loss_kl']) test_loss_coll = torch.Tensor(data['test_loss_coll']) loss_coll = torch.Tensor(data['loss_coll']) total_coll = torch.Tensor(data['total_coll']) test_total_coll = torch.Tensor(data['test_total_coll']) self.viz.line(X=iters, Y=total_coll, env=self.name, win=self.win_id['total_coll'], update='append', opts=dict(xlabel='iter', ylabel='total_coll', title='total_coll')) self.viz.line(X=iters, Y=test_total_coll, env=self.name, win=self.win_id['test_total_coll'], update='append', opts=dict(xlabel='iter', ylabel='test_total_coll', title='test_total_coll')) self.viz.line(X=iters, Y=test_loss_coll, env=self.name, win=self.win_id['test_loss_coll'], update='append', opts=dict(xlabel='iter', ylabel='test_loss_coll', title='test_loss_coll')) self.viz.line(X=iters, Y=loss_coll, env=self.name, win=self.win_id['loss_coll'], update='append', opts=dict(xlabel='iter', ylabel='loss_coll', title='loss_coll')) self.viz.line(X=iters, Y=loss_recon, env=self.name, win=self.win_id['loss_recon'], update='append', opts=dict(xlabel='iter', ylabel='-loglikelihood', title='Recon. loss of predicted future traj')) self.viz.line(X=iters, Y=loss_recon_prior, env=self.name, win=self.win_id['loss_recon_prior'], update='append', opts=dict(xlabel='iter', ylabel='-loglikelihood', title='Recon. loss - prior')) self.viz.line( X=iters, Y=loss_kl, env=self.name, win=self.win_id['loss_kl'], update='append', opts=dict(xlabel='iter', ylabel='kl divergence', title='KL div. btw posterior and c. prior'), ) self.viz.line(X=iters, Y=test_loss_recon, env=self.name, win=self.win_id['test_loss_recon'], update='append', opts=dict( xlabel='iter', ylabel='-loglikelihood', title='Test Recon. loss of predicted future traj')) self.viz.line( X=iters, Y=test_loss_kl, env=self.name, win=self.win_id['test_loss_kl'], update='append', opts=dict(xlabel='iter', ylabel='kl divergence', title='Test KL div. btw posterior and c. prior'), ) self.viz.line( X=iters, Y=ade_min, env=self.name, win=self.win_id['ade_min'], update='append', opts=dict(xlabel='iter', ylabel='ade', title='ADE min'), ) self.viz.line( X=iters, Y=fde_min, env=self.name, win=self.win_id['fde_min'], update='append', opts=dict(xlabel='iter', ylabel='fde', title='FDE min'), ) self.viz.line( X=iters, Y=ade_avg, env=self.name, win=self.win_id['ade_avg'], update='append', opts=dict(xlabel='iter', ylabel='ade', title='ADE avg'), ) self.viz.line( X=iters, Y=fde_avg, env=self.name, win=self.win_id['fde_avg'], update='append', opts=dict(xlabel='iter', ylabel='fde', title='FDE avg'), ) self.viz.line( X=iters, Y=ade_std, env=self.name, win=self.win_id['ade_std'], update='append', opts=dict(xlabel='iter', ylabel='ade std', title='ADE std'), ) self.viz.line( X=iters, Y=fde_std, env=self.name, win=self.win_id['fde_std'], update='append', opts=dict(xlabel='iter', ylabel='fde std', title='FDE std'), ) def set_mode(self, train=True): if train: self.encoderMx.train() self.encoderMy.train() self.decoderMy.train() else: self.encoderMx.eval() self.encoderMy.eval() self.decoderMy.eval() #### def save_checkpoint(self, iteration): encoderMx_path = os.path.join(self.ckpt_dir, 'iter_%s_encoderMx.pt' % iteration) encoderMy_path = os.path.join(self.ckpt_dir, 'iter_%s_encoderMy.pt' % iteration) decoderMy_path = os.path.join(self.ckpt_dir, 'iter_%s_decoderMy.pt' % iteration) lg_cvae_path = os.path.join(self.ckpt_dir, 'iter_%s_lg_cvae.pt' % iteration) sg_unet_path = os.path.join(self.ckpt_dir, 'iter_%s_sg_unet.pt' % iteration) mkdirs(self.ckpt_dir) torch.save(self.encoderMx, encoderMx_path) torch.save(self.encoderMy, encoderMy_path) torch.save(self.decoderMy, decoderMy_path) #### def load_checkpoint(self): encoderMx_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderMx.pt' % self.ckpt_load_iter) encoderMy_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderMy.pt' % self.ckpt_load_iter) decoderMy_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderMy.pt' % self.ckpt_load_iter) lg_cvae_path = os.path.join(self.ckpt_dir, 'iter_%s_lg_cvae.pt' % self.ckpt_load_iter) sg_unet_path = os.path.join(self.ckpt_dir, 'iter_%s_sg_unet.pt' % self.ckpt_load_iter) if self.device == 'cuda': self.encoderMx = torch.load(encoderMx_path) self.encoderMy = torch.load(encoderMy_path) self.decoderMy = torch.load(decoderMy_path) else: self.encoderMx = torch.load(encoderMx_path, map_location='cpu') self.encoderMy = torch.load(encoderMy_path, map_location='cpu') self.decoderMy = torch.load(decoderMy_path, map_location='cpu')
class Solver(object): def __init__(self, args): # Misc use_cuda = args.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.name = args.name self.max_iter = int(args.max_iter) self.print_iter = args.print_iter self.global_iter = 0 self.test_count = 0 self.pbar = tqdm(total=self.max_iter) # Data self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) # Networks & Optimizers self.z_dim = args.z_dim self.gamma = args.gamma self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D if args.dataset == 'dsprites': self.VAE = FactorVAE1(self.z_dim).to(self.device) self.nc = 1 elif args.dataset == 'mnist': self.VAE = Custom_FactorVAE2(self.z_dim).to(self.device) self.nc = 3 elif args.dataset == 'load_mnist': self.VAE = Custom_FactorVAE2(self.z_dim).to(self.device) self.nc = 3 elif args.dataset == 'glove/numpy_vector/300d_wiki.npy': self.VAE = Glove_FactorVAE1(self.z_dim).to(self.device) self.nc = 3 else: self.VAE = FactorVAE2(self.z_dim).to(self.device) self.nc = 3 self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE)) self.D = Discriminator(self.z_dim).to(self.device) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1_D, self.beta2_D)) self.nets = [self.VAE, self.D] # Visdom self.viz_on = args.viz_on self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc') self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc') self.image_gather = DataGather('true', 'recon') if self.viz_on: self.viz_port = args.viz_port self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_ra_iter = args.viz_ra_iter self.viz_ta_iter = args.viz_ta_iter if not self.viz.win_exists(env=self.name+'/lines', win=self.win_id['D_z']): self.viz_init() # Checkpoint self.ckpt_dir = os.path.join(args.ckpt_dir, args.name) self.ckpt_save_iter = args.ckpt_save_iter mkdirs(self.ckpt_dir) if args.ckpt_load: self.load_checkpoint(args.ckpt_load) # Output(latent traverse GIF) self.output_dir = os.path.join(args.output_dir, args.name) self.output_save = args.output_save mkdirs(self.output_dir) def custom_loss(self, x): #lossは交差エントロピーを採用している, MSEの事例もある #https://tips-memo.com/vae-pytorch#i-7, http://aidiary.hatenablog.com/entry/20180228/1519828344のlossを参考 mean, var = self.VAE._encoder(x) #KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var)) #オリジナル, mean意味わからんけど, あんまり値が変わらないか>ら #上手くいくんじゃないか #KL = 0.5 * torch.sum(torch.exp(var) + mean**2 - 1. - var) KL = -0.5 * torch.sum(1 + var - mean.pow(2) - var.exp()) # sumを行っているのは各次元ごとに算出しているため #print("KL: " + str(KL)) z = self.VAE._sample_z(mean, var) y = self.VAE._decoder(z) #delta = 1e-8 #reconstruction = torch.mean(torch.sum(x * torch.log(y + delta) + (1 - x) * torch.log(1 - y + delta))) #reconstruction = F.binary_cross_entropy(y, x.view(-1, 784), size_average=False) reconstruction = F.binary_cross_entropy(y, x, size_average=False) #交差エントロピー誤差を利用して, 対数尤度の最大化を行っている, 2つのみ=(1-x), (1-y)で算出可能 #http://aidiary.hatenablog.com/entry/20180228/1519828344(参考記事) #print("reconstruction: " + str(reconstruction)) #lower_bound = [-KL, reconstruction] #両方とも小さくしたい, クロスエントロピーは本来マイナス, KLは小さくしたいからプラスに変換 #returnで恐らくわかりやすくするために, 目的関数から誤差関数への変換をしている #return -sum(lower_bound) return KL + reconstruction def train(self): self.net_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device) out = False while not out: for x_true1, x_true2 in self.data_loader:#ここで読み込んでいる? self.global_iter += 1 self.pbar.update(1) if self.dataset == 'mnist': x_true1 = x_true1.view(x_true1.shape[0], -1) x_true1 = x_true1.to(self.device) x_recon, mu, logvar, z = self.VAE(x_true1) x = x_true1.view(x_true1.shape[0], -1) #custom #vae_recon_loss = self.custom_loss(x) / self.batch_size #custom vae_recon_loss = recon_loss(x, x_recon) #復元誤差, 交差エントロピー誤差 vae_kld = kl_divergence(mu, logvar) D_z = self.D(z) vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #恐らく, discriminatorのloss vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss #vae_loss = vae_recon_loss + self.gamma*vae_tc_loss self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() x_true2 = x_true2.to(self.device) #x_true2 = x_true2.view(x_true2.shape[0], -1) z_prime = self.VAE(x_true2, no_dec=True) #trueにすることで潜在空間に写像した状態のデータを獲得? z_pperm = permute_dims(z_prime).detach() D_z_pperm = self.D(z_pperm) D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) #GANのdiscriminatorっぽい?偽物と本物 #そのため誤差の部分が0と1になっているはず!zerosとonesの部分 self.optim_D.zero_grad() D_tc_loss.backward() self.optim_D.step() #if self.global_iter%self.print_iter == 0: # self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( # self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) if self.test_count % 547 == 0: #self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( #self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item())) self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format( self.global_iter, vae_recon_loss.item(), vae_tc_loss.item(), D_tc_loss.item())) self.test_count = 0 if self.global_iter%self.ckpt_save_iter == 0: self.save_checkpoint(self.global_iter) if self.viz_on and (self.global_iter%self.viz_ll_iter == 0): soft_D_z = F.softmax(D_z, 1)[:, :1].detach() soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach() D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float() D_acc /= 2*self.batch_size self.line_gather.insert(iter=self.global_iter, soft_D_z=soft_D_z.mean().item(), soft_D_z_pperm=soft_D_z_pperm.mean().item(), recon=vae_recon_loss.item(), #kld=vae_kld.item(), acc=D_acc.item()) if self.viz_on and (self.global_iter%self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() if self.viz_on and (self.global_iter%self.viz_ra_iter == 0): self.image_gather.insert(true=x_true1.data.cpu(), recon=F.sigmoid(x_recon).data.cpu()) self.visualize_recon() self.image_gather.flush() if self.viz_on and (self.global_iter%self.viz_ta_iter == 0): if self.dataset.lower() == '3dchairs': self.visualize_traverse(limit=2, inter=0.5) else: #self.visualize_traverse(limit=3, inter=2/3) print("ignore") if self.global_iter >= self.max_iter: out = True break self.test_count += 1 self.pbar.write("[Training Finished]") torch.save(self.VAE.state_dict(), "model1/0531_128_2_gamma2.pth") self.pbar.close() def visualize_recon(self): data = self.image_gather.data true_image = data['true'][0] recon_image = data['recon'][0] true_image = make_grid(true_image) recon_image = make_grid(recon_image) sample = torch.stack([true_image, recon_image], dim=0) self.viz.images(sample, env=self.name+'/recon_image', opts=dict(title=str(self.global_iter))) def visualize_line(self): data = self.line_gather.data iters = torch.Tensor(data['iter']) recon = torch.Tensor(data['recon']) kld = torch.Tensor(data['kld']) D_acc = torch.Tensor(data['acc']) soft_D_z = torch.Tensor(data['soft_D_z']) soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm']) soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1) self.viz.line(X=iters, Y=soft_D_zs, env=self.name+'/lines', win=self.win_id['D_z'], update='append', opts=dict( xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) self.viz.line(X=iters, Y=recon, env=self.name+'/lines', win=self.win_id['recon'], update='append', opts=dict( xlabel='iteration', ylabel='reconstruction loss',)) self.viz.line(X=iters,Y=D_acc, env=self.name+'/lines', win=self.win_id['acc'], update='append', opts=dict( xlabel='iteration', ylabel='discriminator accuracy',)) ''' self.viz.line(X=iters, Y=kld, env=self.name+'/lines', win=self.win_id['kld'], update='append', opts=dict( xlabel='iteration', ylabel='kl divergence',)) ''' def visualize_traverse(self, limit=3, inter=2/3, loc=-1): self.net_mode(train=False) decoder = self.VAE.decode encoder = self.VAE.encode interpolation = torch.arange(-limit, limit+0.1, inter) random_img = self.data_loader.dataset.__getitem__(0)[1] random_img = random_img.to(self.device).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] if self.dataset.lower() == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2, 'fixed_heart':fixed_img_z3, 'random_img':random_img_z} elif self.dataset.lower() == 'celeba': fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg' fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg' fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg' fixed_idx4 = 70059 # 'CelebA/img_align_celeba/070060.jpg' fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0] fixed_img4 = fixed_img4.to(self.device).unsqueeze(0) fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim] Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4, 'random':random_img_z} elif self.dataset.lower() == '3dchairs': fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png fixed_idx2 = 5172 # 3DChairs/images/14657_image_020_p020_t232_r096.png fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, 'fixed_3':fixed_img_z3, 'random':random_img_z} else: fixed_idx = 0 fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0] fixed_img = fixed_img.to(self.device).unsqueeze(0) fixed_img_z = encoder(fixed_img)[:, :self.z_dim] random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device) Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} gifs = [] for key in Z: z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) self.viz.images(samples, env=self.name+'/traverse', opts=dict(title=title), nrow=len(interpolation)) if self.output_save: output_dir = os.path.join(self.output_dir, str(self.global_iter)) mkdirs(output_dir) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(str(os.path.join(output_dir, key+'*.jpg')), str(os.path.join(output_dir, key+'.gif')), delay=10) self.net_mode(train=True) def viz_init(self): zero_init = torch.zeros([1]) self.viz.line(X=zero_init, Y=torch.stack([zero_init, zero_init], -1), env=self.name+'/lines', win=self.win_id['D_z'], opts=dict( xlabel='iteration', ylabel='D(.)', legend=['D(z)', 'D(z_perm)'])) self.viz.line(X=zero_init, Y=zero_init, env=self.name+'/lines', win=self.win_id['recon'], opts=dict( xlabel='iteration', ylabel='reconstruction loss',)) self.viz.line(X=zero_init, Y=zero_init, env=self.name+'/lines', win=self.win_id['acc'], opts=dict( xlabel='iteration', ylabel='discriminator accuracy',)) self.viz.line(X=zero_init, Y=zero_init, env=self.name+'/lines', win=self.win_id['kld'], opts=dict( xlabel='iteration', ylabel='kl divergence',)) def net_mode(self, train): if not isinstance(train, bool): raise ValueError('Only bool type is supported. True|False') for net in self.nets: if train: net.train() else: net.eval() def save_checkpoint(self, ckptname='last', verbose=True): model_states = {'D':self.D.state_dict(), 'VAE':self.VAE.state_dict()} optim_states = {'optim_D':self.optim_D.state_dict(), 'optim_VAE':self.optim_VAE.state_dict()} states = {'iter':self.global_iter, 'model_states':model_states, 'optim_states':optim_states} filepath = os.path.join(self.ckpt_dir, str(ckptname)) with open(filepath, 'wb+') as f: torch.save(states, f) if verbose: self.pbar.write("=> saved checkpoint '{}' (iter {})".format(filepath, self.global_iter)) def load_checkpoint(self, ckptname='last', verbose=True): if ckptname == 'last': ckpts = os.listdir(self.ckpt_dir) if not ckpts: if verbose: self.pbar.write("=> no checkpoint found") return ckpts = [int(ckpt) for ckpt in ckpts] ckpts.sort(reverse=True) ckptname = str(ckpts[0]) filepath = os.path.join(self.ckpt_dir, ckptname) if os.path.isfile(filepath): with open(filepath, 'rb') as f: checkpoint = torch.load(f) self.global_iter = checkpoint['iter'] self.VAE.load_state_dict(checkpoint['model_states']['VAE']) self.D.load_state_dict(checkpoint['model_states']['D']) self.optim_VAE.load_state_dict(checkpoint['optim_states']['optim_VAE']) self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D']) self.pbar.update(self.global_iter) if verbose: self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(filepath, self.global_iter)) else: if verbose: self.pbar.write("=> no checkpoint found at '{}'".format(filepath)) def senzai_view(self, z, label): plt.figure(figsize=(10, 10)) plt.scatter(z[:, 0], z[:, 1], marker='.', c=label, cmap=pylab.cm.jet) plt.colorbar() plt.grid() plt.title('oza_FVAE_2dimention') plt.savefig('FVAE0531_128_2_gamma2_senzai.png') def load_model(self): self.VAE.load_state_dict(torch.load("model1/0531_128_2_gamma2.pth", map_location=self.device)) for data, label in self.data_loader: data = data.to(self.device) data = data.view(data.shape[0], -1) label = label.detach().numpy() break n = 10 x_recon, mu, logvar, z = self.VAE(data) z = Variable(z, volatile=True).cpu().numpy() data = Variable(data, volatile=True).cpu().numpy() x_recon = Variable(x_recon, volatile=True).cpu().numpy() #以下, ラベルごとの分散算出 ''' sum = 0 for i in range(10): tmp = np.where(label == i) print(np.var(z[tmp])) sum += np.var(z[tmp]) sum /= 10 print("Ave var: " + str(sum)) quit() #ここまで, 通常は消すこと plt.figure(figsize=(10, 10)) plt.scatter(z[:, 0], z[:, 1], marker='.', c=label, cmap=pylab.cm.jet) plt.colorbar() plt.grid() plt.savefig('FVAE0528_128_2_senzai.png') ''' if self.z_dim == 2: self.senzai_view(z, label) plt.figure(figsize=(12, 6)) for i in range(n): ax = plt.subplot(3, n, i+1) if i == 1: plt.title('Original MNIST') plt.imshow(data[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax = plt.subplot(3, n, i+1+n) if i == 1: plt.title('FVAE_Reconstruction MNIST(20dim)') plt.imshow(x_recon[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.savefig("FVAE0531_128_20_gamma2_recon.png") plt.show() plt.close()
class Solver(object): #### def __init__(self, args): self.args = args self.name = '%s_map_pred_len_%s_zS_%s_dr_mlp_%s_dr_rnn_%s_dr_map_%s_enc_h_dim_%s_dec_h_dim_%s_mlp_dim_%s_emb_dim_%s_lr_%s_klw_%s_map_%s' % \ (args.dataset_name, args.pred_len, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.dropout_map, args.encoder_h_dim, args.decoder_h_dim, args.mlp_dim, args.emb_dim, args.lr_VAE, args.kl_weight, args.map_size) # to be appended by run_id # self.use_cuda = args.cuda and torch.cuda.is_available() self.device = args.device self.temp = 0.66 self.eps = 1e-9 self.kl_weight = args.kl_weight self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info self.dataset_dir = args.dataset_dir self.dataset_name = args.dataset_name # self.N = self.latent_values.shape[0] # self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.zS_dim = args.zS_dim self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE print(args.desc) # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict(recon='win_recon', loss_kl='win_loss_kl', loss_recon='win_loss_recon', total_loss='win_total_loss', ade_min='win_ade_min', fde_min='win_fde_min', ade_avg='win_ade_avg', fde_avg='win_fde_avg', ade_std='win_ade_std', fde_std='win_fde_std', test_loss_recon='win_test_loss_recon', test_loss_kl='win_test_loss_kl', test_total_loss='win_test_total_loss') self.line_gather = DataGather('iter', 'loss_recon', 'loss_kl', 'total_loss', 'ade_min', 'fde_min', 'ade_avg', 'fde_avg', 'ade_std', 'fde_std', 'test_loss_recon', 'test_loss_kl', 'test_total_loss') import visdom self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records") mkdirs("ckpts") mkdirs("outputs") # set run id if args.run_id < 0: # create a new id k = 0 rfname = os.path.join("records", self.name + '_run_0.txt') while os.path.exists(rfname): k += 1 rfname = os.path.join("records", self.name + '_run_%d.txt' % k) self.run_id = k else: # user-provided id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # records (text file to store console outputs) self.record_file = 'records/%s.txt' % self.name # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # outputs self.output_dir_recon = os.path.join("outputs", self.name + '_recon') # dir for reconstructed images self.output_dir_synth = os.path.join("outputs", self.name + '_synth') # dir for synthesized images self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl') #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter self.obs_len = args.obs_len self.pred_len = args.pred_len self.num_layers = args.num_layers self.decoder_h_dim = args.decoder_h_dim if self.ckpt_load_iter == 0 or args.dataset_name == 'all': # create a new model self.encoderMx = Encoder(args.zS_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, emb_dim=args.emb_dim, map_size=args.map_size, batch_norm=args.batch_norm, num_layers=args.num_layers, dropout_mlp=args.dropout_mlp, dropout_rnn=args.dropout_rnn, dropout_map=args.dropout_map).to( self.device) self.encoderMy = EncoderY(args.zS_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, emb_dim=args.emb_dim, map_size=args.map_size, num_layers=args.num_layers, dropout_rnn=args.dropout_rnn, dropout_map=args.dropout_map, device=self.device).to(self.device) self.decoderMy = Decoder(args.pred_len, dec_h_dim=self.decoder_h_dim, enc_h_dim=args.encoder_h_dim, mlp_dim=args.mlp_dim, z_dim=args.zS_dim, num_layers=args.num_layers, device=args.device, dropout_rnn=args.dropout_rnn).to( self.device) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') # get VAE parameters vae_params = \ list(self.encoderMx.parameters()) + \ list(self.encoderMy.parameters()) + \ list(self.decoderMy.parameters()) # create optimizers self.optim_vae = optim.Adam(vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE]) ######## map # self.map = imageio.imread('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '/map.png') # h = np.loadtxt('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '\H.txt') # self.inv_h_t = np.linalg.pinv(np.transpose(h)) self.map_size = args.map_size ###################################### # prepare dataloader (iterable) print('Start loading data...') train_path = os.path.join(self.dataset_dir, self.dataset_name, 'train') val_path = os.path.join(self.dataset_dir, self.dataset_name, 'test') # long_dtype, float_dtype = get_dtypes(args) print("Initializing train dataset") if self.dataset_name == 'eth': self.args.pixel_distance = 5 # for hotel else: self.args.pixel_distance = 3 # for eth _, self.train_loader = data_loader(self.args, train_path) print("Initializing val dataset") if self.dataset_name == 'eth': self.args.pixel_distance = 3 else: self.args.pixel_distance = 5 _, self.val_loader = data_loader(self.args, val_path) # self.val_loader = self.train_loader print('There are {} iterations per epoch'.format( len(self.train_loader.dataset) / args.batch_size)) print('...done') #### def train(self): self.set_mode(train=True) data_loader = self.train_loader self.N = len(data_loader.dataset) # iterators from dataloader iterator = iter(data_loader) iter_per_epoch = len(iterator) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator = iter(data_loader) # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ # sample a mini-batch (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames, past_obst, fut_obst) = next(iterator) batch = fut_traj.size(1) (last_past_map_feat, encX_h_feat, logitX) = self.encoderMx(past_obst, seq_start_end, train=True) (fut_map_emb, encY_h_feat, logitY) \ = self.encoderMy(past_obst[-1], fut_obst, seq_start_end, encX_h_feat, train=True) p_dist = discrete(logits=logitX) q_dist = discrete(logits=logitY) relaxed_q_dist = concrete(logits=logitY, temperature=self.temp) fut_map_mean = self.decoderMy(last_past_map_feat, encX_h_feat, relaxed_q_dist.rsample(), fut_map_emb) fut_map_mean = fut_map_mean.view(fut_obst.shape[0], fut_obst.shape[1], -1, fut_map_mean.shape[2], fut_map_mean.shape[3]) loglikelihood = (torch.log(fut_map_mean + self.eps) * fut_obst + torch.log(1 - fut_map_mean + self.eps) * (1 - fut_obst)).sum().div(batch) loss_kl = kl_divergence(q_dist, p_dist).sum().div(batch) loss_kl = torch.clamp(loss_kl, min=0.07) # print('log_likelihood:', loglikelihood.item(), ' kl:', loss_kl.item()) elbo = loglikelihood - self.kl_weight * loss_kl vae_loss = -elbo self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): test_loss_recon, test_loss_kl, test_vae_loss = self.test() self.line_gather.insert( iter=iteration, loss_recon=-loglikelihood.item(), loss_kl=loss_kl.item(), total_loss=vae_loss.item(), test_loss_recon=-test_loss_recon.item(), test_loss_kl=test_loss_kl.item(), test_total_loss=test_vae_loss.item(), ) prn_str = ('[iter_%d (epoch_%d)] vae_loss: %.3f ' + \ '(recon: %.3f, kl: %.3f)\n' ) % \ (iteration, epoch, vae_loss.item(), -loglikelihood.item(), loss_kl.item() ) print(prn_str) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() if (iteration % self.output_save_iter == 0): self.recon(self.val_loader) def test(self): self.set_mode(train=False) all_loglikelihood = 0 all_loss_kl = 0 all_vae_loss = 0 b = 0 with torch.no_grad(): for abatch in self.val_loader: b += 1 # sample a mini-batch (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames, past_obst, fut_obst) = abatch batch = fut_traj.size(1) (last_past_map_feat, encX_h_feat, logitX) = self.encoderMx(past_obst, seq_start_end) (_, _, logitY) \ = self.encoderMy(past_obst[-1], fut_obst, seq_start_end, encX_h_feat) p_dist = discrete(logits=logitX) q_dist = discrete(logits=logitY) relaxed_p_dist = concrete(logits=logitX, temperature=self.temp) fut_map_mean = self.decoderMy(last_past_map_feat, encX_h_feat, relaxed_p_dist.rsample()) fut_map_mean = fut_map_mean.view(fut_obst.shape[0], fut_obst.shape[1], -1, fut_map_mean.shape[2], fut_map_mean.shape[3]) loglikelihood = ( torch.log(fut_map_mean + self.eps) * fut_obst + torch.log(1 - fut_map_mean + self.eps) * (1 - fut_obst)).sum().div(batch) loss_kl = kl_divergence(q_dist, p_dist).sum().div(batch) loss_kl = torch.clamp(loss_kl, min=0.07) elbo = loglikelihood - self.kl_weight * loss_kl vae_loss = -elbo all_loglikelihood += loglikelihood all_loss_kl += loss_kl all_vae_loss += vae_loss self.set_mode(train=True) return all_loglikelihood.div(b), all_loss_kl.div(b), all_vae_loss.div( b) def recon(self, data_loader): self.set_mode(train=False) with torch.no_grad(): fixed_idxs = range(5) from data.obstacles import seq_collate data = [] for i, idx in enumerate(fixed_idxs): data.append(data_loader.dataset.__getitem__(idx)) (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames, past_obst, fut_obst) = seq_collate(data) (last_past_map_feat, encX_h_feat, logitX) = self.encoderMx(past_obst, seq_start_end) (fut_map_emb, _, logitY) = self.encoderMy(past_obst[-1], fut_obst, seq_start_end, encX_h_feat) relaxed_p_dist = concrete(logits=logitX, temperature=self.temp) relaxed_q_dist = concrete(logits=logitY, temperature=self.temp) prior_fut_map_mean = self.decoderMy(last_past_map_feat, encX_h_feat, relaxed_p_dist.rsample()) posterior_fut_map_mean = self.decoderMy( last_past_map_feat, encX_h_feat, relaxed_q_dist.rsample(), fut_map_emb, ) prior_fut_map_mean = prior_fut_map_mean.view( fut_obst.shape[0], fut_obst.shape[1], -1, prior_fut_map_mean.shape[2], prior_fut_map_mean.shape[3]) posterior_fut_map_mean = posterior_fut_map_mean.view( fut_obst.shape[0], fut_obst.shape[1], -1, posterior_fut_map_mean.shape[2], posterior_fut_map_mean.shape[3]) out_dir = os.path.join('./output', self.name, str(self.ckpt_load_iter)) mkdirs(out_dir) for i in range(fut_obst.shape[1]): save_image(prior_fut_map_mean[:, i], str( os.path.join( out_dir, 'prior_recon_img' + str(i) + '.png')), nrow=self.pred_len, pad_value=1) save_image(posterior_fut_map_mean[:, i], str( os.path.join( out_dir, 'posterior_recon_img' + str(i) + '.png')), nrow=self.pred_len, pad_value=1) save_image(fut_obst[:, i], str( os.path.join(out_dir, 'gt_img' + str(i) + '.png')), nrow=self.pred_len, pad_value=1) self.set_mode(train=True) #### def viz_init(self): self.viz.close(env=self.name + '/lines', win=self.win_id['loss_recon']) self.viz.close(env=self.name + '/lines', win=self.win_id['loss_kl']) self.viz.close(env=self.name + '/lines', win=self.win_id['total_loss']) self.viz.close(env=self.name + '/lines', win=self.win_id['test_loss_recon']) self.viz.close(env=self.name + '/lines', win=self.win_id['test_loss_kl']) self.viz.close(env=self.name + '/lines', win=self.win_id['test_total_loss']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) loss_recon = torch.Tensor(data['loss_recon']) loss_kl = torch.Tensor(data['loss_kl']) total_loss = torch.Tensor(data['total_loss']) test_loss_recon = torch.Tensor(data['test_loss_recon']) test_loss_kl = torch.Tensor(data['test_loss_kl']) test_total_loss = torch.Tensor(data['test_total_loss']) self.viz.line(X=iters, Y=loss_recon, env=self.name + '/lines', win=self.win_id['loss_recon'], update='append', opts=dict(xlabel='iter', ylabel='-loglikelihood', title='Recon. loss of predicted future traj')) self.viz.line( X=iters, Y=loss_kl, env=self.name + '/lines', win=self.win_id['loss_kl'], update='append', opts=dict(xlabel='iter', ylabel='kl divergence', title='KL div. btw posterior and c. prior'), ) self.viz.line( X=iters, Y=total_loss, env=self.name + '/lines', win=self.win_id['total_loss'], update='append', opts=dict(xlabel='iter', ylabel='vae loss', title='VAE loss'), ) self.viz.line(X=iters, Y=test_loss_recon, env=self.name + '/lines', win=self.win_id['test_loss_recon'], update='append', opts=dict( xlabel='iter', ylabel='-loglikelihood', title='Test Recon. loss of predicted future traj')) self.viz.line( X=iters, Y=test_loss_kl, env=self.name + '/lines', win=self.win_id['test_loss_kl'], update='append', opts=dict(xlabel='iter', ylabel='kl divergence', title='Test KL div. btw posterior and c. prior'), ) self.viz.line( X=iters, Y=test_total_loss, env=self.name + '/lines', win=self.win_id['test_total_loss'], update='append', opts=dict(xlabel='iter', ylabel='vae loss', title='Test VAE loss'), ) def set_mode(self, train=True): if train: self.encoderMx.train() self.encoderMy.train() self.decoderMy.train() else: self.encoderMx.eval() self.encoderMy.eval() self.decoderMy.eval() #### def save_checkpoint(self, iteration): encoderMx_path = os.path.join(self.ckpt_dir, 'iter_%s_encoderMx.pt' % iteration) encoderMy_path = os.path.join(self.ckpt_dir, 'iter_%s_encoderMy.pt' % iteration) decoderMy_path = os.path.join(self.ckpt_dir, 'iter_%s_decoderMy.pt' % iteration) mkdirs(self.ckpt_dir) torch.save(self.encoderMx, encoderMx_path) torch.save(self.encoderMy, encoderMy_path) torch.save(self.decoderMy, decoderMy_path) #### def load_checkpoint(self): encoderMx_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderMx.pt' % self.ckpt_load_iter) encoderMy_path = os.path.join( self.ckpt_dir, 'iter_%s_encoderMy.pt' % self.ckpt_load_iter) decoderMy_path = os.path.join( self.ckpt_dir, 'iter_%s_decoderMy.pt' % self.ckpt_load_iter) if self.device == 'cuda': self.encoderMx = torch.load(encoderMx_path) self.encoderMy = torch.load(encoderMy_path) self.decoderMy = torch.load(decoderMy_path) else: self.encoderMx = torch.load(encoderMx_path, map_location='cpu') self.encoderMy = torch.load(encoderMy_path, map_location='cpu') self.decoderMy = torch.load(decoderMy_path, map_location='cpu')