def visualize_dataset(dataset_name, output_path, num_animations=5, num_frames=20, fps=10): """Visualizes the data set by saving images to output_path. For each latent factor, outputs 16 images where only that latent factor is varied while all others are kept constant. Args: dataset_name: String with name of dataset as defined in named_data.py. output_path: String with path in which to create the visualizations. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. """ data = named_data.get_named_ground_truth_data(dataset_name) random_state = np.random.RandomState(0) # Create output folder if necessary. path = os.path.join(output_path, dataset_name) if not gfile.IsDirectory(path): gfile.MakeDirs(path) # Create still images. for i in range(data.num_factors): factors = data.sample_factors(16, random_state) indices = [j for j in range(data.num_factors) if i != j] factors[:, indices] = factors[0, indices] images = data.sample_observations_from_factors(factors, random_state) visualize_util.grid_save_images( images, os.path.join(path, "variations_of_factor%s.png" % i)) # Create animations. for i in range(num_animations): base_factor = data.sample_factors(1, random_state) images = [] for j, num_atoms in enumerate(data.factors_num_values): factors = np.repeat(base_factor, num_frames, axis=0) factors[:, j] = visualize_util.cycle_factor(base_factor[0, j], num_atoms, num_frames) images.append( data.sample_observations_from_factors(factors, random_state)) visualize_util.save_animation( np.array(images), os.path.join(path, "animation%d.gif" % i), fps)
def visualize_weakly_supervised_dataset(data, path, num_animations=10, num_frames=20, fps=10): """Visualizes the data set by saving images to output_path. For each latent factor, outputs 16 images where only that latent factor is varied while all others are kept constant. Args: data: String with name of dataset as defined in named_data.py. path: String with path in which to create the visualizations. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. """ random_state = np.random.RandomState(0) # Create output folder if necessary. if not tf.compat.v1.gfile.IsDirectory(path): tf.compat.v1.gfile.MakeDirs(path) # Create animations. images = [] for i in range(num_animations): images.append([]) factor = data.sample_factors(1, random_state) images[i].append( np.squeeze(data.sample_observations_from_factors( factor, random_state), axis=0)) for _ in range(num_frames): factor, _ = simple_dynamics(factor, data, random_state) images[i].append( np.squeeze(data.sample_observations_from_factors( factor, random_state), axis=0)) visualize_util.save_animation(np.array(images), os.path.join(path, "animation.gif"), fps)
def visualize(model_dir, output_dir, overwrite=False, num_animations=5, num_frames=20, fps=10, num_points_irs=10000): """Takes trained model from model_dir and visualizes it in output_dir. Args: model_dir: Path to directory where the trained model is saved. output_dir: Path to output directory. overwrite: Boolean indicating whether to overwrite output directory. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. num_points_irs: Number of points to be used for the IRS plots. """ # Fix the random seed for reproducibility. random_state = np.random.RandomState(0) # Create the output directory if necessary. if tf.gfile.IsDirectory(output_dir): if overwrite: tf.gfile.DeleteRecursively(output_dir) else: raise ValueError( "Directory already exists and overwrite is False.") # Automatically set the proper data set if necessary. We replace the active # gin config as this will lead to a valid gin config file where the data set # is present. # Obtain the dataset name from the gin config of the previous step. gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin") gin_dict = results.gin_dict(gin_config_file) gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace("'", "")) # Automatically infer the activation function from gin config. activation_str = gin_dict["reconstruction_loss.activation"] if activation_str == "'logits'": activation = sigmoid elif activation_str == "'tanh'": activation = tanh else: raise ValueError( "Activation function could not be infered from gin config.") dataset = named_data.get_named_ground_truth_data() num_pics = 64 module_path = os.path.join(model_dir, "tfhub") with hub.eval_function_for_module(module_path) as f: # Save reconstructions. real_pics = dataset.sample_observations(num_pics, random_state) raw_pics = f(dict(images=real_pics), signature="reconstructions", as_dict=True)["images"] pics = activation(raw_pics) paired_pics = np.concatenate((real_pics, pics), axis=2) paired_pics = [ paired_pics[i, :, :, :] for i in range(paired_pics.shape[0]) ] results_dir = os.path.join(output_dir, "reconstructions") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( paired_pics, os.path.join(results_dir, "reconstructions.jpg")) # Save samples. def _decoder(latent_vectors): return f(dict(latent_vectors=latent_vectors), signature="decoder", as_dict=True)["images"] num_latent = int(gin_dict["encoder.num_latent"]) num_pics = 64 random_codes = random_state.normal(0, 1, [num_pics, num_latent]) pics = activation(_decoder(random_codes)) results_dir = os.path.join(output_dir, "sampled") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( pics, os.path.join(results_dir, "samples.jpg")) # Save latent traversals. result = f( dict(images=dataset.sample_observations(num_pics, random_state)), signature="gaussian_encoder", as_dict=True) means = result["mean"] logvars = result["logvar"] results_dir = os.path.join(output_dir, "traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) for i in range(means.shape[1]): pics = activation( latent_traversal_1d_multi_dim(_decoder, means[i, :], None)) file_name = os.path.join(results_dir, "traversals{}.jpg".format(i)) visualize_util.grid_save_images([pics], file_name) # Save the latent traversal animations. results_dir = os.path.join(output_dir, "animated_traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) # Cycle through quantiles of a standard Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "std_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through quantiles of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames, loc=loc, scale=np.sqrt(total_variance)) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fitted_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through [-2, 2] interval. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, -2., 2.) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fixed_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through +-2 std dev of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) scale = np.sqrt(total_variance) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, loc - 2. * scale, loc + 2. * scale) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "conf_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through minmax of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, np.min(means[:, j]), np.max(means[:, j])) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "minmax_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Interventional effects visualization. factors = dataset.sample_factors(num_points_irs, random_state) obs = dataset.sample_observations_from_factors(factors, random_state) batch_size = 64 num_outputs = 0 latents = [] while num_outputs < obs.shape[0]: input_batch = obs[num_outputs:min(num_outputs + batch_size, obs.shape[0])] output_batch = f(dict(images=input_batch), signature="gaussian_encoder", as_dict=True)["mean"] latents.append(output_batch) num_outputs += batch_size latents = np.concatenate(latents) results_dir = os.path.join(output_dir, "interventional_effects") vis_all_interventional_effects(factors, latents, results_dir) # Finally, we clear the gin config that we have set. gin.clear_config()
def test_save_animation(self): path = os.path.join(self.create_tempdir().full_path, "animation.gif") images = np.ones((18, 128, 256, 3), dtype=np.float32) visualize_util.save_animation([images, images], path, fps=18)