def plot_decoder_img(NN, test_data, pdf_path, global_iter, sbd, type, n=20 ):
   
    if type =='test':
        pdf_path = "{}/testing_recon{}.pdf".format(pdf_path, global_iter)
    elif type =='gnrl':
        pdf_path = "{}/gnrl_recon{}.pdf".format(pdf_path, global_iter)
        
    pdf = matplotlib.backends.backend_pdf.PdfPages(pdf_path)
    for i in range(n):
        sample = test_data.__getitem__(i)
        x = sample['x']
        y = sample['y']
        if sbd:
            sbd_model = spatial_broadcast_decoder()
            x = torch.unsqueeze(x, 0)
            x = sbd_model(x)
            
        x_recon = NN._decode(x)
        x = x.numpy()
        y = y.numpy()
        x_recon = F.sigmoid(x_recon).numpy()
            
        #plt.gray()    if want grey image instead of coloured 
        f, ( a1, a2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1,1]})
        #https://scipy-cookbook.readthedocs.io/items/Matplotlib_Show_colormaps.html
        #a0.imshow(x[0,0,:,:]) #cmap='...' 
        a1.imshow(y[0,:,:])
        a2.imshow(x_recon[0,0,:,:])
        f.tight_layout()
        f.suptitle(np.around(x,2))
        pdf.savefig(f, dpi=300)
        plt.close()

    pdf.close()
Esempio n. 2
0
    def test_plots(self):
        #self.net.eval()   but supposed to add when testing?
        net_copy = deepcopy(self.net)
        net_copy.to('cpu')
        
        if not self.AE:
            print("creating sample images!")
            #Print sample images by decoding samples of normal distribution size of z_dim
            if self.z_dim_bern == 0:         
                sample = torch.randn(16, self.z_dim_gauss)
            elif self.z_dim_gauss == 0:         
                sample = torch.rand(16, self.z_dim_bern)
            elif self.z_dim_bern !=0 and self.z_dim_gauss != 0:
                sample_2 = torch.randn(16, self.z_dim_gauss)
                sample_1 = torch.rand(16, self.z_dim_bern)
                sample = torch.cat((sample_1,sample_2), 1)
            with torch.no_grad():
                if self.sbd == True:
                    sbd_decoder = SB_decoder(self.z_dim_bern, self.z_dim_gauss, self.n_filter, self.nc)
                    sbd_model = spatial_broadcast_decoder()
                    sample = sbd_model(sample)
                    test_recon = sbd_decoder(sample)
                else:
                    print(sample)
                    test_recon = net_copy._decode(sample)
            #torchvision.utils.save_image( F.sigmoid(test_recon).view(
            #    test_recon.size(0),1, self.image_size, self.image_size).data.cpu(), '{}/sampling_z_{}.png'.
            #                                format(self.output_dir, self.global_iter))        

        print("Constructing Z hist!")
        construct_z_hist(net_copy, self.gnrl_dl, self.global_iter, self.output_dir,self.AE, dim='blah')

        
        if not self.AE:
            #select test image to traverse 
            print("Traversing!")
            with torch.no_grad():
                for i in range(3):
                    example_id = self.gnrl_data.__getitem__(i+random.randint(0,20))
                    traverse_z(net_copy, example_id, ID=str(i),output_dir=self.output_dir, 
                               global_iter=self.global_iter, sbd = self.sbd, num_frames=100 )
    
        
        print('Reconstructing generalisation images!')
        with torch.no_grad():
            plotsave_tests(net_copy, self.gnrl_data, self.output_dir, self.global_iter, type="Gnrl", n=30 )
