def get_paired_data(paired_cnt, seed): data = torch.utils.data.DataLoader(DIGIT('./data', train=True), batch_size=args.batch_size, shuffle=False) tr_labels = data.dataset.label cnt = int(paired_cnt / 10) assert cnt == paired_cnt / 10 label_idx = {} for i in range(10): label_idx.update({i: []}) for idx in range(len(tr_labels)): label = int(tr_labels[idx]) label_idx[label].append(idx) total_random_idx = [] for i in range(10): random.seed(seed) per_label_random_idx = random.sample(label_idx[i], cnt) total_random_idx.extend(per_label_random_idx) random.seed(seed) random.shuffle(total_random_idx) imgs = [] labels = [] for idx in total_random_idx: img, label = data.dataset.__getitem__(idx) imgs.append(img) labels.append(torch.tensor(label)) imgs = torch.stack(imgs, dim=0) labels = torch.stack(labels, dim=0) return imgs, labels
def train(self): self.set_mode(train=True) # prepare dataloader (iterable) print('Start loading data...') dset = DIGIT('./data', train=True) self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True) test_dset = DIGIT('./data', train=False) self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True) print('test: ', len(test_dset)) self.N = len(self.data_loader.dataset) print('...done') # iterators from dataloader iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) iter_per_epoch = min(len(iterator1), len(iterator2)) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) # ============================================ # TRAIN THE VAE (ENC & DEC) # ============================================ # sample a mini-batch XA, XB, index = next(iterator1) # (n x C x H x W) index = index.cpu().detach().numpy() if self.use_cuda: XA = XA.cuda() XB = XB.cuda() # zA, zS = encA(xA) muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA) # zB, zS = encB(xB) muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB) # read current values # zS = encAB(xA,xB) via POE cate_prob_POE = torch.exp( torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB)) # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist) # kl losses #A latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]} (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function( self.use_cuda, iteration, latent_dist_infA) loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA #B latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]} (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function( self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0]) loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB loss_capa = capacity_loss_infB # encoder samples (for training) ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA) ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB) ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE) # encoder samples (for cross-modal prediction) ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA) ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB) # reconstructed samples (given joint modal observation) XA_POE_recon = self.decoderA(ZA_infA, ZS_POE) XB_POE_recon = self.decoderB(ZB_infB, ZS_POE) # reconstructed samples (given single modal observation) XA_infA_recon = self.decoderA(ZA_infA, ZS_infA) XB_infB_recon = self.decoderB(ZB_infB, ZS_infB) # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0)) loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli") # loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli") # loss_recon_POE = \ F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \ F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0)) # loss_recon = loss_recon_infB # total loss for vae vae_loss = loss_recon + loss_capa # update vae self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() # print the losses if iteration % self.print_iter == 0: prn_str = ( \ '[iter %d (epoch %d)] vae_loss: %.3f ' + \ '(recon: %.3f, capa: %.3f)\n' + \ ' rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \ ' kl_infA = %.3f, kl_infB = %.3f' + \ ' cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \ ' cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n' ) % \ (iteration, epoch, vae_loss.item(), loss_recon.item(), loss_capa.item(), loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(), loss_kl_infA.item(), loss_kl_infB.item(), cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(), cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(), ) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str,)) record.close() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # save output images (recon, synth, etc.) if iteration % self.output_save_iter == 0: # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE) # 1) save the recon images self.save_recon(iteration) # self.save_recon2(iteration, index, XA, XB, # torch.sigmoid(XA_infA_recon).data, # torch.sigmoid(XB_infB_recon).data, # torch.sigmoid(XA_POE_recon).data, # torch.sigmoid(XB_POE_recon).data, # muA_infA, muB_infB, muS_infA, muS_infB, muS_POE, # logalpha, logalphaA, logalphaB # ) z_A, z_B, z_S = self.get_stat() # # # # # 2) save the pure-synthesis images # # self.save_synth_pure( iteration, howmany=100 ) # # # # 3) save the cross-modal-synthesis images # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3) # # # 4) save the latent traversed images self.save_traverseB(iteration, z_A, z_B, z_S) # self.get_loglike(logalpha, logalphaA, logalphaB) # # 3) save the latent traversed images # if self.dataset.lower() == '3dchairs': # self.save_traverse(iteration, limb=-2, limu=2, inter=0.5) # else: # self.save_traverse(iteration, limb=-3, limu=3, inter=0.1) if iteration % self.eval_metrics_iter == 0: self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): self.line_gather.insert(iter=iteration, recon_both=loss_recon_POE.item(), recon_A=loss_recon_infA.item(), recon_B=loss_recon_infB.item(), kl_A=loss_kl_infA.item(), kl_B=loss_kl_infB.item(), cont_capacity_loss_infA=cont_capacity_loss_infA.item(), disc_capacity_loss_infA=disc_capacity_loss_infA.item(), cont_capacity_loss_infB=cont_capacity_loss_infB.item(), disc_capacity_loss_infB=disc_capacity_loss_infB.item() ) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush()
if args.viz_on: WIN_ID = dict(llA='win_llA', llB='win_llB', test_acc='win_test_acc', total_losses='win_total_losses') LINE_GATHER = probtorch.util.DataGather('epoch', 'recon_A', 'recon_B', 'recon_poeA', 'recon_poeB', 'recon_crA', 'recon_crB', 'total_loss', 'test_total_loss', 'test_acc') VIZ = visdom.Visdom(port=args.viz_port) viz_init() train_data = torch.utils.data.DataLoader(DIGIT('./data', train=True), batch_size=args.batch_size, shuffle=False) test_data = torch.utils.data.DataLoader(DIGIT('./data', train=False), batch_size=args.batch_size, shuffle=False) train_data_size = len(train_data) BIAS_TRAIN = (train_data_size - 1) / (args.batch_size - 1) BIAS_TEST = (test_data.dataset.__len__() - 1) / (args.batch_size - 1) def cuda_tensors(obj): for attr in dir(obj): value = getattr(obj, attr)
title='Total Loss', legend=['train_loss', 'test_loss']) ) if args.viz_on: WIN_ID = dict( llA='win_llA', llB='win_llB', test_acc='win_test_acc', total_losses='win_total_losses' ) LINE_GATHER = probtorch.util.DataGather( 'epoch', 'recon_A', 'recon_B', 'recon_poeA', 'recon_poeB', 'recon_crA', 'recon_crB', 'total_loss', 'test_total_loss', 'test_acc' ) VIZ = visdom.Visdom(port=args.viz_port) viz_init() train_data = torch.utils.data.DataLoader(DIGIT('./data', train=True), batch_size=args.batch_size, shuffle=False) test_data = torch.utils.data.DataLoader(DIGIT('./data', train=False), batch_size=args.batch_size, shuffle=False) train_data_size = len(train_data) BIAS_TRAIN = (test_data.dataset.__len__() - 1) / (args.batch_size - 1) BIAS_TEST = (test_data.dataset.__len__() - 1) / (args.batch_size - 1) def cuda_tensors(obj): for attr in dir(obj): value = getattr(obj, attr) if isinstance(value, torch.Tensor): setattr(obj, attr, value.cuda())