示例#1
0
 def test_gin_dict_dir(self):
   """Tests namespacing functionality based on saved gin config."""
   parameter_name = "test.value"
   gin.bind_parameter(parameter_name, 1)
   _ = test_fn()
   config_path = os.path.join(self.get_temp_dir(), "config.gin")
   with tf.gfile.GFile(config_path, "w") as f:
     f.write(gin.operative_config_str())
     f.close()
   self.assertDictEqual(results.gin_dict(config_path), {parameter_name: "1"})
示例#2
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name="",
             eval_pytorch=False):
    """Loads a representation TFHub module and computes disentanglement metrics.

    Args:
      model_dir: String with path to directory where the representation function
        is saved.
      output_dir: String with the path where the results should be saved.
      overwrite: Boolean indicating whether to overwrite output directory.
      evaluation_fn: Function used to evaluate the representation (see metrics/
        for examples).
      random_seed: Integer with random seed used for training.
      name: Optional string with name of the metric (can be used to name metrics).
    """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    try:
        # Automatically set the proper data set if necessary. We replace the active
        # gin config as this will lead to a valid gin config file where the data set
        # is present.
        if gin.query_parameter("dataset.name") == "auto":
            # Obtain the dataset name from the gin config of the previous step.
            gin_config_file = os.path.join(model_dir, "results", "gin",
                                           "postprocess.gin")
            gin_dict = results.gin_dict(gin_config_file)
            with gin.unlock_config():
                gin.bind_parameter("dataset.name",
                                   gin_dict["dataset.name"].replace("'", ""))
        dataset = named_data.get_named_ground_truth_data()
    except NotFoundError:
        # If we did not train with disentanglement_lib, there is no "previous step",
        # so we'll have to rely on the environment variable.
        if gin.query_parameter("dataset.name") == "auto":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name", get_dataset_name())
        dataset = named_data.get_named_ground_truth_data()

    eval_tf = True
    if eval_pytorch and os.path.exists(
            os.path.join(model_dir, 'pytorch_model.pt')):
        eval_tf = False

    if os.path.exists(os.path.join(model_dir, 'tfhub')) and eval_tf:
        # Path to TFHub module of previously trained representation.
        module_path = os.path.join(model_dir, "tfhub")
        # Evaluate results with tensorflow
        results_dict = _evaluate_with_tensorflow(module_path, evaluation_fn,
                                                 dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'pytorch_model.pt')):
        # Path to Pytorch JIT Module of previously trained representation.
        module_path = os.path.join(model_dir, 'pytorch_model.pt')
        # Evaluate results with pytorch
        results_dict = _evaluate_with_pytorch(module_path, evaluation_fn,
                                              dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'python_model.dill')):
        # Path to the dilled function
        module_path = os.path.join(model_dir, 'python_model.dill')
        # Evaluate results with numpy
        results_dict = _evaluate_with_numpy(module_path, evaluation_fn,
                                            dataset, random_seed)
    else:
        raise RuntimeError(
            "`model_dir` must contain either a pytorch or a TFHub model.")

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
示例#3
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             postprocess_fn=gin.REQUIRED,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             pca_components=gin.REQUIRED,
             name=""):
    """Loads a trained Gaussian encoder and decoder.

  Args:
    model_dir: String with path to directory where the model is saved.
    output_dir: String with the path where the representation should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    postprocess_fn: Function used to extract the representation (see methods.py
      for examples).
    random_seed: Integer with random seed used for postprocessing (may be
      unused).
    name: Optional string with name of the representation (can be used to name
      representations).
  """
    # We do not use the variable 'name'. Instead, it can be used to name
    # representations as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up timer to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained model.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _gaussian_encoder(x):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(images=x),
                       signature="gaussian_encoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return np.array(output["mean"]), np.array(output["logvar"])

        def _decoder(z):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(latent_vectors=z),
                       signature="decoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return np.array(output['images'])

        # Run the postprocessing function which returns a transformation function
        # that can be used to create the representation from the mean and log
        # variance of the Gaussian distribution given by the encoder. Also returns
        # path to a checkpoint if the transformation requires variables.
        transform_fn, transform_checkpoint_path = postprocess_fn(
            dataset, _gaussian_encoder, np.random.RandomState(random_seed),
            output_dir)

        print('\n\n\n Calculating recall')
        # Computes scores of the representation based on the evaluation_fn.
        if _has_kwarg_or_kwargs(evaluation_fn, "artifact_dir"):
            artifact_dir = os.path.join(model_dir, "artifacts")
            results_dict_list = evaluation_fn(
                dataset,
                _gaussian_encoder,
                transform_fn,
                _decoder,
                random_state=np.random.RandomState(random_seed),
                artifact_dir=artifact_dir)
        else:
            # Legacy code path to allow for old evaluation metrics.
            warnings.warn(
                "Evaluation function does not appear to accept an"
                " `artifact_dir` argument. This may not be compatible with "
                "future versions.", DeprecationWarning)
            results_dict_list = evaluation_fn(
                dataset,
                _gaussian_encoder,
                transform_fn,
                _decoder,
                random_state=np.random.RandomState(random_seed))

    # Save the results (and all previous results in the pipeline) on disk.
    for results_dict, pca_comp in list(zip(results_dict_list, pca_components)):
        results_dir = os.path.join(output_dir, "results")
        results_dict["elapsed_time"] = time.time() - experiment_timer
        filename = "evaluation_pca_{}comp".format(pca_comp)
        results.update_result_directory(results_dir, filename, results_dict)
