def test_gin_dict_dir(self): """Tests namespacing functionality based on saved gin config.""" parameter_name = "test.value" gin.bind_parameter(parameter_name, 1) _ = test_fn() config_path = os.path.join(self.get_temp_dir(), "config.gin") with tf.gfile.GFile(config_path, "w") as f: f.write(gin.operative_config_str()) f.close() self.assertDictEqual(results.gin_dict(config_path), {parameter_name: "1"})
def evaluate(model_dir, output_dir, overwrite=False, evaluation_fn=gin.REQUIRED, random_seed=gin.REQUIRED, name=""): """Loads a representation TFHub module and computes disentanglement metrics. Args: model_dir: String with path to directory where the representation function is saved. output_dir: String with the path where the results should be saved. overwrite: Boolean indicating whether to overwrite output directory. evaluation_fn: Function used to evaluate the representation (see metrics/ for examples). random_seed: Integer with random seed used for training. name: Optional string with name of the metric (can be used to name metrics). """ # We do not use the variable 'name'. Instead, it can be used to name scores # as it will be part of the saved gin config. del name # Delete the output directory if it already exists. if tf.gfile.IsDirectory(output_dir): if overwrite: tf.gfile.DeleteRecursively(output_dir) else: raise ValueError("Directory already exists and overwrite is False.") # Set up time to keep track of elapsed time in results. experiment_timer = time.time() # Automatically set the proper data set if necessary. We replace the active # gin config as this will lead to a valid gin config file where the data set # is present. if gin.query_parameter("dataset.name") == "auto": # Obtain the dataset name from the gin config of the previous step. gin_config_file = os.path.join(model_dir, "results", "gin", "postprocess.gin") gin_dict = results.gin_dict(gin_config_file) with gin.unlock_config(): gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace( "'", "")) dataset = named_data.get_named_ground_truth_data() # Path to TFHub module of previously trained representation. module_path = os.path.join(model_dir, "tfhub") with hub.eval_function_for_module(module_path) as f: def _representation_function(x): """Computes representation vector for input images.""" output = f(dict(images=x), signature="representation", as_dict=True) return np.array(output["default"]) # Computes scores of the representation based on the evaluation_fn. if _has_kwarg_or_kwargs(evaluation_fn, "artifact_dir"): artifact_dir = os.path.join(model_dir, "artifacts") results_dict = evaluation_fn( dataset, _representation_function, random_state=np.random.RandomState(random_seed), artifact_dir=artifact_dir) else: # Legacy code path to allow for old evaluation metrics. warnings.warn( "Evaluation function does not appear to accept an" " `artifact_dir` argument. This may not be compatible with " "future versions.", DeprecationWarning) results_dict = evaluation_fn( dataset, _representation_function, random_state=np.random.RandomState(random_seed)) # Save the results (and all previous results in the pipeline) on disk. original_results_dir = os.path.join(model_dir, "results") results_dir = os.path.join(output_dir, "results") results_dict["elapsed_time"] = time.time() - experiment_timer results.update_result_directory(results_dir, "evaluation", results_dict, original_results_dir)
def postprocess(model_dir, output_dir, overwrite=False, postprocess_fn=gin.REQUIRED, random_seed=gin.REQUIRED, name=""): """Loads a trained Gaussian encoder and extracts representation. Args: model_dir: String with path to directory where the model is saved. output_dir: String with the path where the representation should be saved. overwrite: Boolean indicating whether to overwrite output directory. postprocess_fn: Function used to extract the representation (see methods.py for examples). random_seed: Integer with random seed used for postprocessing (may be unused). name: Optional string with name of the representation (can be used to name representations). """ # We do not use the variable 'name'. Instead, it can be used to name # representations as it will be part of the saved gin config. del name # Delete the output directory if it already exists. if tf.gfile.IsDirectory(output_dir): if overwrite: tf.gfile.DeleteRecursively(output_dir) else: raise ValueError( "Directory already exists and overwrite is False.") # Set up timer to keep track of elapsed time in results. experiment_timer = time.time() # Automatically set the proper data set if necessary. We replace the active # gin config as this will lead to a valid gin config file where the data set # is present. if gin.query_parameter("dataset.name") == "auto": # Obtain the dataset name from the gin config of the previous step. gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin") gin_dict = results.gin_dict(gin_config_file) with gin.unlock_config(): gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace("'", "")) dataset = named_data.get_named_ground_truth_data() # Path to TFHub module of previously trained model. module_path = os.path.join(model_dir, "tfhub") with hub.eval_function_for_module(module_path) as f: def _gaussian_encoder(x): """Encodes images using trained model.""" # Push images through the TFHub module. output = f(dict(images=x), signature="gaussian_encoder", as_dict=True) # Convert to numpy arrays and return. return {key: np.array(values) for key, values in output.items()} # Run the postprocessing function which returns a transformation function # that can be used to create the representation from the mean and log # variance of the Gaussian distribution given by the encoder. Also returns # path to a checkpoint if the transformation requires variables. transform_fn, transform_checkpoint_path = postprocess_fn( dataset, _gaussian_encoder, np.random.RandomState(random_seed), output_dir) # Takes the "gaussian_encoder" signature, extracts the representation and # then saves under the signature "representation". tfhub_module_dir = os.path.join(output_dir, "tfhub") convolute_hub.convolute_and_save(module_path, "gaussian_encoder", tfhub_module_dir, transform_fn, transform_checkpoint_path, "representation") # We first copy over all the prior results and configs. original_results_dir = os.path.join(model_dir, "results") results_dir = os.path.join(output_dir, "results") results_dict = dict(elapsed_time=time.time() - experiment_timer) results.update_result_directory(results_dir, "postprocess", results_dict, original_results_dir)
def test_gin_dict_live(self): """Tests namespacing functionality based on live gin config.""" parameter_name = "test.value" gin.bind_parameter(parameter_name, 1) _ = test_fn() self.assertDictEqual(results.gin_dict(), {parameter_name: "1"})
def visualize(model_dir, output_dir, overwrite=False, num_animations=5, num_frames=20, fps=10, num_points_irs=10000): """Takes trained model from model_dir and visualizes it in output_dir. Args: model_dir: Path to directory where the trained model is saved. output_dir: Path to output directory. overwrite: Boolean indicating whether to overwrite output directory. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. num_points_irs: Number of points to be used for the IRS plots. """ # Fix the random seed for reproducibility. random_state = np.random.RandomState(0) # Create the output directory if necessary. if tf.gfile.IsDirectory(output_dir): if overwrite: tf.gfile.DeleteRecursively(output_dir) else: raise ValueError( "Directory already exists and overwrite is False.") # Automatically set the proper data set if necessary. We replace the active # gin config as this will lead to a valid gin config file where the data set # is present. # Obtain the dataset name from the gin config of the previous step. gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin") gin_dict = results.gin_dict(gin_config_file) gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace("'", "")) # Automatically infer the activation function from gin config. activation_str = gin_dict["reconstruction_loss.activation"] if activation_str == "'logits'": activation = sigmoid elif activation_str == "'tanh'": activation = tanh else: raise ValueError( "Activation function could not be infered from gin config.") dataset = named_data.get_named_ground_truth_data() num_pics = 64 module_path = os.path.join(model_dir, "tfhub") with hub.eval_function_for_module(module_path) as f: # Save reconstructions. real_pics = dataset.sample_observations(num_pics, random_state) raw_pics = f(dict(images=real_pics), signature="reconstructions", as_dict=True)["images"] pics = activation(raw_pics) paired_pics = np.concatenate((real_pics, pics), axis=2) paired_pics = [ paired_pics[i, :, :, :] for i in range(paired_pics.shape[0]) ] results_dir = os.path.join(output_dir, "reconstructions") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( paired_pics, os.path.join(results_dir, "reconstructions.jpg")) # Save samples. def _decoder(latent_vectors): return f(dict(latent_vectors=latent_vectors), signature="decoder", as_dict=True)["images"] num_latent = int(gin_dict["encoder.num_latent"]) num_pics = 64 random_codes = random_state.normal(0, 1, [num_pics, num_latent]) pics = activation(_decoder(random_codes)) results_dir = os.path.join(output_dir, "sampled") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( pics, os.path.join(results_dir, "samples.jpg")) # Save latent traversals. result = f( dict(images=dataset.sample_observations(num_pics, random_state)), signature="gaussian_encoder", as_dict=True) means = result["mean"] logvars = result["logvar"] results_dir = os.path.join(output_dir, "traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) for i in range(means.shape[1]): pics = activation( latent_traversal_1d_multi_dim(_decoder, means[i, :], None)) file_name = os.path.join(results_dir, "traversals{}.jpg".format(i)) visualize_util.grid_save_images([pics], file_name) # Save the latent traversal animations. results_dir = os.path.join(output_dir, "animated_traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) # Cycle through quantiles of a standard Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "std_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through quantiles of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames, loc=loc, scale=np.sqrt(total_variance)) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fitted_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through [-2, 2] interval. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, -2., 2.) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fixed_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through +-2 std dev of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) scale = np.sqrt(total_variance) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, loc - 2. * scale, loc + 2. * scale) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "conf_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through minmax of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, np.min(means[:, j]), np.max(means[:, j])) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "minmax_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Interventional effects visualization. factors = dataset.sample_factors(num_points_irs, random_state) obs = dataset.sample_observations_from_factors(factors, random_state) latents = f(dict(images=obs), signature="gaussian_encoder", as_dict=True)["mean"] results_dir = os.path.join(output_dir, "interventional_effects") vis_all_interventional_effects(factors, latents, results_dir) # Finally, we clear the gin config that we have set. gin.clear_config()