コード例 #1
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
 def plot_latent_interpolations2d(self,
                                  attr_str1,
                                  attr_str2,
                                  num_points=10):
     x1 = torch.linspace(-4., 4.0, num_points)
     x2 = torch.linspace(-4., 4.0, num_points)
     z1, z2 = torch.meshgrid([x1, x2])
     total_num_points = z1.size(0) * z1.size(1)
     _, _, data_loader = self.dataset.data_loaders(batch_size=1)
     interp_dict = self.compute_eval_metrics()["interpretability"]
     dim1 = interp_dict[attr_str1][0]
     dim2 = interp_dict[attr_str2][0]
     for sample_id, batch in tqdm(enumerate(data_loader)):
         if sample_id == 9:
             inputs, labels = self.process_batch_data(batch)
             inputs = to_cuda_variable(inputs)
             recons, _, _, z, _ = self.model(inputs)
             recons = torch.sigmoid(recons)
             z = z.repeat(total_num_points, 1)
             z[:, dim1] = z1.contiguous().view(1, -1)
             z[:, dim2] = z2.contiguous().view(1, -1)
             # z = torch.flip(z, dims=[0])
             outputs = torch.sigmoid(self.model.decode(z))
             save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'latent_interpolations_2d_({attr_str1},{attr_str2})_{sample_id}.png'
             )
             save_image(outputs.cpu(),
                        save_filepath,
                        nrow=num_points,
                        pad_value=1.0)
             # save original image
             org_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'original_{sample_id}.png')
             save_image(inputs.cpu(),
                        org_save_filepath,
                        nrow=1,
                        pad_value=1.0)
             # save reconstruction
             recons_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'recons_{sample_id}.png')
             save_image(recons.cpu(),
                        recons_save_filepath,
                        nrow=1,
                        pad_value=1.0)
         if sample_id == 10:
             break
コード例 #2
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
    def plot_latent_surface(self, attr_str, dim1=0, dim2=1, grid_res=0.1):
        # create the dataspace
        x1 = torch.arange(-5., 5., grid_res)
        x2 = torch.arange(-5., 5., grid_res)
        z1, z2 = torch.meshgrid([x1, x2])
        num_points = z1.size(0) * z1.size(1)
        z = torch.randn(1, self.model.z_dim)
        z = z.repeat(num_points, 1)
        z[:, dim1] = z1.contiguous().view(1, -1)
        z[:, dim2] = z2.contiguous().view(1, -1)
        z = to_cuda_variable(z)

        mini_batch_size = 500
        num_mini_batches = num_points // mini_batch_size
        attr_labels_all = []
        for i in tqdm(range(num_mini_batches)):
            z_batch = z[i * mini_batch_size:(i + 1) * mini_batch_size, :]
            outputs = torch.sigmoid(self.model.decode(z_batch))
            labels = self.compute_mnist_morpho_labels(outputs, attr_str)
            attr_labels_all.append(torch.from_numpy(labels))
        attr_labels_all = to_numpy(torch.cat(attr_labels_all, 0))
        z = to_numpy(z)[:num_mini_batches * mini_batch_size, :]
        save_filename = os.path.join(Trainer.get_save_dir(self.model),
                                     f'latent_surface_{attr_str}.png')
        plot_dim(z, attr_labels_all, save_filename, dim1=dim1, dim2=dim2)