示例#4
0
def evaluate(model_dirs,
             output_dir,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name=""):
    """Loads a trained estimator and evaluates it according to beta-VAE metric."""
    # The name will be part of the gin config and can be used to tag results.
    del name

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper dataset if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the dataset
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dirs[0], "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            print(gin_dict["dataset.name"])
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))

    output_dir = os.path.join(output_dir)
    if tf.io.gfile.isdir(output_dir):
        tf.io.gfile.rmtree(output_dir)

    dataset = named_data.get_named_ground_truth_data()

    with contextlib.ExitStack() as stack:
        representation_functions = []
        eval_functions = [
            stack.enter_context(
                hub.eval_function_for_module(os.path.join(model_dir, "tfhub")))
            for model_dir in model_dirs
        ]
        for f in eval_functions:

            def _representation_function(x, f=f):
                def compute_gaussian_kl(z_mean, z_logvar):
                    return np.mean(
                        0.5 *
                        (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1),
                        axis=0)

                encoding = f(dict(images=x),
                             signature="gaussian_encoder",
                             as_dict=True)

                return np.array(encoding["mean"]), compute_gaussian_kl(
                    np.array(encoding["mean"]), np.array(encoding["logvar"]))

            representation_functions.append(_representation_function)

        results_dict = evaluation_fn(
            dataset,
            representation_functions,
            random_state=np.random.RandomState(random_seed))

    original_results_dir = os.path.join(model_dirs[0], "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
示例#5
0
def visualize(model_dir,
              output_dir,
              overwrite=False,
              num_animations=5,
              num_frames=20,
              fps=10,
              num_points_irs=10000):
    """Takes trained model from model_dir and visualizes it in output_dir.

    Args:
      model_dir: Path to directory where the trained model is saved.
      output_dir: Path to output directory.
      overwrite: Boolean indicating whether to overwrite output directory.
      num_animations: Integer with number of distinct animations to create.
      num_frames: Integer with number of frames in each animation.
      fps: Integer with frame rate for the animation.
      num_points_irs: Number of points to be used for the IRS plots.
    """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    module_path = os.path.join(model_dir, "tfhub")

    with hub.eval_function_for_module(module_path) as f:
        # Save reconstructions.
        real_pics = dataset.sample_observations(num_pics, random_state)
        raw_pics = f(dict(images=real_pics),
                     signature="reconstructions",
                     as_dict=True)["images"]
        pics = activation(raw_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

        # Save samples.
        def _decoder(latent_vectors):
            return f(dict(latent_vectors=latent_vectors),
                     signature="decoder",
                     as_dict=True)["images"]

        num_latent = int(gin_dict["encoder.num_latent"])
        num_pics = 64
        random_codes = random_state.normal(0, 1, [num_pics, num_latent])
        pics = activation(_decoder(random_codes))
        results_dir = os.path.join(output_dir, "sampled")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            pics, os.path.join(results_dir, "samples.jpg"))

        # Save latent traversals.
        result = f(
            dict(images=dataset.sample_observations(num_pics, random_state)),
            signature="gaussian_encoder",
            as_dict=True)
        means = result["mean"]
        logvars = result["logvar"]
        results_dir = os.path.join(output_dir, "traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        for i in range(means.shape[1]):
            pics = activation(
                latent_traversal_1d_multi_dim(_decoder, means[i, :], None))
            file_name = os.path.join(results_dir, "traversals{}.jpg".format(i))
            visualize_util.grid_save_images([pics], file_name)

        # Save the latent traversal animations.
        results_dir = os.path.join(output_dir, "animated_traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)

        # Cycle through quantiles of a standard Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j], num_frames)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "std_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through quantiles of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j],
                    num_frames,
                    loc=loc,
                    scale=np.sqrt(total_variance))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fitted_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through [-2, 2] interval.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, -2., 2.)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fixed_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                scale = np.sqrt(total_variance)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, loc - 2. * scale,
                    loc + 2. * scale)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "conf_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through minmax of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, np.min(means[:, j]),
                    np.max(means[:, j]))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "minmax_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Interventional effects visualization.
        factors = dataset.sample_factors(num_points_irs, random_state)
        obs = dataset.sample_observations_from_factors(factors, random_state)
        batch_size = 64
        num_outputs = 0
        latents = []
        while num_outputs < obs.shape[0]:
            input_batch = obs[num_outputs:min(num_outputs +
                                              batch_size, obs.shape[0])]
            output_batch = f(dict(images=input_batch),
                             signature="gaussian_encoder",
                             as_dict=True)["mean"]
            latents.append(output_batch)
            num_outputs += batch_size
        latents = np.concatenate(latents)

        results_dir = os.path.join(output_dir, "interventional_effects")
        vis_all_interventional_effects(factors, latents, results_dir)

    # Finally, we clear the gin config that we have set.
    gin.clear_config()
