def plot_decoder_img(NN, test_data, pdf_path, global_iter, sbd, type, n=20 ): if type =='test': pdf_path = "{}/testing_recon{}.pdf".format(pdf_path, global_iter) elif type =='gnrl': pdf_path = "{}/gnrl_recon{}.pdf".format(pdf_path, global_iter) pdf = matplotlib.backends.backend_pdf.PdfPages(pdf_path) for i in range(n): sample = test_data.__getitem__(i) x = sample['x'] y = sample['y'] if sbd: sbd_model = spatial_broadcast_decoder() x = torch.unsqueeze(x, 0) x = sbd_model(x) x_recon = NN._decode(x) x = x.numpy() y = y.numpy() x_recon = F.sigmoid(x_recon).numpy() #plt.gray() if want grey image instead of coloured f, ( a1, a2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1,1]}) #https://scipy-cookbook.readthedocs.io/items/Matplotlib_Show_colormaps.html #a0.imshow(x[0,0,:,:]) #cmap='...' a1.imshow(y[0,:,:]) a2.imshow(x_recon[0,0,:,:]) f.tight_layout() f.suptitle(np.around(x,2)) pdf.savefig(f, dpi=300) plt.close() pdf.close()
def test_plots(self): #self.net.eval() but supposed to add when testing? net_copy = deepcopy(self.net) net_copy.to('cpu') if not self.AE: print("creating sample images!") #Print sample images by decoding samples of normal distribution size of z_dim if self.z_dim_bern == 0: sample = torch.randn(16, self.z_dim_gauss) elif self.z_dim_gauss == 0: sample = torch.rand(16, self.z_dim_bern) elif self.z_dim_bern !=0 and self.z_dim_gauss != 0: sample_2 = torch.randn(16, self.z_dim_gauss) sample_1 = torch.rand(16, self.z_dim_bern) sample = torch.cat((sample_1,sample_2), 1) with torch.no_grad(): if self.sbd == True: sbd_decoder = SB_decoder(self.z_dim_bern, self.z_dim_gauss, self.n_filter, self.nc) sbd_model = spatial_broadcast_decoder() sample = sbd_model(sample) test_recon = sbd_decoder(sample) else: print(sample) test_recon = net_copy._decode(sample) #torchvision.utils.save_image( F.sigmoid(test_recon).view( # test_recon.size(0),1, self.image_size, self.image_size).data.cpu(), '{}/sampling_z_{}.png'. # format(self.output_dir, self.global_iter)) print("Constructing Z hist!") construct_z_hist(net_copy, self.gnrl_dl, self.global_iter, self.output_dir,self.AE, dim='blah') if not self.AE: #select test image to traverse print("Traversing!") with torch.no_grad(): for i in range(3): example_id = self.gnrl_data.__getitem__(i+random.randint(0,20)) traverse_z(net_copy, example_id, ID=str(i),output_dir=self.output_dir, global_iter=self.global_iter, sbd = self.sbd, num_frames=100 ) print('Reconstructing generalisation images!') with torch.no_grad(): plotsave_tests(net_copy, self.gnrl_data, self.output_dir, self.global_iter, type="Gnrl", n=30 )
def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.testing_method = args.testing_method self.encoder = args.encoder self.decoder = args.decoder self.n_filter = args.n_filter self.n_rep = args.n_rep self.kernel_size = args.kernel_size self.padding = args.padding self.sbd = args.sbd self.AE = args.AE self.encoder_target_type = args.encoder_target_type data_csv = "{}digts.csv".format(args.dset_dir) data = pd.read_csv(data_csv, header=None) if data.shape[1] > 37: self.n_digits =3 elif data.shape[1] <= 37: self.n_digits = 2 if args.encoder_target_type == 'joint': self.z_dim = 10 elif args.encoder_target_type == 'black_white': self.z_dim = 20 elif args.encoder_target_type == 'depth_black_white': self.z_dim = 21 elif args.encoder_target_type == 'depth_black_white_xy_xy': self.z_dim = 25 elif args.encoder_target_type== "depth_ordered_one_hot": if self.n_digits == 2: self.z_dim = 20 elif self.n_digits ==3: self.z_dim = 30 elif args.encoder_target_type== "depth_ordered_one_hot_xy": if self.n_digits == 2: self.z_dim = 24 elif self.n_digits ==3: self.z_dim = 36 if args.dataset.lower() == 'digits_gray': self.nc = 1 elif args.dataset.lower() == 'digits_col': self.nc = 3 else: raise NotImplementedError net = multi_VAE(self.encoder,self.decoder,self.z_dim, 0 ,self.n_filter,self.nc, self.n_rep,self.sbd, self.kernel_size, self.padding, self.AE) if self.sbd == True: self.decoder = SB_decoder(self.z_dim, 0, self.n_filter, self.nc) self.sbd_model = spatial_broadcast_decoder() #print parameters in model encoder_size = 0 decoder_size = 0 for name, param in net.named_parameters(): if param.requires_grad: if 'encoder' in name: encoder_size += param.numel() elif 'decoder' in name: decoder_size += param.numel() tot_size = encoder_size + decoder_size if self.testing_method =='supervised_encoder': print(encoder_size ,"parameters in the ENCODER!") self.params = encoder_size elif self.testing_method =='supervised_decoder': print(decoder_size ,"parameters in the DECODER!") self.params = decoder_size print("CUDA availability: " + str(torch.cuda.is_available())) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count()>1: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net) self.net = net.to(self.device) print("net on cuda: " + str(next(self.net.parameters()).is_cuda)) if self.testing_method == 'supervised_encoder': self.train_dl, self.gnrl_dl = return_data_sup_encoder(args) elif self.testing_method == 'supervised_decoder': self.train_dl, self.gnrl_dl , self.gnrl_data = return_data_sup_decoder(args) else: raise NotImplementedError self.lr = args.lr self.l2_loss = args.l2_loss self.beta1 = args.beta1 self.beta2 = args.beta2 if args.optim_type =='Adam': self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) elif args.optim_type =='SGD': self.optim = optim.SGD(self.net.parameters(), lr=self.lr, momentum=0.9) self.max_epoch = args.max_epoch self.global_iter = 0 self.max_epoch = args.max_epoch self.global_iter = 0 self.gather_step = args.gather_step self.display_step = args.display_step self.save_step = args.save_step self.image_size = args.image_size self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on self.win_recon = None self.win_kld = None self.win_mu = None self.win_var = None #if self.viz_on: # self.viz = visdom.Visdom(port=self.viz_port) self.save_output = args.save_output self.output_dir = args.output_dir if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) self.ckpt_dir = os.path.join(args.output_dir, args.viz_name) if not os.path.exists(self.ckpt_dir): os.makedirs(self.ckpt_dir, exist_ok=True) self.ckpt_name = args.ckpt_name if self.ckpt_name is not None: self.load_checkpoint(self.ckpt_name) self.gather_step = args.gather_step self.display_step = args.display_step self.save_step = args.save_step self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.gather = DataGather(self.testing_method, self.encoder_target_type, self.n_digits)
def traverse_z(NN, example_id, ID, output_dir, global_iter, sbd ,num_frames = 100 ): z_dim = NN.z_dim_tot x_test_sample = example_id['x'] y_test_sample = example_id['y'] x_test_sample = torch.unsqueeze(x_test_sample, 0) #encode a sample image z_distributions = NN._encode(x_test_sample) if NN.z_dim_bern == 0: z_sample = z_distributions[:, :NN.z_dim_gauss] elif NN.z_dim_gauss == 0: z_sample = z_distributions[:, :NN.z_dim_bern] elif NN.z_dim_bern !=0 and NN.z_dim_gauss != 0: p = z_distributions[:, :NN.z_dim_bern] p = F.sigmoid(p) mu = z_distributions[:, NN.z_dim_bern:(NN.z_dim_bern+NN.z_dim_gauss)] z_sample = torch.cat((p, mu),1) num_slice = int(1000/num_frames) if NN.z_dim_bern == 0: #create sorted normal samples & transverse_input matrix made from z encodings of sample image dist_samples = np.random.normal(loc=0, scale=1, size=1000) dist_samples.sort() dist_samples = torch.from_numpy(dist_samples[0::num_slice]) elif NN.z_dim_gauss == 0: dist_samples = np.random.uniform(low=0, high=1, size=1000) dist_samples.sort() dist_samples = torch.from_numpy(dist_samples[0::num_slice]) elif NN.z_dim_bern !=0 and NN.z_dim_gauss != 0: dist_samples_1= np.random.uniform(low=0, high=1, size=1000) dist_samples_2 = np.random.normal(loc=0, scale=1, size=1000) dist_samples_1.sort() dist_samples_2.sort() dist_samples_1 = torch.from_numpy(dist_samples_1[0::num_slice]) dist_samples_2 = torch.from_numpy(dist_samples_2[0::num_slice]) traverse_input = torch.mul(torch.ones(num_frames*z_dim,1),z_sample) #print(traverse_input.shape) if NN.z_dim_bern !=0 and NN.z_dim_gauss != 0: indexs = np.arange(0, num_frames*z_dim, num_frames) for i in indexs: z = int(i/num_frames) if z <= NN.z_dim_bern: traverse_input[i:(i+num_frames),z] = dist_samples_1 else: traverse_input[i:(i+num_frames),z] = dist_samples_2 else: #Populate matrix with individually varying Zs indexs = np.arange(0, num_frames*z_dim, num_frames) for i in indexs: z = int(i/num_frames) traverse_input[i:(i+num_frames),z] = dist_samples if sbd: sbd_decoder = SB_decoder(NN.z_dim_bern, NN.z_dim_gauss, NN.n_filter, NN.nc) sbd_model = spatial_broadcast_decoder() z_sample = sbd_model(z_sample) x_recon = sbd_decoder(z_sample) traverse_input = sbd_model(traverse_input) reconst = sbd_decoder(traverse_input) print(reconst.shape) else: #create all reconstruction images x_recon = NN._decode(z_sample) reconst = NN._decode(traverse_input) print(reconst.shape) #Create GIFs indexs = np.arange(0, num_frames*z_dim, num_frames) for i in indexs: #save images for each gif into the images list images = [] for e in range(num_frames): #save images to make gifs into different folders filename = '{}/traversals{}_{}/z{}/img{}.png'.format(output_dir,global_iter,ID,int(i/num_frames),e) directory = os.path.dirname(filename) if not os.path.exists(directory): os.makedirs(directory) torchvision.utils.save_image(F.sigmoid(reconst[i+e,0,:,:].cpu()) , filename) images.append(imageio.imread(filename)) #save all gifs into same folder filename_2 = '{}/traversals_gifs{}_{}/traversing_z_{}.gif'.format( output_dir,global_iter, ID,int(i/num_frames),int(i/num_frames)) directory_2 = os.path.dirname(filename_2) if not os.path.exists(directory_2): os.makedirs(directory_2) imageio.mimsave('{}/traversals_gifs{}_{}/traversing_z_{}.gif'.format( output_dir, global_iter, ID, int(i/num_frames),int(i/num_frames)), images) with open('{}/traversals_gifs{}_{}/encoded_z.txt'.format(output_dir,global_iter,ID), 'w') as f: f.write(str(z_sample.numpy())) #add the reconstruction image to the GIF image folder torchvision.utils.save_image(F.sigmoid(x_recon[0,0,:,:]), '{}/traversals_gifs{}_{}/recon.png'.format(output_dir,global_iter,ID)) #add the actual target image to the GIF image folder torchvision.utils.save_image(y_test_sample[0,:,:], '{}/traversals_gifs{}_{}/target.png'.format(output_dir,global_iter,ID)) shutil.rmtree(directory)