Esempio n. 3
0
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        
        self.testing_method = args.testing_method
        self.encoder = args.encoder
        self.decoder = args.decoder
        self.n_filter = args.n_filter
        self.n_rep = args.n_rep
        self.kernel_size = args.kernel_size
        self.padding = args.padding
        self.sbd = args.sbd
        self.AE = args.AE
        
        self.encoder_target_type = args.encoder_target_type
                
        data_csv = "{}digts.csv".format(args.dset_dir)
        data = pd.read_csv(data_csv, header=None)
        if  data.shape[1] > 37:
            self.n_digits =3
        elif data.shape[1] <= 37:
            self.n_digits = 2
            
        if args.encoder_target_type == 'joint':
            self.z_dim = 10
        elif args.encoder_target_type == 'black_white':
            self.z_dim = 20
        elif args.encoder_target_type == 'depth_black_white':
            self.z_dim = 21
        elif args.encoder_target_type == 'depth_black_white_xy_xy':
            self.z_dim = 25
        elif args.encoder_target_type== "depth_ordered_one_hot":
            if self.n_digits == 2:
                self.z_dim = 20
            elif self.n_digits ==3:
                self.z_dim = 30
        elif args.encoder_target_type== "depth_ordered_one_hot_xy":
            if self.n_digits == 2:
                self.z_dim = 24
            elif self.n_digits ==3:
                self.z_dim = 36
        
        if args.dataset.lower() == 'digits_gray':
            self.nc = 1
        elif args.dataset.lower() == 'digits_col':
            self.nc = 3
        else:
            raise NotImplementedError
        
        net = multi_VAE(self.encoder,self.decoder,self.z_dim, 0 ,self.n_filter,self.nc,
                        self.n_rep,self.sbd, self.kernel_size, self.padding, self.AE)
        
        
        if self.sbd == True:
            self.decoder = SB_decoder(self.z_dim, 0, self.n_filter, self.nc)
            self.sbd_model = spatial_broadcast_decoder()
            
        
        #print parameters in model
        encoder_size = 0
        decoder_size = 0
        for name, param in net.named_parameters():
            if param.requires_grad:
                if 'encoder' in name:
                    encoder_size += param.numel()
                elif 'decoder' in name:
                    decoder_size += param.numel()
        tot_size = encoder_size + decoder_size
        if self.testing_method =='supervised_encoder':
            print(encoder_size ,"parameters in the ENCODER!")
            self.params = encoder_size
        
        elif self.testing_method =='supervised_decoder':
            print(decoder_size ,"parameters in the DECODER!")
            self.params = decoder_size
            
        
        print("CUDA availability: " + str(torch.cuda.is_available()))
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if torch.cuda.device_count()>1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            net = nn.DataParallel(net)
            
        self.net = net.to(self.device) 
        print("net on cuda: " + str(next(self.net.parameters()).is_cuda))
        
           
        if self.testing_method == 'supervised_encoder':
            self.train_dl, self.gnrl_dl  = return_data_sup_encoder(args)
        elif self.testing_method == 'supervised_decoder':
            self.train_dl, self.gnrl_dl , self.gnrl_data =  return_data_sup_decoder(args)
        else:
            raise NotImplementedError    

          
        
        
        self.lr = args.lr
        self.l2_loss = args.l2_loss
        self.beta1 = args.beta1
        self.beta2 = args.beta2        
        if args.optim_type =='Adam':
            self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        elif args.optim_type =='SGD':
            self.optim = optim.SGD(self.net.parameters(), lr=self.lr, momentum=0.9)    
        self.max_epoch = args.max_epoch
        self.global_iter = 0
        self.max_epoch = args.max_epoch
        self.global_iter = 0
        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step
        
            
            
        self.image_size = args.image_size

        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        self.win_recon = None
        self.win_kld = None
        self.win_mu = None
        self.win_var = None
        #if self.viz_on:
        #    self.viz = visdom.Visdom(port=self.viz_port)
            
        self.save_output = args.save_output
        self.output_dir = args.output_dir 
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
            
        self.ckpt_dir = os.path.join(args.output_dir, args.viz_name)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        

        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        
        
        self.gather = DataGather(self.testing_method, self.encoder_target_type, self.n_digits)
def traverse_z(NN, example_id, ID, output_dir, global_iter, sbd ,num_frames = 100 ):
    z_dim = NN.z_dim_tot
    
    x_test_sample = example_id['x']
    y_test_sample = example_id['y']
    x_test_sample = torch.unsqueeze(x_test_sample, 0)
    
    #encode a sample image
    z_distributions = NN._encode(x_test_sample)
    if NN.z_dim_bern == 0:
        z_sample = z_distributions[:, :NN.z_dim_gauss]
    elif NN.z_dim_gauss == 0:
        z_sample = z_distributions[:, :NN.z_dim_bern]
    elif NN.z_dim_bern !=0 and NN.z_dim_gauss != 0:
        p = z_distributions[:, :NN.z_dim_bern]
        p = F.sigmoid(p)
        mu = z_distributions[:, NN.z_dim_bern:(NN.z_dim_bern+NN.z_dim_gauss)]
        z_sample = torch.cat((p, mu),1) 
   
    num_slice = int(1000/num_frames)

    if NN.z_dim_bern == 0:
        #create sorted normal samples & transverse_input matrix made from z encodings of sample image
        dist_samples = np.random.normal(loc=0, scale=1, size=1000)
        dist_samples.sort()
        dist_samples = torch.from_numpy(dist_samples[0::num_slice])
    elif NN.z_dim_gauss == 0:
        dist_samples = np.random.uniform(low=0, high=1, size=1000)
        dist_samples.sort()
        dist_samples = torch.from_numpy(dist_samples[0::num_slice])
    elif NN.z_dim_bern !=0 and NN.z_dim_gauss != 0:
        dist_samples_1= np.random.uniform(low=0, high=1, size=1000)
        dist_samples_2 = np.random.normal(loc=0, scale=1, size=1000)
        dist_samples_1.sort()
        dist_samples_2.sort()
        dist_samples_1 = torch.from_numpy(dist_samples_1[0::num_slice])
        dist_samples_2 = torch.from_numpy(dist_samples_2[0::num_slice])
        
            
    traverse_input = torch.mul(torch.ones(num_frames*z_dim,1),z_sample)
    
     
    #print(traverse_input.shape)

    if NN.z_dim_bern !=0 and NN.z_dim_gauss != 0:
        indexs = np.arange(0, num_frames*z_dim, num_frames)
        for i in indexs:
            z = int(i/num_frames)
            if z <= NN.z_dim_bern:
                traverse_input[i:(i+num_frames),z] = dist_samples_1
            else:
                traverse_input[i:(i+num_frames),z] = dist_samples_2
    else:
         #Populate matrix with individually varying Zs
        indexs = np.arange(0, num_frames*z_dim, num_frames)
        for i in indexs:
            z = int(i/num_frames)
            traverse_input[i:(i+num_frames),z] = dist_samples
   
    if sbd:
        sbd_decoder = SB_decoder(NN.z_dim_bern, NN.z_dim_gauss, NN.n_filter, NN.nc)
        sbd_model = spatial_broadcast_decoder()
        z_sample = sbd_model(z_sample)
        x_recon = sbd_decoder(z_sample)
        traverse_input = sbd_model(traverse_input)
        reconst = sbd_decoder(traverse_input)
        print(reconst.shape)
    else:
        #create all reconstruction images
        x_recon = NN._decode(z_sample)
        reconst = NN._decode(traverse_input)
        print(reconst.shape)

    #Create GIFs
    indexs = np.arange(0, num_frames*z_dim, num_frames)
    for i in indexs:
        #save images for each gif into the images list
        images = []
        for e in range(num_frames):
            #save images to make gifs into different folders
            filename = '{}/traversals{}_{}/z{}/img{}.png'.format(output_dir,global_iter,ID,int(i/num_frames),e)
            directory = os.path.dirname(filename)
            if not os.path.exists(directory):
                os.makedirs(directory)
            torchvision.utils.save_image(F.sigmoid(reconst[i+e,0,:,:].cpu()) , filename)
            images.append(imageio.imread(filename))


        #save all gifs into same folder
        filename_2 = '{}/traversals_gifs{}_{}/traversing_z_{}.gif'.format(
            output_dir,global_iter, ID,int(i/num_frames),int(i/num_frames))
        directory_2 = os.path.dirname(filename_2)
        if not os.path.exists(directory_2):
                os.makedirs(directory_2)
        imageio.mimsave('{}/traversals_gifs{}_{}/traversing_z_{}.gif'.format(
            output_dir, global_iter, ID, int(i/num_frames),int(i/num_frames)), images)
        
        with open('{}/traversals_gifs{}_{}/encoded_z.txt'.format(output_dir,global_iter,ID), 'w') as f:
            f.write(str(z_sample.numpy()))
        
        #add the reconstruction image to the GIF image folder
        torchvision.utils.save_image(F.sigmoid(x_recon[0,0,:,:]),
                                        '{}/traversals_gifs{}_{}/recon.png'.format(output_dir,global_iter,ID))
        #add the actual target image to the GIF image folder
        torchvision.utils.save_image(y_test_sample[0,:,:],
                                        '{}/traversals_gifs{}_{}/target.png'.format(output_dir,global_iter,ID))
        shutil.rmtree(directory)