示例#6
0
def reason(
    input_dir,
    output_dir,
    overwrite=False,
    model=gin.REQUIRED,
    num_iterations=gin.REQUIRED,
    training_steps_per_iteration=gin.REQUIRED,
    eval_steps_per_iteration=gin.REQUIRED,
    random_seed=gin.REQUIRED,
    batch_size=gin.REQUIRED,
    name="",
):
    """Trains the estimator and exports the snapshot and the gin config.

  The use of this function requires the gin binding 'dataset.name' to be
  specified if a model is trained from scratch as that determines the data set
  used for training.

  Args:
    input_dir: String with path to directory where the representation function
      is saved.
    output_dir: String with the path where the results should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    model: GaussianEncoderModel that should be trained and exported.
    num_iterations: Integer with number of training steps.
    training_steps_per_iteration: Integer with number of training steps per
      iteration.
    eval_steps_per_iteration: Integer with number of validationand test steps
      per iteration.
    random_seed: Integer with random seed used for training.
    batch_size: Integer with the batch size.
    name: Optional string with name of the model (can be used to name models).
  """
    # We do not use the variable 'name'. Instead, it can be used to name results
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Create a numpy random state. We will sample the random seeds for training
    # and evaluation from this.
    random_state = np.random.RandomState(random_seed)

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        if input_dir is None:
            raise ValueError(
                "Cannot automatically infer data set for methods with"
                " no prior model directory.")
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(input_dir, "results", "gin",
                                       "postprocess.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = pgm_data.get_pgm_dataset()

    # Set the path to the TFHub embedding if we are training based on a
    # pre-trained embedding..
    if input_dir is not None:
        tfhub_dir = os.path.join(input_dir, "tfhub")
        with gin.unlock_config():
            gin.bind_parameter("HubEmbedding.hub_path", tfhub_dir)

    # We create a TPUEstimator based on the provided model. This is primarily so
    # that we could switch to TPU training in the future. For now, we train
    # locally on GPUs.
    run_config = contrib_tpu.RunConfig(
        tf_random_seed=random_seed,
        keep_checkpoint_max=1,
        tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=500))
    tpu_estimator = contrib_tpu.TPUEstimator(use_tpu=False,
                                             model_fn=model.model_fn,
                                             model_dir=os.path.join(
                                                 output_dir, "tf_checkpoint"),
                                             train_batch_size=batch_size,
                                             eval_batch_size=batch_size,
                                             config=run_config)

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Create a dictionary to keep track of all relevant information.
    results_dict_of_dicts = {}
    validation_scores = []
    all_dicts = []

    for i in range(num_iterations):
        steps_so_far = i * training_steps_per_iteration
        tf.logging.info("Training to %d steps.", steps_so_far)
        # Train the model for the specified steps.
        tpu_estimator.train(input_fn=dataset.make_input_fn(
            random_state.randint(2**32)),
                            steps=training_steps_per_iteration)
        # Compute validation scores used for model selection.
        validation_results = tpu_estimator.evaluate(
            input_fn=dataset.make_input_fn(
                random_state.randint(2**32),
                num_batches=eval_steps_per_iteration))
        validation_scores.append(validation_results["accuracy"])
        tf.logging.info("Validation results %s", validation_results)
        # Compute test scores for final results.
        test_results = tpu_estimator.evaluate(input_fn=dataset.make_input_fn(
            random_state.randint(2**32), num_batches=eval_steps_per_iteration),
                                              name="test")
        dict_at_iteration = results.namespaced_dict(val=validation_results,
                                                    test=test_results)
        results_dict_of_dicts["step{}".format(
            steps_so_far)] = dict_at_iteration
        all_dicts.append(dict_at_iteration)

    # Select the best number of steps based on the validation scores and add it as
    # as a special key to the dictionary.
    best_index = np.argmax(validation_scores)
    results_dict_of_dicts["best"] = all_dicts[best_index]

    # Save the results. The result dir will contain all the results and config
    # files that we copied along, as we progress in the pipeline. The idea is that
    # these files will be available for analysis at the end.
    if input_dir is not None:
        original_results_dir = os.path.join(input_dir, "results")
    else:
        original_results_dir = None
    results_dict = results.namespaced_dict(**results_dict_of_dicts)
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "abstract_reasoning",
                                    results_dict, original_results_dir)
