def save_traverseA(self, iters, z_A, z_B, z_S, loc=-1): self.set_mode(train=False) encoderA = self.encoderA encoderB = self.encoderB decoderA = self.decoderA decoderB = self.decoderB interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim)) interpolationB = torch.tensor(np.linspace(-3, 3, self.zS_dim)) interpolationS = torch.tensor(np.linspace(-3, 3, self.zS_dim)) print('------------ traverse interpolation ------------') print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A))) print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B))) print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S))) if self.record_file: #### fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121] fixed_XA = [0] * len(fixed_idxs) fixed_XB = [0] * len(fixed_idxs) for i, idx in enumerate(fixed_idxs): fixed_XA[i], fixed_XB[i] = \ self.data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: fixed_XA[i] = fixed_XA[i].cuda() fixed_XB[i] = fixed_XB[i].cuda() fixed_XA[i] = fixed_XA[i].unsqueeze(0) fixed_XB[i] = fixed_XB[i].unsqueeze(0) fixed_XA = torch.cat(fixed_XA, dim=0) fixed_XB = torch.cat(fixed_XB, dim=0) fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA) # zB, zS = encB(xB) fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB) # zS = encAB(xA,xB) via POE fixed_cate_probS = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False) fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape #### WS = torch.ones(saving_shape) if self.use_cuda: WS = WS.cuda() # do traversal and collect generated images gifs = [] zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS tempA = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32 for row in range(self.zA_dim): if loc != -1 and row != loc: continue zA = zA_ori.clone() temp = [] for val in interpolationA: zA[:, row] = val sampleA = torch.sigmoid(decoderA(zA, zS_ori)).data temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0)) tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32 temp = [] for i in range(self.zS_dim): zS = np.zeros((1, self.zS_dim)) zS[0, i % self.zS_dim] = 1. zS = torch.Tensor(zS) zS = torch.cat([zS] * len(fixed_idxs), dim=0) if self.use_cuda: zS = zS.cuda() sampleA = torch.sigmoid(decoderA(zA_ori, zS)).data temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0)) tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) gifs = torch.cat(tempA, dim=0) #torch.Size([11, 10, 1, 384, 32]) # save the generated files, also the animated gifs out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train') mkdirs(self.output_dir_trvsl) mkdirs(out_dir) for j, val in enumerate(interpolationA): # I = torch.cat([IMG[key], gifs[:][j]], dim=0) I = gifs[:,j] save_image( tensor=I.cpu(), filename=os.path.join(out_dir, '%03d.jpg' % (j)), nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim, pad_value=1) # make animated gif grid2gif2( out_dir, str(os.path.join(out_dir, 'mnist_traverse' + '.gif')), delay=10 ) self.set_mode(train=True)
def save_recon(self, iters): self.set_mode(train=False) mkdirs(self.output_dir_recon) fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121] fixed_idxs60 = [] for idx in fixed_idxs: for i in range(6): fixed_idxs60.append(idx + i) XA = [0] * len(fixed_idxs60) XB = [0] * len(fixed_idxs60) for i, idx in enumerate(fixed_idxs60): XA[i], XB[i] = \ self.data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: XA[i] = XA[i].cuda() XB[i] = XB[i].cuda() XA = torch.stack(XA) XB = torch.stack(XB) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB) # zS = encAB(xA,xB) via POE cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # encoder samples (for training) ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA) ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB) ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE, train=False) # encoder samples (for cross-modal prediction) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False) # reconstructed samples (given joint modal observation) XA_POE_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_POE)) XB_POE_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_POE)) # reconstructed samples (given single modal observation) XA_infA_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_infA)) XB_infB_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_infB)) WS = torch.ones(XA.shape) if self.use_cuda: WS = WS.cuda() n = XA.shape[0] perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0) perm = perm.contiguous().view(-1) ## img # merged = torch.cat( # [ XA, XB, XA_infA_recon, XB_infB_recon, # XA_POE_recon, XB_POE_recon, WS ], dim=0 # ) merged = torch.cat( [XA, XA_infA_recon, XA_POE_recon, WS], dim=0 ) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'reconA_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image( tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)), pad_value=1 ) WS = torch.ones(XB.shape) if self.use_cuda: WS = WS.cuda() n = XB.shape[0] perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0) perm = perm.contiguous().view(-1) ## ingr merged = torch.cat( [XB, XB_infB_recon, XB_POE_recon, WS], dim=0 ) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'reconB_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image( tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)), pad_value=1 ) self.set_mode(train=True)
def save_synth_cross_modal(self, iters, z_A_stat, z_B_stat, train=True, howmany=3): self.set_mode(train=False) if train: data_loader = self.data_loader fixed_idxs = [3246, 7001, 14308, 19000, 27447, 33103, 38002, 45232, 51000, 55125] else: data_loader = self.test_data_loader fixed_idxs = [2, 982, 2300, 3400, 4500, 5500, 6500, 7500, 8500, 9500] fixed_XA = [0] * len(fixed_idxs) fixed_XB = [0] * len(fixed_idxs) for i, idx in enumerate(fixed_idxs): fixed_XA[i], fixed_XB[i] = \ data_loader.dataset.__getitem__(idx)[0:2] if self.use_cuda: fixed_XA[i] = fixed_XA[i].cuda() fixed_XB[i] = fixed_XB[i].cuda() fixed_XA = torch.stack(fixed_XA) fixed_XB = torch.stack(fixed_XB) _, _, _, cate_prob_infA = self.encoderA(fixed_XA) # zB, zS = encB(xB) _, _, _, cate_prob_infB = self.encoderB(fixed_XB) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False) if self.use_cuda: ZS_infA = ZS_infA.cuda() ZS_infB = ZS_infB.cuda() decoderA = self.decoderA decoderB = self.decoderB # mkdirs(os.path.join(self.output_dir_synth, str(iters))) fixed_XA_3ch = [] for i in range(len(fixed_XA)): each_XA = fixed_XA[i].clone().squeeze() fixed_XA_3ch.append(torch.stack([each_XA, each_XA, each_XA])) fixed_XA_3ch = torch.stack(fixed_XA_3ch) WS = torch.ones(fixed_XA_3ch.shape) if self.use_cuda: WS = WS.cuda() n = len(fixed_idxs) perm = torch.arange(0, (howmany + 2) * n).view(howmany + 2, n).transpose(1, 0) perm = perm.contiguous().view(-1) ######## 1) generate xB from given xA (A2B) ######## merged = torch.cat([fixed_XA_3ch], dim=0) for k in range(howmany): # z_B_stat = np.array(z_B_stat) # z_B_stat_mean = np.mean(z_B_stat, 0) # ZB = torch.Tensor(z_B_stat_mean) # ZB_list = [] # for _ in range(n): # ZB_list.append(ZB) # ZB = torch.stack(ZB_list) ZB = torch.randn(n, self.zB_dim) z_B_stat = np.array(z_B_stat) z_B_stat_mean = np.mean(z_B_stat, 0) ZB = ZB + torch.Tensor(z_B_stat_mean) if self.use_cuda: ZB = ZB.cuda() XB_synth = torch.sigmoid(decoderB(ZB, ZS_infA)) # given XA # merged = torch.cat([merged, fixed_XA_3ch], dim=0) merged = torch.cat([merged, XB_synth], dim=0) merged = torch.cat([merged, WS], dim=0) merged = merged[perm, :].cpu() # save the results as image if train: fname = os.path.join( self.output_dir_synth, 'synth_cross_modal_A2B_%s.jpg' % iters ) else: fname = os.path.join( self.output_dir_synth, 'eval_synth_cross_modal_A2B_%s.jpg' % iters ) mkdirs(self.output_dir_synth) save_image( tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)), pad_value=1 ) ######## 2) generate xA from given xB (B2A) ######## merged = torch.cat([fixed_XB], dim=0) for k in range(howmany): # z_A_stat = np.array(z_A_stat) # z_A_stat_mean = np.mean(z_A_stat, 0) # ZA = torch.Tensor(z_A_stat_mean) # ZA_list = [] # for _ in range(n): # ZA_list.append(ZA) # ZA = torch.stack(ZA_list) ZA = torch.randn(n, self.zA_dim) z_A_stat = np.array(z_A_stat) z_A_stat_mean = np.mean(z_A_stat, 0) ZA = ZA + torch.Tensor(z_A_stat_mean) if self.use_cuda: ZA = ZA.cuda() XA_synth = torch.sigmoid(decoderA(ZA, ZS_infB)) # given XB XA_synth_3ch = [] for i in range(len(XA_synth)): each_XA = XA_synth[i].clone().squeeze() XA_synth_3ch.append(torch.stack([each_XA, each_XA, each_XA])) # merged = torch.cat([merged, fixed_XB[:,:,2:30, 2:30]], dim=0) merged = torch.cat([merged, torch.stack(XA_synth_3ch)], dim=0) merged = torch.cat([merged, WS], dim=0) merged = merged[perm, :].cpu() # save the results as image if train: fname = os.path.join( self.output_dir_synth, 'synth_cross_modal_B2A_%s.jpg' % iters ) else: fname = os.path.join( self.output_dir_synth, 'eval_synth_cross_modal_B2A_%s.jpg' % iters ) mkdirs(self.output_dir_synth) save_image( tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)), pad_value=1 ) self.set_mode(train=True)
def train(self): self.set_mode(train=True) # prepare dataloader (iterable) print('Start loading data...') dset = DIGIT('./data', train=True) self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True) test_dset = DIGIT('./data', train=False) self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True) print('test: ', len(test_dset)) self.N = len(self.data_loader.dataset) print('...done') # iterators from dataloader iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) iter_per_epoch = min(len(iterator1), len(iterator2)) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ # sample a mini-batch XA, XB, index = next(iterator1) # (n x C x H x W) index = index.cpu().detach().numpy() if self.use_cuda: XA = XA.cuda() XB = XB.cuda() # zA, zS = encA(xA) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB) # read current values # zS = encAB(xA,xB) via POE cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist) # kl losses #A latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function( self.use_cuda, iteration, latent_dist_infA) loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA #B latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]} (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function( self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0]) loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB loss_capa = capacity_loss_infB # encoder samples (for training) ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA) ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB) ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE) # encoder samples (for cross-modal prediction) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB) # reconstructed samples (given joint modal observation) XA_POE_recon = self.decoderA(ZA_infA, ZS_POE) XB_POE_recon = self.decoderB(ZB_infB, ZS_POE) # reconstructed samples (given single modal observation) XA_infA_recon = self.decoderA(ZA_infA, ZS_infA) XB_infB_recon = self.decoderB(ZB_infB, ZS_infB) # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0)) loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli") # loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli") # loss_recon_POE = \ F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \ F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0)) # loss_recon = loss_recon_infB # total loss for vae vae_loss = loss_recon + loss_capa # update vae self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() # print the losses if iteration % self.print_iter == 0: prn_str = ( \ '[iter %d (epoch %d)] vae_loss: %.3f ' + \ '(recon: %.3f, capa: %.3f)\n' + \ ' rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \ ' kl_infA = %.3f, kl_infB = %.3f' + \ ' cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \ ' cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n' ) % \ (iteration, epoch, vae_loss.item(), loss_recon.item(), loss_capa.item(), loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(), loss_kl_infA.item(), loss_kl_infB.item(), cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(), cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(), ) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str,)) record.close() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # save output images (recon, synth, etc.) if iteration % self.output_save_iter == 0: # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE) # 1) save the recon images self.save_recon(iteration) # self.save_recon2(iteration, index, XA, XB, # torch.sigmoid(XA_infA_recon).data, # torch.sigmoid(XB_infB_recon).data, # torch.sigmoid(XA_POE_recon).data, # torch.sigmoid(XB_POE_recon).data, # muA_infA, muB_infB, muS_infA, muS_infB, muS_POE, # logalpha, logalphaA, logalphaB # ) z_A, z_B, z_S = self.get_stat() # # # # # 2) save the pure-synthesis images # # self.save_synth_pure( iteration, howmany=100 ) # # # # 3) save the cross-modal-synthesis images # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3) # # # 4) save the latent traversed images self.save_traverseB(iteration, z_A, z_B, z_S) # self.get_loglike(logalpha, logalphaA, logalphaB) # # 3) save the latent traversed images # if self.dataset.lower() == '3dchairs': # self.save_traverse(iteration, limb=-2, limu=2, inter=0.5) # else: # self.save_traverse(iteration, limb=-3, limu=3, inter=0.1) if iteration % self.eval_metrics_iter == 0: self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): self.line_gather.insert(iter=iteration, recon_both=loss_recon_POE.item(), recon_A=loss_recon_infA.item(), recon_B=loss_recon_infB.item(), kl_A=loss_kl_infA.item(), kl_B=loss_kl_infB.item(), cont_capacity_loss_infA=cont_capacity_loss_infA.item(), disc_capacity_loss_infA=disc_capacity_loss_infA.item(), cont_capacity_loss_infB=cont_capacity_loss_infB.item(), disc_capacity_loss_infB=disc_capacity_loss_infB.item() ) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush()