def pre_train_DE(self, out_z, out_x): self.net_mode(True) out = False pre_DE_cnt = 0 pbar = tqdm(total=self.niter_preDE) pbar.update(pre_DE_cnt) loss_table = [] while not out: for z, x in zip(out_z, out_x): x = Variable(cuda(x.float(), self.use_cuda)) z = Variable(cuda(z.float(), self.use_cuda)) x_recon = self.net._decode(z) loss = reconstruction_loss(x, x_recon, self.decoder_dist) loss_table.append(loss.data.item()) self.optim_DE.zero_grad() loss.backward() self.optim_DE.step() if pre_DE_cnt >= self.niter_preDE: out = True break pre_DE_cnt += 1 pbar.update(1) pbar.write("[Pretrain Decoder Finished]") pbar.close() sys.stdout.flush() return loss_table
def pre_train_EN(self, out_z, out_x): self.net_mode(True) out = False pre_EN_cnt = 0 pbar = tqdm(total=self.niter_preEN) pbar.update(pre_EN_cnt) loss_fun = torch.nn.MSELoss() loss_table = [] while not out: for z, x in zip(out_z, out_x): x = Variable(cuda(x.float(), self.use_cuda)) z = Variable(cuda(z.float(), self.use_cuda)) distributions = self.net._encode(x) mu = distributions[:, :self.z_dim] logvar = distributions[:, self.z_dim:] z_hat = reparametrize(mu, logvar) loss = loss_fun(z_hat, z) loss_table.append(loss.data.item()) self.optim_EN.zero_grad() loss.backward() self.optim_EN.step() if pre_EN_cnt >= self.niter_preEN: out = True break pre_EN_cnt += 1 pbar.update(1) pbar.write("[Pretrain Encoder Finished]") pbar.close() sys.stdout.flush() return loss_table
def encoder_curves(args, data_loader, flag=False): seed = args.seed torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) net = Solver(args) net.net_mode(True) out = False pre_EN_cnt = 0 pbar = tqdm(total=net.niter_preEN) pbar.update(pre_EN_cnt) loss_fun = torch.nn.MSELoss() loss_table = [] while not out: for x, y, mz in data_loader: x = Variable(cuda(x.float(), net.use_cuda)) y = Variable(cuda(y.float(), net.use_cuda)) mz = Variable(cuda(mz.float(), net.use_cuda)) mz_bar = mz[:, :, :args.z_dim] mz_hat = net.net._encode(x)[:, :args.z_dim] #mz_bar = reparametrize(mz[:,:,:args.z_dim], mz[:,:,args.z_dim:]) #tmp = net.net._encode(x) #mz_hat = reparametrize(tmp[:,:args.z_dim], tmp[:,args.z_dim:]) if flag: y_zeros = torch.zeros_like(y) mz_bar = torch.cat((y, y_zeros), dim=2) loss = loss_fun(mz_bar, mz_hat) loss_table.append(loss.data.item()) net.optim_EN.zero_grad() loss.backward() net.optim_EN.step() if pre_EN_cnt >= net.niter_preEN: out = True break pre_EN_cnt += 1 pbar.update(1) pbar.write("[Pretrain Encoder Finished]") pbar.close() sys.stdout.flush() return loss_table
def decoder_curves(args, data_loader, flag=False): seed = args.seed torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) net = Solver(args) net.net_mode(True) out = False pre_DE_cnt = 0 pbar = tqdm(total=net.niter_preDE) pbar.update(pre_DE_cnt) loss_table = [] while not out: for x, y, mz in data_loader: x = Variable(cuda(x.float(), net.use_cuda)) y = Variable(cuda(y.float(), net.use_cuda)) mz = Variable(cuda(mz.float(), net.use_cuda)) #z = mz[:,:,:args.z_dim] z = reparametrize(mz[:, :, :args.z_dim], mz[:, :, args.z_dim:]) if flag: y_zeros = torch.zeros_like(y) z = torch.cat((y, y_zeros), dim=2) x_recon = net.net._decode(z) loss = reconstruction_loss(x, x_recon, 'bernoulli') loss_table.append(loss.data.item()) net.optim_DE.zero_grad() loss.backward() net.optim_DE.step() if pre_DE_cnt >= net.niter_preDE: out = True break pre_DE_cnt += 1 pbar.update(1) pbar.write("[Pretrain Encoder Finished]") pbar.close() sys.stdout.flush() return loss_table
def gen_z(self, gen_size=10, fullz=False): ''' Randomly sample x from dataloader, feed it to encoder, generate z Return z and true latent value @ out_z should be a list with length equals gen_size, each object has size B*z_dim @ out_y should be a list with length equals gen_size, each object has size B*6 ''' self.net_mode(train=False) out = False gen_cnt = 0 #pbar = tqdm(total=gen_size) #pbar.update(gen_cnt) out_z = [] out_y = [] out_x = [] out_distr_list = [] while not out: for x, y in self.data_loader: out_y.append(y.squeeze(1)[:, 1:]) out_x.append(x) #pbar.update(1) gen_cnt += 1 x = Variable(cuda(x.float(), self.use_cuda)) out_distri = self.net.encoder(x).data mu = out_distri[:, :self.z_dim] logvar = out_distri[:, self.z_dim:] out_z.append(reparametrize(mu, logvar)) out_distr_list.append(out_distri) if gen_cnt >= gen_size: out = True break self.net_mode(train=True) if fullz == True: return out_distr_list, out_y, out_x else: return out_z, out_y, out_x
def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter_per_gen = args.max_iter_per_gen self.max_gen = args.max_gen self.global_iter = 0 self.global_gen = 0 self.z_dim = args.z_dim self.beta = args.beta self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 self.nb_preENDE = args.nb_preENDE self.niter_preEN = args.niter_preEN self.niter_preDE = args.niter_preDE self.nc = 1 self.decoder_dist = 'bernoulli' net = BetaVAE_H self.net = cuda(net(self.z_dim, self.nc), self.use_cuda) self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) self.optim_EN = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) self.optim_DE = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) self.exp_name = args.exp_name self.ckpt_dir = os.path.join('exp_results/' + args.exp_name, args.ckpt_dir) if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir, exist_ok=True) self.ckpt_name = args.ckpt_name if self.ckpt_name is not None: self.load_checkpoint(self.ckpt_name) self.metric_dir = os.path.join('exp_results/' + args.exp_name + '/metrics') if not os.path.exists(self.metric_dir): os.makedirs(self.metric_dir) self.imgs_dir = os.path.join('exp_results/' + args.exp_name + '/images') if not os.path.exists(self.imgs_dir): os.makedirs(self.imgs_dir) self.save_step = args.save_step self.metric_step = args.metric_step self.metric_topsim = Metric_topsim(args) self.top_sim_batches = args.top_sim_batches self.metric_R = Metric_R(args) self.save_gifs = args.save_gifs self.dset_dir = args.dset_dir self.batch_size = args.batch_size self.data_loader = return_data(args) self.gather = DataGather()
def save_gif(self, limit=3, inter=2 / 3, loc=-1): self.net_mode(train=False) import random decoder = self.net.decoder encoder = self.net.encoder interpolation = torch.arange(-limit, limit + 0.1, inter) n_dsets = len(self.data_loader.dataset) rand_idx = random.randint(1, n_dsets - 1) random_img, _ = self.data_loader.dataset.__getitem__(rand_idx) random_img = Variable(cuda(random_img.float(), self.use_cuda), volatile=True).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1, _ = self.data_loader.dataset.__getitem__(fixed_idx1) fixed_img1 = Variable(cuda(fixed_img1.float(), self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] fixed_img2, _ = self.data_loader.dataset.__getitem__(fixed_idx2) fixed_img2 = Variable(cuda(fixed_img2.float(), self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] fixed_img3, _ = self.data_loader.dataset.__getitem__(fixed_idx3) fixed_img3 = Variable(cuda(fixed_img3.float(), self.use_cuda), volatile=True).unsqueeze(0) fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] Z = { 'fixed_square': fixed_img_z1, 'fixed_ellipse': fixed_img_z2, 'fixed_heart': fixed_img_z3, 'random_img': random_img_z } gifs = [] for key in Z.keys(): z_ori = Z[key] samples = [] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = F.sigmoid(decoder(z)).data samples.append(sample) gifs.append(sample) samples = torch.cat(samples, dim=0).cpu() output_dir = os.path.join(self.imgs_dir, str(self.global_iter)) os.makedirs(output_dir, exist_ok=True) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(os.path.join(output_dir, key + '*.jpg'), os.path.join(output_dir, key + '.gif'), delay=10) self.net_mode(train=True)
def interact_train(self): self.net_mode(train=True) out = False indx_list = [] loss_list = [] corr_list = [] dist_list = [] comp_list = [] info_list = [] R_list = [] recon_loss_list = [] bvae_loss_list = [] pbar = tqdm(total=self.max_iter_per_gen) pbar.update(0) local_iter = 0 with open(self.metric_dir + '/results.txt', 'a') as f: f.write('====== Experiment name: ' + self.exp_name + '==============\n') while not out: for x, y in self.data_loader: self.global_iter += 1 local_iter += 1 pbar.update(1) x = Variable(cuda(x.float(), self.use_cuda)) x_recon, mu, logvar = self.net(x) recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) beta_vae_loss = recon_loss + self.beta * total_kld recon_loss_list.append(recon_loss.data.item()) bvae_loss_list.append(beta_vae_loss.data.item()) self.optim.zero_grad() beta_vae_loss.backward() self.optim.step() if False: #self.global_iter%(self.metric_step*5) == 0: out_fz, out_y, out_x = self.gen_z(self.top_sim_batches * 10, fullz=True) fz_upk = torch.stack(out_fz).view(-1, self.z_dim * 2).cpu() y_upk = torch.stack(out_y).view(-1, 5).cpu() x_upk = torch.stack(out_x).view(-1, 64, 64).cpu() np.savez( self.metric_dir + '/saved_xyz_' + str(self.global_iter) + '.npz', out_fz=np.asarray( fz_upk ), # Here out_z is the distribution ([mu:sigma]) out_y=np.asarray(y_upk), out_x=np.asarray(x_upk)) #test_z = np.load(os.path.join(net.metric_dir+'/saved_xyz0.npz')) if self.global_iter % self.metric_step == 0: out_z, out_y, _ = self.gen_z(self.top_sim_batches) corr = self.metric_topsim.top_sim_zy( out_z[:20], out_y[:20]) dist, comp, info, R = self.metric_R.dise_comp_info( out_z, out_y, 'random_forest') indx_list.append(self.global_iter) loss_list.append(recon_loss.data.item()) corr_list.append(corr) dist_list.append(dist) comp_list.append(comp) info_list.append(info) R_list.append(R) #print('======================================') with open(self.metric_dir + '/results.txt', 'a') as f: f.write( '\n [{:0>7d}] \t loss:{:.3f} \t corr:{:.3f} \t dise:{:.3f} \t comp:{:.3f}\t info:{:.3f}' .format(self.global_iter, recon_loss.data.item(), corr, dist[-1], comp[-1], info[-1])) #print('======================================') if self.global_iter % self.save_step == 1: self.save_checkpoint('last') #pbar.write('Saved checkpoint(iter:{})'.format(self.global_iter)) if self.save_gifs: self.save_gif() if self.global_iter % 50000 == 0: self.save_checkpoint(str(self.global_iter)) if local_iter >= self.max_iter_per_gen: out = True break np.savez( self.metric_dir + '/metrics_gen' + str(self.global_gen) + '.npz', indx=np.asarray(indx_list), # (len,) loss=np.asarray(loss_list), # (len,) corr=np.asarray(corr_list), # (len,) dist=np.asarray(dist_list), # (len, z_dim+1) comp=np.asarray(comp_list), # (len, y_dim+1) info=np.asarray(info_list), # (len, y_dim+1) R=np.asarray(R_list)) # (len, z_dim, y_dim) pbar.write("[Training Finished]") pbar.close() sys.stdout.flush() return recon_loss_list, bvae_loss_list