示例#7
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name=""):
    """Loads a representation TFHub module and computes disentanglement metrics.

  Args:
    model_dir: String with path to directory where the representation function
      is saved.
    output_dir: String with the path where the results should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    evaluation_fn: Function used to evaluate the representation (see metrics/
      for examples).
    random_seed: Integer with random seed used for training.
    name: Optional string with name of the metric (can be used to name metrics).
  """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "postprocess.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained representation.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _representation_function(x):
            """Computes representation vector for input images."""
            output = f(dict(images=x),
                       signature="representation",
                       as_dict=True)
            return np.array(output["default"])

        # Computes scores of the representation based on the evaluation_fn.
        if _has_kwarg_or_kwargs(evaluation_fn, "artifact_dir"):
            artifact_dir = os.path.join(model_dir, "artifacts")
            results_dict = evaluation_fn(
                dataset,
                _representation_function,
                random_state=np.random.RandomState(random_seed),
                artifact_dir=artifact_dir)
        else:
            # Legacy code path to allow for old evaluation metrics.
            warnings.warn(
                "Evaluation function does not appear to accept an"
                " `artifact_dir` argument. This may not be compatible with "
                "future versions.", DeprecationWarning)
            results_dict = evaluation_fn(
                dataset,
                _representation_function,
                random_state=np.random.RandomState(random_seed))

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
def validate(model_dir,
             output_dir,
             overwrite=False,
             validation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             num_labelled_samples=gin.REQUIRED,
             name=""):
    """Loads a representation TFHub module and computes disentanglement metrics.

  Args:
    model_dir: String with path to directory where the representation function
      is saved.
    output_dir: String with the path where the results should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    validation_fn: Function used to validate the representation (see metrics/
      for examples).
    random_seed: Integer with random seed used for training.
    num_labelled_samples: How many labelled samples are available.
    name: Optional string with name of the metric (can be used to name metrics).
  """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "postprocess.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()
    observations, labels, _ = semi_supervised_utils.sample_supervised_data(
        random_seed, dataset, num_labelled_samples)
    # Path to TFHub module of previously trained representation.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _representation_function(x):
            """Computes representation vector for input images."""
            output = f(dict(images=x),
                       signature="representation",
                       as_dict=True)
            return np.array(output["default"])

        # Computes scores of the representation based on the evaluation_fn.
        results_dict = validation_fn(observations, np.transpose(labels),
                                     _representation_function)

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "validation", results_dict,
                                    original_results_dir)