コード例 #3
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
 def plot_latent_interpolations(self, attr_str='slant', num_points=10):
     x1 = torch.linspace(-4, 4.0, num_points)
     _, _, data_loader = self.dataset.data_loaders(batch_size=1)
     interp_dict = self.compute_eval_metrics()["interpretability"]
     dim = interp_dict[attr_str][0]
     for sample_id, batch in tqdm(enumerate(data_loader)):
         # for MNIST [5, 1, 30, 19, 23, 21, 17, 61, 9, 28]
         if sample_id in [5, 1, 30, 19, 23, 21, 17, 61, 9, 28]:
             inputs, labels = self.process_batch_data(batch)
             inputs = to_cuda_variable(inputs)
             recons, _, _, z, _ = self.model(inputs)
             recons = torch.sigmoid(recons)
             z = z.repeat(num_points, 1)
             z[:, dim] = x1.contiguous()
             outputs = torch.sigmoid(self.model.decode(z))
             # save interpolation
             save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'latent_interpolations_{attr_str}_{sample_id}.png')
             save_image(outputs.cpu(),
                        save_filepath,
                        nrow=num_points,
                        pad_value=1.0)
             # save original image
             org_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'original_{sample_id}.png')
             save_image(inputs.cpu(),
                        org_save_filepath,
                        nrow=1,
                        pad_value=1.0)
             # save reconstruction
             recons_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'recons_{sample_id}.png')
             save_image(recons.cpu(),
                        recons_save_filepath,
                        nrow=1,
                        pad_value=1.0)
         if sample_id == 62:
             break
コード例 #4
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
 def plot_latent_reconstructions(self, num_points=10):
     _, _, data_loader = self.dataset.data_loaders(batch_size=num_points)
     for sample_id, batch in tqdm(enumerate(data_loader)):
         inputs, labels = self.process_batch_data(batch)
         inputs = to_cuda_variable(inputs)
         recons, _, _, z, _ = self.model(inputs)
         recons = torch.sigmoid(recons)
         # save original image
         org_save_filepath = os.path.join(Trainer.get_save_dir(self.model),
                                          f'r_original_{sample_id}.png')
         save_image(inputs.cpu(),
                    org_save_filepath,
                    nrow=num_points,
                    pad_value=1.0)
         # save reconstruction
         recons_save_filepath = os.path.join(
             Trainer.get_save_dir(self.model), f'r_recons_{sample_id}.png')
         save_image(recons.cpu(),
                    recons_save_filepath,
                    nrow=num_points,
                    pad_value=1.0)
         break
コード例 #5
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
 def plot_data_dist(self,
                    latent_codes,
                    attributes,
                    attr_str,
                    dim1=0,
                    dim2=1):
     save_filename = os.path.join(Trainer.get_save_dir(self.model),
                                  'data_dist_' + attr_str + '.png')
     img = plot_dim(latent_codes,
                    attributes[:, self.attr_dict[attr_str]],
                    save_filename,
                    dim1=dim1,
                    dim2=dim2,
                    xlim=4.0,
                    ylim=4.0)
     return img
コード例 #6
0
ファイル: image_vae_trainer.py プロジェクト: zbxzc35/ar-vae
    def create_latent_gifs(self, sample_id=9, num_points=10):
        x1 = torch.linspace(-4, 4.0, num_points)
        _, _, data_loader = self.dataset.data_loaders(batch_size=1)
        interp_dict = self.compute_eval_metrics()["interpretability"]
        for sid, batch in tqdm(enumerate(data_loader)):
            if sid == sample_id:
                inputs, labels = self.process_batch_data(batch)
                inputs = to_cuda_variable(inputs)
                _, _, _, z, _ = self.model(inputs)
                z = z.repeat(num_points, 1)
                outputs = []
                for attr_str in self.attr_dict.keys():
                    if attr_str == 'digit_identity' or attr_str == 'color':
                        continue
                    dim = interp_dict[attr_str][0]
                    z_copy = z.clone()
                    z_copy[:, dim] = x1.contiguous()

                    outputs.append(torch.sigmoid(self.model.decode(z_copy)))
                outputs = torch.unsqueeze(torch.cat(outputs, dim=1), dim=2)
                interps = []
                for n in range(outputs.shape[0]):
                    image_grid = make_grid(outputs[n],
                                           padding=2,
                                           pad_value=1.0).detach().cpu()
                    np_image = image_grid.mul(255).clamp(0,
                                                         255).byte().permute(
                                                             1, 2, 0).numpy()
                    interps.append(Image.fromarray(np_image))
                # save gif
                gif_filepath = os.path.join(
                    Trainer.get_save_dir(self.model),
                    f'gif_interpolations_{self.dataset_type}_{sample_id}.gif')
                save_gif_from_list(interps, gif_filepath)
            if sid > sample_id:
                break