def plot_latent_interpolations2d(self, attr_str1, attr_str2, num_points=10): x1 = torch.linspace(-4., 4.0, num_points) x2 = torch.linspace(-4., 4.0, num_points) z1, z2 = torch.meshgrid([x1, x2]) total_num_points = z1.size(0) * z1.size(1) _, _, data_loader = self.dataset.data_loaders(batch_size=1) interp_dict = self.compute_eval_metrics()["interpretability"] dim1 = interp_dict[attr_str1][0] dim2 = interp_dict[attr_str2][0] for sample_id, batch in tqdm(enumerate(data_loader)): if sample_id == 9: inputs, labels = self.process_batch_data(batch) inputs = to_cuda_variable(inputs) recons, _, _, z, _ = self.model(inputs) recons = torch.sigmoid(recons) z = z.repeat(total_num_points, 1) z[:, dim1] = z1.contiguous().view(1, -1) z[:, dim2] = z2.contiguous().view(1, -1) # z = torch.flip(z, dims=[0]) outputs = torch.sigmoid(self.model.decode(z)) save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'latent_interpolations_2d_({attr_str1},{attr_str2})_{sample_id}.png' ) save_image(outputs.cpu(), save_filepath, nrow=num_points, pad_value=1.0) # save original image org_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'original_{sample_id}.png') save_image(inputs.cpu(), org_save_filepath, nrow=1, pad_value=1.0) # save reconstruction recons_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'recons_{sample_id}.png') save_image(recons.cpu(), recons_save_filepath, nrow=1, pad_value=1.0) if sample_id == 10: break
def plot_latent_surface(self, attr_str, dim1=0, dim2=1, grid_res=0.1): # create the dataspace x1 = torch.arange(-5., 5., grid_res) x2 = torch.arange(-5., 5., grid_res) z1, z2 = torch.meshgrid([x1, x2]) num_points = z1.size(0) * z1.size(1) z = torch.randn(1, self.model.z_dim) z = z.repeat(num_points, 1) z[:, dim1] = z1.contiguous().view(1, -1) z[:, dim2] = z2.contiguous().view(1, -1) z = to_cuda_variable(z) mini_batch_size = 500 num_mini_batches = num_points // mini_batch_size attr_labels_all = [] for i in tqdm(range(num_mini_batches)): z_batch = z[i * mini_batch_size:(i + 1) * mini_batch_size, :] outputs = torch.sigmoid(self.model.decode(z_batch)) labels = self.compute_mnist_morpho_labels(outputs, attr_str) attr_labels_all.append(torch.from_numpy(labels)) attr_labels_all = to_numpy(torch.cat(attr_labels_all, 0)) z = to_numpy(z)[:num_mini_batches * mini_batch_size, :] save_filename = os.path.join(Trainer.get_save_dir(self.model), f'latent_surface_{attr_str}.png') plot_dim(z, attr_labels_all, save_filename, dim1=dim1, dim2=dim2)
def plot_latent_interpolations(self, attr_str='slant', num_points=10): x1 = torch.linspace(-4, 4.0, num_points) _, _, data_loader = self.dataset.data_loaders(batch_size=1) interp_dict = self.compute_eval_metrics()["interpretability"] dim = interp_dict[attr_str][0] for sample_id, batch in tqdm(enumerate(data_loader)): # for MNIST [5, 1, 30, 19, 23, 21, 17, 61, 9, 28] if sample_id in [5, 1, 30, 19, 23, 21, 17, 61, 9, 28]: inputs, labels = self.process_batch_data(batch) inputs = to_cuda_variable(inputs) recons, _, _, z, _ = self.model(inputs) recons = torch.sigmoid(recons) z = z.repeat(num_points, 1) z[:, dim] = x1.contiguous() outputs = torch.sigmoid(self.model.decode(z)) # save interpolation save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'latent_interpolations_{attr_str}_{sample_id}.png') save_image(outputs.cpu(), save_filepath, nrow=num_points, pad_value=1.0) # save original image org_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'original_{sample_id}.png') save_image(inputs.cpu(), org_save_filepath, nrow=1, pad_value=1.0) # save reconstruction recons_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'recons_{sample_id}.png') save_image(recons.cpu(), recons_save_filepath, nrow=1, pad_value=1.0) if sample_id == 62: break
def plot_latent_reconstructions(self, num_points=10): _, _, data_loader = self.dataset.data_loaders(batch_size=num_points) for sample_id, batch in tqdm(enumerate(data_loader)): inputs, labels = self.process_batch_data(batch) inputs = to_cuda_variable(inputs) recons, _, _, z, _ = self.model(inputs) recons = torch.sigmoid(recons) # save original image org_save_filepath = os.path.join(Trainer.get_save_dir(self.model), f'r_original_{sample_id}.png') save_image(inputs.cpu(), org_save_filepath, nrow=num_points, pad_value=1.0) # save reconstruction recons_save_filepath = os.path.join( Trainer.get_save_dir(self.model), f'r_recons_{sample_id}.png') save_image(recons.cpu(), recons_save_filepath, nrow=num_points, pad_value=1.0) break
def plot_data_dist(self, latent_codes, attributes, attr_str, dim1=0, dim2=1): save_filename = os.path.join(Trainer.get_save_dir(self.model), 'data_dist_' + attr_str + '.png') img = plot_dim(latent_codes, attributes[:, self.attr_dict[attr_str]], save_filename, dim1=dim1, dim2=dim2, xlim=4.0, ylim=4.0) return img
def create_latent_gifs(self, sample_id=9, num_points=10): x1 = torch.linspace(-4, 4.0, num_points) _, _, data_loader = self.dataset.data_loaders(batch_size=1) interp_dict = self.compute_eval_metrics()["interpretability"] for sid, batch in tqdm(enumerate(data_loader)): if sid == sample_id: inputs, labels = self.process_batch_data(batch) inputs = to_cuda_variable(inputs) _, _, _, z, _ = self.model(inputs) z = z.repeat(num_points, 1) outputs = [] for attr_str in self.attr_dict.keys(): if attr_str == 'digit_identity' or attr_str == 'color': continue dim = interp_dict[attr_str][0] z_copy = z.clone() z_copy[:, dim] = x1.contiguous() outputs.append(torch.sigmoid(self.model.decode(z_copy))) outputs = torch.unsqueeze(torch.cat(outputs, dim=1), dim=2) interps = [] for n in range(outputs.shape[0]): image_grid = make_grid(outputs[n], padding=2, pad_value=1.0).detach().cpu() np_image = image_grid.mul(255).clamp(0, 255).byte().permute( 1, 2, 0).numpy() interps.append(Image.fromarray(np_image)) # save gif gif_filepath = os.path.join( Trainer.get_save_dir(self.model), f'gif_interpolations_{self.dataset_type}_{sample_id}.gif') save_gif_from_list(interps, gif_filepath) if sid > sample_id: break