示例#9
0
def postprocess(model_dir,
                output_dir,
                overwrite=False,
                postprocess_fn=gin.REQUIRED,
                random_seed=gin.REQUIRED,
                name=""):
    """Loads a trained Gaussian encoder and extracts representation.

  Args:
    model_dir: String with path to directory where the model is saved.
    output_dir: String with the path where the representation should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    postprocess_fn: Function used to extract the representation (see methods.py
      for examples).
    random_seed: Integer with random seed used for postprocessing (may be
      unused).
    name: Optional string with name of the representation (can be used to name
      representations).
  """
    # We do not use the variable 'name'. Instead, it can be used to name
    # representations as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up timer to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained model.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _gaussian_encoder(x):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(images=x),
                       signature="gaussian_encoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return {key: np.array(values) for key, values in output.items()}

        # Run the postprocessing function which returns a transformation function
        # that can be used to create the representation from the mean and log
        # variance of the Gaussian distribution given by the encoder. Also returns
        # path to a checkpoint if the transformation requires variables.
        transform_fn, transform_checkpoint_path = postprocess_fn(
            dataset, _gaussian_encoder, np.random.RandomState(random_seed),
            output_dir)

        # Takes the "gaussian_encoder" signature, extracts the representation and
        # then saves under the signature "representation".
        tfhub_module_dir = os.path.join(output_dir, "tfhub")
        convolute_hub.convolute_and_save(module_path, "gaussian_encoder",
                                         tfhub_module_dir, transform_fn,
                                         transform_checkpoint_path,
                                         "representation")

    # We first copy over all the prior results and configs.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict = dict(elapsed_time=time.time() - experiment_timer)
    results.update_result_directory(results_dir, "postprocess", results_dict,
                                    original_results_dir)
示例#10
0
def visualize_supervised(supervised_model_dir,
                         trained_vae_model_dir,
                         output_dir,
                         overwrite=False):
    """Takes trained model from model_dir and visualizes it in output_dir.

  Args:
    model_dir: Path to directory where the trained model is saved.
    output_dir: Path to output directory.
    overwrite: Boolean indicating whether to overwrite output directory.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
    num_points_irs: Number of points to be used for the IRS plots.
  """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(supervised_model_dir, "results", "gin",
                                   "evaluate.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    _, dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    supervised_module_path = os.path.join(supervised_model_dir, "tfhub")

    with hub.eval_function_for_module(supervised_module_path) as f:
        trained_vae_path = os.path.join(trained_vae_model_dir, "tfhub")
        with hub.eval_function_for_module(trained_vae_path) as g:

            def _representation_function(x):
                """Computes representation vector for input images."""
                output = g(dict(images=x),
                           signature="representation",
                           as_dict=True)
                return np.array(output["default"])

            # Save reconstructions.
            real_pics = dataset.sample_observations(num_pics, random_state)
            #      real_pics, _ = dataset.sample_observations_and_labels(num_pics, random_state)
            representations = _representation_function(real_pics)

        print(real_pics.shape, representations.shape)
        decoded_pics = f(dict(representations=representations),
                         signature="reconstructions",
                         as_dict=True)['images']
        pics = activation(decoded_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

    # Finally, we clear the gin config that we have set.
    gin.clear_config()
示例#11
0
 def test_gin_dict_live(self):
     """Tests namespacing functionality based on live gin config."""
     parameter_name = "test.value"
     gin.bind_parameter(parameter_name, 1)
     _ = test_fn()
     self.assertDictEqual(results.gin_dict(), {parameter_name: "1"})