def viz_traverse(self, limit=3, inter=2/3, loc=-1): self.net_mode(train=False) import random decoder = self.net.decoder encoder = self.net.encoder interpolation = torch.arange(-limit, limit+0.1, inter) n_dsets = len(self.data_loader.dataset) rand_idx = random.randint(1, n_dsets-1) random_img = self.data_loader.dataset.__getitem__(rand_idx) random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True) if self.dataset == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1) fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2) fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3) fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2, 'fixed_heart':fixed_img_z3, 'random_img':random_img_z} else: fixed_idx = 0 fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z = encoder(fixed_img)[:, :self.z_dim] Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} gifs = [] for key in Z.keys(): z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) if self.viz_on: self.viz.images(samples, env=self.viz_name+'_traverse', opts=dict(title=title), nrow=len(interpolation)) if self.save_output: output_dir = os.path.join(self.output_dir, str(self.global_iter)) os.makedirs(output_dir, exist_ok=True) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(os.path.join(output_dir, key+'*.jpg'), os.path.join(output_dir, key+'.gif'), delay=10) self.net_mode(train=True)
def train(self): # self.net_mode(train=True) out = False # Start training from scratch or resume training. self.global_iter = 0 if self.resume_iters: self.global_iter = self.resume_iters self.restore_model(self.resume_iters) pbar = tqdm(total=self.max_iter) pbar.update(self.global_iter) while not out: for sup_package in self.data_loader: # appe, pose, combine A_img = sup_package['A'] B_img = sup_package['B'] C_img = sup_package['C'] D_img = sup_package['D'] E_img = sup_package['E'] F_img = sup_package['F'] self.global_iter += 1 pbar.update(1) A_img = Variable(cuda(A_img, self.use_cuda)) B_img = Variable(cuda(B_img, self.use_cuda)) C_img = Variable(cuda(C_img, self.use_cuda)) D_img = Variable(cuda(D_img, self.use_cuda)) E_img = Variable(cuda(E_img, self.use_cuda)) F_img = Variable(cuda(F_img, self.use_cuda)) ## 1. A B C seperate(first400: id last600 background) A_recon, A_z = self.Autoencoder(A_img) B_recon, B_z = self.Autoencoder(B_img) C_recon, C_z = self.Autoencoder(C_img) D_recon, D_z = self.Autoencoder(D_img) E_recon, E_z = self.Autoencoder(E_img) F_recon, F_z = self.Autoencoder(F_img) ''' refer 1: content, 2: size, 3: font-color, 4 back_color, 5 style''' A_z_1 = A_z[:, 0:self.z_size_start_dim] # 0-200 A_z_2 = A_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 20-40 A_z_3 = A_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #40-60 A_z_4 = A_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 60-80 A_z_5 = A_z[:, self.z_style_start_dim :] #80-100 B_z_1 = B_z[:, 0:self.z_size_start_dim] # 0-200 B_z_2 = B_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 200-400 B_z_3 = B_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #400-600 B_z_4 = B_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 600-800 B_z_5 = B_z[:, self.z_style_start_dim :] #800-1000 C_z_1 = C_z[:, 0:self.z_size_start_dim] # 0-200 C_z_2 = C_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 200-400 C_z_3 = C_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #400-600 C_z_4 = C_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 600-800 C_z_5 = C_z[:, self.z_style_start_dim :] #800-1000 D_z_1 = D_z[:, 0:self.z_size_start_dim] # 0-200 D_z_2 = D_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 200-400 D_z_3 = D_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #400-600 D_z_4 = D_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 600-800 D_z_5 = D_z[:, self.z_style_start_dim :] #800-1000 E_z_1 = E_z[:, 0:self.z_size_start_dim] # 0-200 E_z_2 = E_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 200-400 E_z_3 = E_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #400-600 E_z_4 = E_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 600-800 E_z_5 = E_z[:, self.z_style_start_dim :] #800-1000 F_z_1 = F_z[:, 0:self.z_size_start_dim] # 0-200 F_z_2 = F_z[:, self.z_size_start_dim : self.z_font_color_start_dim] # 200-400 F_z_3 = F_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim] #400-600 F_z_4 = F_z[:, self.z_back_color_start_dim : self.z_style_start_dim] # 600-800 F_z_5 = F_z[:, self.z_style_start_dim :] #800-1000 ## 2. combine with strong supervise ''' refer 1: content, 2: size, 3: font-color, 4 back_color, 5 style''' # C A same content-1 A1Co_combine_2C = torch.cat((A_z_1, C_z_2, C_z_3, C_z_4, C_z_5), dim=1) mid_A1Co = self.Autoencoder.fc_decoder(A1Co_combine_2C) mid_A1Co = mid_A1Co.view(A1Co_combine_2C.shape[0], 256, 8, 8) A1Co_2C = self.Autoencoder.decoder(mid_A1Co) AoC1_combine_2A = torch.cat((C_z_1, A_z_2, A_z_3, A_z_4, A_z_5), dim=1) mid_AoC1 = self.Autoencoder.fc_decoder(AoC1_combine_2A) mid_AoC1 = mid_AoC1.view(AoC1_combine_2A.shape[0], 256, 8, 8) AoC1_2A = self.Autoencoder.decoder(mid_AoC1) # C B same size 2 B2Co_combine_2C = torch.cat((C_z_1, B_z_2, C_z_3, C_z_4, C_z_5), dim=1) mid_B2Co = self.Autoencoder.fc_decoder(B2Co_combine_2C) mid_B2Co = mid_B2Co.view(B2Co_combine_2C.shape[0], 256, 8, 8) B2Co_2C = self.Autoencoder.decoder(mid_B2Co) BoC2_combine_2B = torch.cat((B_z_1, C_z_2, B_z_3, B_z_4, B_z_5), dim=1) mid_BoC2 = self.Autoencoder.fc_decoder(BoC2_combine_2B) mid_BoC2 = mid_BoC2.view(BoC2_combine_2B.shape[0], 256, 8, 8) BoC2_2B = self.Autoencoder.decoder(mid_BoC2) # C D same font_color 3 D3Co_combine_2C = torch.cat((C_z_1, C_z_2, D_z_3, C_z_4, C_z_5), dim=1) mid_D3Co = self.Autoencoder.fc_decoder(D3Co_combine_2C) mid_D3Co = mid_D3Co.view(D3Co_combine_2C.shape[0], 256, 8, 8) D3Co_2C = self.Autoencoder.decoder(mid_D3Co) DoC3_combine_2D = torch.cat((D_z_1, D_z_2, C_z_3, D_z_4, D_z_5), dim=1) mid_DoC3 = self.Autoencoder.fc_decoder(DoC3_combine_2D) mid_DoC3 = mid_DoC3.view(DoC3_combine_2D.shape[0], 256, 8, 8) DoC3_2D = self.Autoencoder.decoder(mid_DoC3) # C E same back_color 4 E4Co_combine_2C = torch.cat((C_z_1, C_z_2, C_z_3, E_z_4, C_z_5), dim=1) mid_E4Co = self.Autoencoder.fc_decoder(E4Co_combine_2C) mid_E4Co = mid_E4Co.view(E4Co_combine_2C.shape[0], 256, 8, 8) E4Co_2C = self.Autoencoder.decoder(mid_E4Co) EoC4_combine_2E = torch.cat((E_z_1, E_z_2, E_z_3, C_z_4, E_z_5), dim=1) mid_EoC4 = self.Autoencoder.fc_decoder(EoC4_combine_2E) mid_EoC4 = mid_EoC4.view(EoC4_combine_2E.shape[0], 256, 8, 8) EoC4_2E = self.Autoencoder.decoder(mid_EoC4) # C F same style 5 F5Co_combine_2C = torch.cat((C_z_1, C_z_2, C_z_3, C_z_4, F_z_5), dim=1) mid_F5Co = self.Autoencoder.fc_decoder(F5Co_combine_2C) mid_F5Co = mid_F5Co.view(F5Co_combine_2C.shape[0], 256, 8, 8) F5Co_2C = self.Autoencoder.decoder(mid_F5Co) FoC5_combine_2F = torch.cat((F_z_1, F_z_2, F_z_3, F_z_4, C_z_5), dim=1) mid_FoC5 = self.Autoencoder.fc_decoder(FoC5_combine_2F) mid_FoC5 = mid_FoC5.view(FoC5_combine_2F.shape[0], 256, 8, 8) FoC5_2F = self.Autoencoder.decoder(mid_FoC5) # combine_2C A1B2D3E4F5_combine_2C = torch.cat((A_z_1, B_z_2, D_z_3, E_z_4, F_z_5), dim=1) mid_A1B2D3E4F5 = self.Autoencoder.fc_decoder(A1B2D3E4F5_combine_2C) mid_A1B2D3E4F5 = mid_A1B2D3E4F5.view(A1B2D3E4F5_combine_2C.shape[0], 256, 8, 8) A1B2D3E4F5_2C = self.Autoencoder.decoder(mid_A1B2D3E4F5) # ''' need unsupervise ''' A2B3D4E5F1_combine_2N = torch.cat((F_z_1, A_z_2, B_z_3, D_z_4, E_z_5), dim=1) mid_A2B3D4E5F1 = self.Autoencoder.fc_decoder(A2B3D4E5F1_combine_2N) mid_A2B3D4E5F1 = mid_A2B3D4E5F1.view(A2B3D4E5F1_combine_2N.shape[0], 256, 8, 8) A2B3D4E5F1_2N = self.Autoencoder.decoder(mid_A2B3D4E5F1) ''' optimize for autoencoder ''' # 1. recon_loss A_recon_loss = torch.mean(torch.abs(A_img - A_recon)) B_recon_loss = torch.mean(torch.abs(B_img - B_recon)) C_recon_loss = torch.mean(torch.abs(C_img - C_recon)) D_recon_loss = torch.mean(torch.abs(D_img - D_recon)) E_recon_loss = torch.mean(torch.abs(E_img - E_recon)) F_recon_loss = torch.mean(torch.abs(F_img - F_recon)) recon_loss = A_recon_loss + B_recon_loss + C_recon_loss + D_recon_loss + E_recon_loss + F_recon_loss # 2. sup_combine_loss A1Co_2C_loss = torch.mean(torch.abs(C_img - A1Co_2C)) AoC1_2A_loss = torch.mean(torch.abs(A_img - AoC1_2A)) B2Co_2C_loss = torch.mean(torch.abs(C_img - B2Co_2C)) BoC2_2B_loss = torch.mean(torch.abs(B_img - BoC2_2B)) D3Co_2C_loss = torch.mean(torch.abs(C_img - D3Co_2C)) DoC3_2D_loss = torch.mean(torch.abs(D_img - DoC3_2D)) E4Co_2C_loss = torch.mean(torch.abs(C_img - E4Co_2C)) EoC4_2E_loss = torch.mean(torch.abs(E_img - EoC4_2E)) F5Co_2C_loss = torch.mean(torch.abs(C_img - F5Co_2C)) FoC5_2F_loss = torch.mean(torch.abs(F_img - FoC5_2F)) A1B2D3E4F5_2C_loss = torch.mean(torch.abs(C_img - A1B2D3E4F5_2C)) combine_sup_loss = A1Co_2C_loss + AoC1_2A_loss + B2Co_2C_loss + BoC2_2B_loss + D3Co_2C_loss + DoC3_2D_loss + E4Co_2C_loss + EoC4_2E_loss + F5Co_2C_loss + FoC5_2F_loss + A1B2D3E4F5_2C_loss # 3. unsup_combine_loss _, A2B3D4E5F1_z = self.Autoencoder(A2B3D4E5F1_2N) combine_unsup_loss = torch.mean(torch.abs(F_z_1 - A2B3D4E5F1_z[:, 0:self.z_size_start_dim])) + torch.mean(torch.abs(A_z_2 - A2B3D4E5F1_z[:, self.z_size_start_dim : self.z_font_color_start_dim])) \ + torch.mean(torch.abs(B_z_3 - A2B3D4E5F1_z[:, self.z_font_color_start_dim : self.z_back_color_start_dim])) \ + torch.mean(torch.abs(D_z_4 - A2B3D4E5F1_z[:, self.z_back_color_start_dim : self.z_style_start_dim])) \ + torch.mean(torch.abs(E_z_5 - A2B3D4E5F1_z[:, self.z_style_start_dim :])) # whole loss vae_unsup_loss = recon_loss + self.lambda_combine * combine_sup_loss + self.lambda_unsup * combine_unsup_loss self.auto_optim.zero_grad() vae_unsup_loss.backward() self.auto_optim.step() # save the log f = open(self.log_dir + '/log.txt', 'a') f.writelines(['\n', '[{}] recon_loss:{:.3f} combine_sup_loss:{:.3f} combine_unsup_loss:{:.3f}'.format( self.global_iter, recon_loss.data, combine_sup_loss.data, combine_unsup_loss.data)]) f.close() print('\n',self.global_iter, 'recon:',recon_loss.data, 'combinesup',combine_sup_loss.data, 'combine unsup',combine_unsup_loss.data) if self.viz_on and self.global_iter%self.gather_step == 0: self.gather.insert(iter=self.global_iter,recon_loss=recon_loss.data, combine_sup_loss=combine_sup_loss.data, combine_unsup_loss=combine_unsup_loss.data) if self.global_iter%self.display_step == 0: pbar.write('[{}] recon_loss:{:.3f} combine_sup_loss:{:.3f} combine_unsup_loss:{:.3f}'.format( self.global_iter, recon_loss.data, combine_sup_loss.data, combine_unsup_loss.data)) if self.viz_on: self.gather.insert(images=A_img.data) self.gather.insert(images=B_img.data) self.gather.insert(images=C_img.data) self.gather.insert(images=D_img.data) self.gather.insert(images=E_img.data) self.gather.insert(images=F_img.data) self.gather.insert(images=F.sigmoid(A_recon).data) self.viz_reconstruction() self.viz_lines() ''' combine show ''' self.gather.insert(combine_supimages=F.sigmoid(AoC1_2A).data) self.gather.insert(combine_supimages=F.sigmoid(BoC2_2B).data) self.gather.insert(combine_supimages=F.sigmoid(D3Co_2C).data) self.gather.insert(combine_supimages=F.sigmoid(DoC3_2D).data) self.gather.insert(combine_supimages=F.sigmoid(EoC4_2E).data) self.gather.insert(combine_supimages=F.sigmoid(FoC5_2F).data) self.viz_combine_recon() self.gather.insert(combine_unsupimages=F.sigmoid(A1B2D3E4F5_2C).data) self.gather.insert(combine_unsupimages=F.sigmoid(A2B3D4E5F1_2N).data) self.viz_combine_unsuprecon() # self.viz_combine(x) self.gather.flush() # Save model checkpoints. if self.global_iter%self.save_step == 0: Auto_path = os.path.join(self.model_save_dir, self.viz_name, '{}-Auto.ckpt'.format(self.global_iter)) torch.save(self.Autoencoder.state_dict(), Auto_path) print('Saved model checkpoints into {}/{}...'.format(self.model_save_dir, self.viz_name)) if self.global_iter >= self.max_iter: out = True break pbar.write("[Training Finished]") pbar.close()