示例#1
0
 def __init__(self, name, n_samples=0, seed=0):
     self.name = name
     self.seed = seed
     self.random_state = np.random.RandomState(seed)
     self.dataset = named_data.get_named_ground_truth_data(self.name)
     self.n_samples = len(
         self.dataset.images) if n_samples == 0 else n_samples
def get_loader(name=None, batch_size=256, flat=False):
    """
    Makes a dataset and a data-loader.
    Parameters
    ----------
    name : str
        Name of the dataset use. Defaults to the output of `get_dataset_name`.
    batch_size : int
        Batch size.
    Returns
    -------
    SerialIterator
    """
    name = get_dataset_name() if name is None else name
    dataset = get_named_ground_truth_data(name)
    assert isinstance(dataset.images, np.ndarray), \
        "Input to the representation function must be a ndarray."
    if dataset.images.ndim == 3:
        # expand channel dim
        dataset.images = np.expand_dims(dataset.images, axis=3)
    assert dataset.images.ndim == 4, \
        "Input to the representation function must be a four dimensional NHWC tensor."
    # Convert from NHWC to NCHW
    dataset.images = np.moveaxis(dataset.images, 3, 1)
    if flat:
        dataset.images = dataset.images.reshape(-1, 3 * 64 * 64)
    cast_dataset = CastDataset(dataset.images)
    loader = SerialIterator(cast_dataset,
                            batch_size=batch_size,
                            repeat=True,
                            shuffle=True)
    return loader
示例#3
0
def get_pgm_dataset(pgm_type=gin.REQUIRED):
  """Returns a named PGM data set."""
  ground_truth_data = named_data.get_named_ground_truth_data()

  # Quantization for specific data sets (as described in
  # https://arxiv.org/abs/1905.12506).
  if isinstance(ground_truth_data, dsprites.AbstractDSprites):
    wrapped_data_set = Quantizer(ground_truth_data, [5, 6, 3, 3, 4, 4])
  elif isinstance(ground_truth_data, shapes3d.Shapes3D):
    wrapped_data_set = Quantizer(ground_truth_data, [10, 10, 10, 4, 4, 4])
  elif isinstance(ground_truth_data, dummy_data.DummyData):
    wrapped_data_set = ground_truth_data
  else:
    raise ValueError("Invalid data set.")

  # We support different ways to generate PGMs for each of the data set (e.g.,
  # `easy_1`, `hard_3`, `easy_mixes`). `easy` and `hard` refers to the way the
  # alternative solutions of the PGMs are generated:
  #   - `easy`: Alternative answers are random other solutions that do not
  #             satisfy the constraints in the given PGM.
  #   - `hard`: Alternative answers are unique random modifications of the
  #             correct solution which makes the task substantially harder.
  if pgm_type.startswith("easy"):
    sampling = "easy"
  elif pgm_type.startswith("hard"):
    sampling = "hard"
  else:
    raise ValueError("Invalid sampling strategy.")

  # The suffix determines how many relations there are:
  #   - 1-3: Specifies whether always 1, 2, or 3 relations are constant in each
  #          row.
  #   - `mixed`: With probability 1/3 each, 1, 2, or 3 relations are constant
  #               in each row.
  if pgm_type.endswith("1"):
    relations_dist = [1., 0., 0.]
  elif pgm_type.endswith("2"):
    relations_dist = [0., 1., 0.]
  elif pgm_type.endswith("3"):
    relations_dist = [0., 0., 1.]
  elif pgm_type.endswith("mixed"):
    relations_dist = [1. / 3., 1. / 3., 1. / 3.]
  else:
    raise ValueError("Invalid number of relations.")

  return PGMDataset(
      wrapped_data_set,
      sampling_strategy=sampling,
      relations_dist=relations_dist)
def get_missing_loader(name=None, batch_size=256, flat=False):
    """
    Make a missing dataset and a data-loader.
    Parameters
    ----------
    name : str
        Name of the dataset use. Defaults to the output of `get_dataset_name`.
    batch_size : int
        Batch size.
    Returns
    -------
    SerialIterator
    """
    name = get_dataset_name() if name is None else name
    dataset = get_named_ground_truth_data(name)
    # erase random factor pairs
    dataset.images = dataset.images.reshape(3, 6, 40, 32, 32, 64, 64)
    partial_images = []
    for shape in range(3):
        for size in range(6):
            for rotation in range(40):
                for x_pos in range(32):
                    for y_pos in range(32):
                        if np.random.rand() < 0.5:
                            partial_images.append(dataset.images[shape, size,
                                                                 rotation,
                                                                 x_pos, y_pos])
    dataset.images = partial_images
    dataset.images = np.array(dataset.images)

    assert isinstance(dataset.images, np.ndarray), \
        "Input to the representation function must be a ndarray."
    if dataset.images.ndim == 3:
        # expand channel dim
        dataset.images = np.expand_dims(dataset.images, axis=3)
    assert dataset.images.ndim == 4, \
        "Input to the representation function must be a four dimensional NHWC tensor."
    # Convert from NHWC to NCHW
    dataset.images = np.moveaxis(dataset.images, 3, 1)
    if flat:
        dataset.images = dataset.images.reshape(-1, 3 * 64 * 64)

    cast_dataset = CastDataset(dataset.images)
    loader = SerialIterator(cast_dataset,
                            batch_size=batch_size,
                            repeat=True,
                            shuffle=True)
    return loader
示例#5
0
def visualize_dataset(dataset_name,
                      output_path,
                      num_animations=5,
                      num_frames=20,
                      fps=10):
    """Visualizes the data set by saving images to output_path.

  For each latent factor, outputs 16 images where only that latent factor is
  varied while all others are kept constant.

  Args:
    dataset_name: String with name of dataset as defined in named_data.py.
    output_path: String with path in which to create the visualizations.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
  """
    data = named_data.get_named_ground_truth_data(dataset_name)
    random_state = np.random.RandomState(0)

    # Create output folder if necessary.
    path = os.path.join(output_path, dataset_name)
    if not gfile.IsDirectory(path):
        gfile.MakeDirs(path)

    # Create still images.
    for i in range(data.num_factors):
        factors = data.sample_factors(16, random_state)
        indices = [j for j in range(data.num_factors) if i != j]
        factors[:, indices] = factors[0, indices]
        images = data.sample_observations_from_factors(factors, random_state)
        visualize_util.grid_save_images(
            images, os.path.join(path, "variations_of_factor%s.png" % i))

    # Create animations.
    for i in range(num_animations):
        base_factor = data.sample_factors(1, random_state)
        images = []
        for j, num_atoms in enumerate(data.factors_num_values):
            factors = np.repeat(base_factor, num_frames, axis=0)
            factors[:,
                    j] = visualize_util.cycle_factor(base_factor[0, j],
                                                     num_atoms, num_frames)
            images.append(
                data.sample_observations_from_factors(factors, random_state))
        visualize_util.save_animation(
            np.array(images), os.path.join(path, "animation%d.gif" % i), fps)
示例#6
0
 def __init__(self, name, seed=0):
     """
     Parameters
     ----------
     name : str
         Name of the dataset use. You may use `get_dataset_name`.
     seed : int
         Random seed.
     iterator_len : int
         Length of the dataset. This defines the length of one training epoch.
     """
     from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
     self.name = name
     self.seed = seed
     self.random_state = np.random.RandomState(seed)
     self.dataset = get_named_ground_truth_data(self.name)
     self.iterator_len = self.dataset.images.shape[0]
def train_pca(model_dir,
              overwrite=False,
              random_seed=gin.REQUIRED,
              num_pca_components=gin.REQUIRED,
              name="",
              model_num=None):
    """Trains the pca and saves it.

  Args:
    model_dir: String with path to directory where model output should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    random_seed: Integer with random seed used for training.
    num_pca_components: list with the number of pca components.
    name: Optional string with name of the model (can be used to name models).
    model_num: Optional integer with model number (can be used to identify
      models).
  """
    # Obtain the datasets.
    dataset_train, _ = named_data.get_named_ground_truth_data()

    # Create a numpy random state. We will sample the random seeds for training
    # and evaluation from this.
    random_state = np.random.RandomState(random_seed)

    # Save the pca
    for num_comp in num_pca_components:
        pca = PCA(n_components=num_comp, random_state=random_state)
        original_train_images = dataset_train.images.reshape(
            dataset_train.data_size, -1)
        print(original_train_images.shape)
        trained_pca = pca.fit(original_train_images)

        pca_model_name = 'pca_{0}_{1}comp.pkl'.format(dataset_train.name,
                                                      str(num_comp))
        pca_export_path = os.path.join(model_dir, pca_model_name)
        with open(pca_export_path, 'wb') as f:
            pickle.dump(trained_pca, f)
def get_dataset_loader(batch_size, seed):
    dataset = named_data.get_named_ground_truth_data()
    loader = util.tf_data_set_from_ground_truth_data(dataset, seed)
    loader = loader.batch(batch_size, drop_remainder=True)
    return loader, dataset
def validate(model_dir,
             output_dir,
             overwrite=False,
             validation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             num_labelled_samples=gin.REQUIRED,
             name=""):
    """Loads a representation TFHub module and computes disentanglement metrics.

  Args:
    model_dir: String with path to directory where the representation function
      is saved.
    output_dir: String with the path where the results should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    validation_fn: Function used to validate the representation (see metrics/
      for examples).
    random_seed: Integer with random seed used for training.
    num_labelled_samples: How many labelled samples are available.
    name: Optional string with name of the metric (can be used to name metrics).
  """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "postprocess.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()
    observations, labels, _ = semi_supervised_utils.sample_supervised_data(
        random_seed, dataset, num_labelled_samples)
    # Path to TFHub module of previously trained representation.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _representation_function(x):
            """Computes representation vector for input images."""
            output = f(dict(images=x),
                       signature="representation",
                       as_dict=True)
            return np.array(output["default"])

        # Computes scores of the representation based on the evaluation_fn.
        results_dict = validation_fn(observations, np.transpose(labels),
                                     _representation_function)

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "validation", results_dict,
                                    original_results_dir)
示例#10
0
def compute_downstream_regression_on_representations(
        ground_truth_data,
        representation_function,
        random_state,
        holdout_dataset_name=gin.REQUIRED,
        artifact_dir=None,
        num_train=gin.REQUIRED,
        num_test=gin.REQUIRED,
        num_holdout=gin.REQUIRED,
        batch_size=16):
    """Computes loss of downstream task on representations.

  Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    random_state: Numpy random state used for randomness.
    artifact_dir: Optional path to directory where artifacts can be saved.
    num_train: Number of points used for training.
    num_test: Number of points used for testing.
    batch_size: Batch size for sampling.

  Returns:
    Dictionary with scores.
  """
    del artifact_dir
    ground_truth_train_data, ground_truth_test_data = ground_truth_data
    ground_truth_holdout_data = named_data.get_named_ground_truth_data(
        holdout_dataset_name)
    scores = {}
    for train_size in num_train:
        mus_train, ys_train = utils.generate_batch_label_code(
            ground_truth_train_data, representation_function, train_size,
            random_state, batch_size)
        mus_test, ys_test = utils.generate_batch_label_code(
            ground_truth_test_data, representation_function, num_test,
            random_state, batch_size)
        mus_holdout, ys_holdout = utils.generate_batch_label_code(
            ground_truth_holdout_data, representation_function, num_holdout,
            random_state, batch_size)

        predictor_model = utils.make_predictor_fn()

        train_err, test_err, holdout_err, random_normal_err, random_uniform_err = \
          _compute_mse_loss(
            np.transpose(mus_train), ys_train, np.transpose(mus_test),
            ys_test, np.transpose(mus_holdout), ys_holdout, predictor_model)
        size_string = str(train_size)
        scores[size_string + ":mean_train_mse"] = np.mean(train_err)
        scores[size_string + ":mean_test_mse"] = np.mean(test_err)
        scores[size_string + ":mean_holdout_mse"] = np.mean(holdout_err)
        scores[size_string +
               ":mean_random_normal_mse"] = np.mean(random_normal_err)
        scores[size_string +
               ":mean_random_uniform_mse"] = np.mean(random_uniform_err)
        scores[size_string + ":min_train_mse"] = np.min(train_err)
        scores[size_string + ":min_test_mse"] = np.min(test_err)
        scores[size_string + ":min_holdout_mse"] = np.min(holdout_err)
        scores[size_string +
               ":min_random_normal_mse"] = np.min(random_normal_err)
        scores[size_string +
               ":min_random_uniform_mse"] = np.min(random_uniform_err)
        for i in range(len(train_err)):
            scores[size_string +
                   ":train_mse_factor_{}".format(i)] = train_err[i]
            scores[size_string + ":test_mse_factor_{}".format(i)] = test_err[i]
            scores[size_string +
                   ":holdout_mse_factor_{}".format(i)] = holdout_err[i]
            scores[size_string + ":random_normal_mse_factor_{}".format(
                i)] = random_normal_err[i]
            scores[size_string + ":random_uniform_mse_factor_{}".format(
                i)] = random_uniform_err[i]
    return scores
 def load_dataset(self):
     return get_named_ground_truth_data(self.name)
示例#12
0
 def __init__(self, name, seed=0):
     self.name = name
     self.random_state = np.random.RandomState(seed)
     self.dataset = get_named_ground_truth_data(self.name)
     self.iterator_len = np.prod(self.dataset.factors_num_values)
示例#13
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name="",
             eval_pytorch=False):
    """Loads a representation TFHub module and computes disentanglement metrics.

    Args:
      model_dir: String with path to directory where the representation function
        is saved.
      output_dir: String with the path where the results should be saved.
      overwrite: Boolean indicating whether to overwrite output directory.
      evaluation_fn: Function used to evaluate the representation (see metrics/
        for examples).
      random_seed: Integer with random seed used for training.
      name: Optional string with name of the metric (can be used to name metrics).
    """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    try:
        # Automatically set the proper data set if necessary. We replace the active
        # gin config as this will lead to a valid gin config file where the data set
        # is present.
        if gin.query_parameter("dataset.name") == "auto":
            # Obtain the dataset name from the gin config of the previous step.
            gin_config_file = os.path.join(model_dir, "results", "gin",
                                           "postprocess.gin")
            gin_dict = results.gin_dict(gin_config_file)
            with gin.unlock_config():
                gin.bind_parameter("dataset.name",
                                   gin_dict["dataset.name"].replace("'", ""))
        dataset = named_data.get_named_ground_truth_data()
    except NotFoundError:
        # If we did not train with disentanglement_lib, there is no "previous step",
        # so we'll have to rely on the environment variable.
        if gin.query_parameter("dataset.name") == "auto":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name", get_dataset_name())
        dataset = named_data.get_named_ground_truth_data()

    eval_tf = True
    if eval_pytorch and os.path.exists(
            os.path.join(model_dir, 'pytorch_model.pt')):
        eval_tf = False

    if os.path.exists(os.path.join(model_dir, 'tfhub')) and eval_tf:
        # Path to TFHub module of previously trained representation.
        module_path = os.path.join(model_dir, "tfhub")
        # Evaluate results with tensorflow
        results_dict = _evaluate_with_tensorflow(module_path, evaluation_fn,
                                                 dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'pytorch_model.pt')):
        # Path to Pytorch JIT Module of previously trained representation.
        module_path = os.path.join(model_dir, 'pytorch_model.pt')
        # Evaluate results with pytorch
        results_dict = _evaluate_with_pytorch(module_path, evaluation_fn,
                                              dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'python_model.dill')):
        # Path to the dilled function
        module_path = os.path.join(model_dir, 'python_model.dill')
        # Evaluate results with numpy
        results_dict = _evaluate_with_numpy(module_path, evaluation_fn,
                                            dataset, random_seed)
    else:
        raise RuntimeError(
            "`model_dir` must contain either a pytorch or a TFHub model.")

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
示例#14
0
def evaluate(model_dirs,
             output_dir,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name=""):
    """Loads a trained estimator and evaluates it according to beta-VAE metric."""
    # The name will be part of the gin config and can be used to tag results.
    del name

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper dataset if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the dataset
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dirs[0], "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            print(gin_dict["dataset.name"])
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))

    output_dir = os.path.join(output_dir)
    if tf.io.gfile.isdir(output_dir):
        tf.io.gfile.rmtree(output_dir)

    dataset = named_data.get_named_ground_truth_data()

    with contextlib.ExitStack() as stack:
        representation_functions = []
        eval_functions = [
            stack.enter_context(
                hub.eval_function_for_module(os.path.join(model_dir, "tfhub")))
            for model_dir in model_dirs
        ]
        for f in eval_functions:

            def _representation_function(x, f=f):
                def compute_gaussian_kl(z_mean, z_logvar):
                    return np.mean(
                        0.5 *
                        (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1),
                        axis=0)

                encoding = f(dict(images=x),
                             signature="gaussian_encoder",
                             as_dict=True)

                return np.array(encoding["mean"]), compute_gaussian_kl(
                    np.array(encoding["mean"]), np.array(encoding["logvar"]))

            representation_functions.append(_representation_function)

        results_dict = evaluation_fn(
            dataset,
            representation_functions,
            random_state=np.random.RandomState(random_seed))

    original_results_dir = os.path.join(model_dirs[0], "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
示例#15
0
def main(unused_argv):
    # Obtain the study to reproduce.
    study = reproduce.STUDIES[FLAGS.study]
    dataset_names = ["cars3d", "smallnorb"]

    for dataset_name in dataset_names:
        postprocess_config_files = sorted(study.get_postprocess_config_files())
        for beta in [1e-3, 1e-2, 0.1, 1, 10, 100, 1000]:
            # Set correct output directory.
            if FLAGS.output_directory is None:
                output_directory = os.path.join("output", "{study}",
                                                dataset_name, "{beta}")
            else:
                output_directory = FLAGS.output_directory

            # Insert model number and study name into path if necessary.
            output_directory = output_directory.format(
                beta=str(beta), study="test_benchmark-experiment-6.1")

            # Model training (if model directory is not provided).

            model_bindings, model_config_file = get_model_configs(
                beta, dataset_name)
            logging.info("Training model...")
            model_dir = os.path.join(output_directory, "model")
            model_bindings = [
                "model.name = '{}'".format(
                    os.path.basename(model_config_file)).replace(".gin",
                                                                 ""),  # ,
                # "model.model_num = {}".format(FLAGS.model_num),
            ] + model_bindings
            train.train_with_gin(model_dir, FLAGS.overwrite,
                                 [model_config_file], model_bindings)

            # We visualize reconstructions, samples and latent space traversals.
            visualize_dir = os.path.join(output_directory, "visualizations")
            visualize_model.visualize(model_dir, visualize_dir,
                                      FLAGS.overwrite)

            # We extract the different representations and save them to disk.
            random_state = np.random.RandomState(0)
            postprocess_config_files = sorted(
                study.get_postprocess_config_files())
            for config in postprocess_config_files:
                post_name = os.path.basename(config).replace(".gin", "")
                logging.info("Extracting representation %s...", post_name)
                post_dir = os.path.join(output_directory, "postprocessed",
                                        post_name)
                postprocess_bindings = [
                    "postprocess.random_seed = {}".format(
                        random_state.randint(2**16)),
                    "postprocess.name = '{}'".format(post_name)
                ]
                postprocess.postprocess_with_gin(model_dir, post_dir,
                                                 FLAGS.overwrite, [config],
                                                 postprocess_bindings)

            #Get representations and save to disk
            gin.parse_config_files_and_bindings(
                [], ["dataset.name = {}".format("'{}'".format(dataset_name))])
            dataset = named_data.get_named_ground_truth_data()
            factors, reps = get_representations(dataset, post_dir,
                                                dataset_name)
            pickle.dump(factors, open(os.path.join(post_dir, "factors.p"),
                                      "wb"))
            pickle.dump(reps, open(os.path.join(post_dir, "reps.p"), "wb"))
            gin.clear_config()
示例#16
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name=""):
    """Loads a representation TFHub module and computes disentanglement metrics.

  Args:
    model_dir: String with path to directory where the representation function
      is saved.
    output_dir: String with the path where the results should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    evaluation_fn: Function used to evaluate the representation (see metrics/
      for examples).
    random_seed: Integer with random seed used for training.
    name: Optional string with name of the metric (can be used to name metrics).
  """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "postprocess.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained representation.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _representation_function(x):
            """Computes representation vector for input images."""
            output = f(dict(images=x),
                       signature="representation",
                       as_dict=True)
            return np.array(output["default"])

        # Computes scores of the representation based on the evaluation_fn.
        if _has_kwarg_or_kwargs(evaluation_fn, "artifact_dir"):
            artifact_dir = os.path.join(model_dir, "artifacts")
            results_dict = evaluation_fn(
                dataset,
                _representation_function,
                random_state=np.random.RandomState(random_seed),
                artifact_dir=artifact_dir)
        else:
            # Legacy code path to allow for old evaluation metrics.
            warnings.warn(
                "Evaluation function does not appear to accept an"
                " `artifact_dir` argument. This may not be compatible with "
                "future versions.", DeprecationWarning)
            results_dict = evaluation_fn(
                dataset,
                _representation_function,
                random_state=np.random.RandomState(random_seed))

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
示例#17
0
def visualize(model_dir,
              output_dir,
              overwrite=False,
              num_animations=5,
              num_frames=20,
              fps=10,
              num_points_irs=10000):
    """Takes trained model from model_dir and visualizes it in output_dir.

    Args:
      model_dir: Path to directory where the trained model is saved.
      output_dir: Path to output directory.
      overwrite: Boolean indicating whether to overwrite output directory.
      num_animations: Integer with number of distinct animations to create.
      num_frames: Integer with number of frames in each animation.
      fps: Integer with frame rate for the animation.
      num_points_irs: Number of points to be used for the IRS plots.
    """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    module_path = os.path.join(model_dir, "tfhub")

    with hub.eval_function_for_module(module_path) as f:
        # Save reconstructions.
        real_pics = dataset.sample_observations(num_pics, random_state)
        raw_pics = f(dict(images=real_pics),
                     signature="reconstructions",
                     as_dict=True)["images"]
        pics = activation(raw_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

        # Save samples.
        def _decoder(latent_vectors):
            return f(dict(latent_vectors=latent_vectors),
                     signature="decoder",
                     as_dict=True)["images"]

        num_latent = int(gin_dict["encoder.num_latent"])
        num_pics = 64
        random_codes = random_state.normal(0, 1, [num_pics, num_latent])
        pics = activation(_decoder(random_codes))
        results_dir = os.path.join(output_dir, "sampled")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            pics, os.path.join(results_dir, "samples.jpg"))

        # Save latent traversals.
        result = f(
            dict(images=dataset.sample_observations(num_pics, random_state)),
            signature="gaussian_encoder",
            as_dict=True)
        means = result["mean"]
        logvars = result["logvar"]
        results_dir = os.path.join(output_dir, "traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        for i in range(means.shape[1]):
            pics = activation(
                latent_traversal_1d_multi_dim(_decoder, means[i, :], None))
            file_name = os.path.join(results_dir, "traversals{}.jpg".format(i))
            visualize_util.grid_save_images([pics], file_name)

        # Save the latent traversal animations.
        results_dir = os.path.join(output_dir, "animated_traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)

        # Cycle through quantiles of a standard Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j], num_frames)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "std_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through quantiles of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j],
                    num_frames,
                    loc=loc,
                    scale=np.sqrt(total_variance))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fitted_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through [-2, 2] interval.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, -2., 2.)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fixed_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                scale = np.sqrt(total_variance)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, loc - 2. * scale,
                    loc + 2. * scale)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "conf_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through minmax of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, np.min(means[:, j]),
                    np.max(means[:, j]))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "minmax_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Interventional effects visualization.
        factors = dataset.sample_factors(num_points_irs, random_state)
        obs = dataset.sample_observations_from_factors(factors, random_state)
        batch_size = 64
        num_outputs = 0
        latents = []
        while num_outputs < obs.shape[0]:
            input_batch = obs[num_outputs:min(num_outputs +
                                              batch_size, obs.shape[0])]
            output_batch = f(dict(images=input_batch),
                             signature="gaussian_encoder",
                             as_dict=True)["mean"]
            latents.append(output_batch)
            num_outputs += batch_size
        latents = np.concatenate(latents)

        results_dir = os.path.join(output_dir, "interventional_effects")
        vis_all_interventional_effects(factors, latents, results_dir)

    # Finally, we clear the gin config that we have set.
    gin.clear_config()
示例#18
0
def visualize_supervised(supervised_model_dir,
                         trained_vae_model_dir,
                         output_dir,
                         overwrite=False):
    """Takes trained model from model_dir and visualizes it in output_dir.

  Args:
    model_dir: Path to directory where the trained model is saved.
    output_dir: Path to output directory.
    overwrite: Boolean indicating whether to overwrite output directory.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
    num_points_irs: Number of points to be used for the IRS plots.
  """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(supervised_model_dir, "results", "gin",
                                   "evaluate.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    _, dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    supervised_module_path = os.path.join(supervised_model_dir, "tfhub")

    with hub.eval_function_for_module(supervised_module_path) as f:
        trained_vae_path = os.path.join(trained_vae_model_dir, "tfhub")
        with hub.eval_function_for_module(trained_vae_path) as g:

            def _representation_function(x):
                """Computes representation vector for input images."""
                output = g(dict(images=x),
                           signature="representation",
                           as_dict=True)
                return np.array(output["default"])

            # Save reconstructions.
            real_pics = dataset.sample_observations(num_pics, random_state)
            #      real_pics, _ = dataset.sample_observations_and_labels(num_pics, random_state)
            representations = _representation_function(real_pics)

        print(real_pics.shape, representations.shape)
        decoded_pics = f(dict(representations=representations),
                         signature="reconstructions",
                         as_dict=True)['images']
        pics = activation(decoded_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

    # Finally, we clear the gin config that we have set.
    gin.clear_config()
示例#19
0
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             postprocess_fn=gin.REQUIRED,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             pca_components=gin.REQUIRED,
             name=""):
    """Loads a trained Gaussian encoder and decoder.

  Args:
    model_dir: String with path to directory where the model is saved.
    output_dir: String with the path where the representation should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    postprocess_fn: Function used to extract the representation (see methods.py
      for examples).
    random_seed: Integer with random seed used for postprocessing (may be
      unused).
    name: Optional string with name of the representation (can be used to name
      representations).
  """
    # We do not use the variable 'name'. Instead, it can be used to name
    # representations as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up timer to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained model.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _gaussian_encoder(x):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(images=x),
                       signature="gaussian_encoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return np.array(output["mean"]), np.array(output["logvar"])

        def _decoder(z):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(latent_vectors=z),
                       signature="decoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return np.array(output['images'])

        # Run the postprocessing function which returns a transformation function
        # that can be used to create the representation from the mean and log
        # variance of the Gaussian distribution given by the encoder. Also returns
        # path to a checkpoint if the transformation requires variables.
        transform_fn, transform_checkpoint_path = postprocess_fn(
            dataset, _gaussian_encoder, np.random.RandomState(random_seed),
            output_dir)

        print('\n\n\n Calculating recall')
        # Computes scores of the representation based on the evaluation_fn.
        if _has_kwarg_or_kwargs(evaluation_fn, "artifact_dir"):
            artifact_dir = os.path.join(model_dir, "artifacts")
            results_dict_list = evaluation_fn(
                dataset,
                _gaussian_encoder,
                transform_fn,
                _decoder,
                random_state=np.random.RandomState(random_seed),
                artifact_dir=artifact_dir)
        else:
            # Legacy code path to allow for old evaluation metrics.
            warnings.warn(
                "Evaluation function does not appear to accept an"
                " `artifact_dir` argument. This may not be compatible with "
                "future versions.", DeprecationWarning)
            results_dict_list = evaluation_fn(
                dataset,
                _gaussian_encoder,
                transform_fn,
                _decoder,
                random_state=np.random.RandomState(random_seed))

    # Save the results (and all previous results in the pipeline) on disk.
    for results_dict, pca_comp in list(zip(results_dict_list, pca_components)):
        results_dir = os.path.join(output_dir, "results")
        results_dict["elapsed_time"] = time.time() - experiment_timer
        filename = "evaluation_pca_{}comp".format(pca_comp)
        results.update_result_directory(results_dir, filename, results_dict)
示例#20
0
def train(model_dir,
          overwrite=False,
          model=gin.REQUIRED,
          training_steps=gin.REQUIRED,
          unsupervised_data_seed=gin.REQUIRED,
          supervised_data_seed=gin.REQUIRED,
          model_seed=gin.REQUIRED,
          batch_size=gin.REQUIRED,
          num_labelled_samples=gin.REQUIRED,
          train_percentage=gin.REQUIRED,
          name=""):
  """Trains the estimator and exports the snapshot and the gin config.

  The use of this function requires the gin binding 'dataset.name' to be
  specified as that determines the data set used for training.

  Args:
    model_dir: String with path to directory where model output should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    model: GaussianEncoderModel that should be trained and exported.
    training_steps: Integer with number of training steps.
    unsupervised_data_seed: Integer with random seed used for the unsupervised
      data.
    supervised_data_seed: Integer with random seed for supervised data.
    model_seed: Integer with random seed used for the model.
      batch_size: Integer with the batch size.
    num_labelled_samples: Integer with number of labelled observations for
      training.
    train_percentage: Fraction of the labelled data to use for training (0,1)
    name: Optional string with name of the model (can be used to name models).
  """
  # We do not use the variable 'name'. Instead, it can be used to name results
  # as it will be part of the saved gin config.
  del name

  # Delete the output directory if necessary.
  if tf.gfile.IsDirectory(model_dir):
    if overwrite:
      tf.gfile.DeleteRecursively(model_dir)
    else:
      raise ValueError("Directory already exists and overwrite is False.")

  # Obtain the dataset.
  dataset = named_data.get_named_ground_truth_data()
  (sampled_observations,
   sampled_factors,
   factor_sizes) = semi_supervised_utils.sample_supervised_data(
       supervised_data_seed, dataset, num_labelled_samples)
  # We instantiate the model class.
  if  issubclass(model, semi_supervised_vae.BaseS2VAE):
    model = model(factor_sizes)
  else:
    model = model()

  # We create a TPUEstimator based on the provided model. This is primarily so
  # that we could switch to TPU training in the future. For now, we train
  # locally on GPUs.
  run_config = tpu_config.RunConfig(
      tf_random_seed=model_seed,
      keep_checkpoint_max=1,
      tpu_config=tpu_config.TPUConfig(iterations_per_loop=500))
  tpu_estimator = TPUEstimator(
      use_tpu=False,
      model_fn=model.model_fn,
      model_dir=model_dir,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      config=run_config)

  # Set up time to keep track of elapsed time in results.
  experiment_timer = time.time()
  # Do the actual training.
  tpu_estimator.train(
      input_fn=_make_input_fn(dataset, num_labelled_samples,
                              unsupervised_data_seed, sampled_observations,
                              sampled_factors, train_percentage),
      steps=training_steps)
  # Save model as a TFHub module.
  output_shape = named_data.get_named_ground_truth_data().observation_shape
  module_export_path = os.path.join(model_dir, "tfhub")
  gaussian_encoder_model.export_as_tf_hub(model, output_shape,
                                          tpu_estimator.latest_checkpoint(),
                                          module_export_path)

  # Save the results. The result dir will contain all the results and config
  # files that we copied along, as we progress in the pipeline. The idea is that
  # these files will be available for analysis at the end.
  results_dict = tpu_estimator.evaluate(
      input_fn=_make_input_fn(
          dataset,
          num_labelled_samples,
          unsupervised_data_seed,
          sampled_observations,
          sampled_factors,
          train_percentage,
          num_batches=num_labelled_samples,
          validation=True))
  results_dir = os.path.join(model_dir, "results")
  results_dict["elapsed_time"] = time.time() - experiment_timer
  results.update_result_directory(results_dir, "train", results_dict)
示例#21
0
def train(model_dir,
          overwrite=False,
          model=gin.REQUIRED,
          training_steps=gin.REQUIRED,
          random_seed=gin.REQUIRED,
          batch_size=gin.REQUIRED,
          eval_steps=1000,
          name="",
          model_num=None):
  """Trains the estimator and exports the snapshot and the gin config.

  The use of this function requires the gin binding 'dataset.name' to be
  specified as that determines the data set used for training.

  Args:
    model_dir: String with path to directory where model output should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    model: GaussianEncoderModel that should be trained and exported.
    training_steps: Integer with number of training steps.
    random_seed: Integer with random seed used for training.
    batch_size: Integer with the batch size.
    eval_steps: Optional integer with number of steps used for evaluation.
    name: Optional string with name of the model (can be used to name models).
    model_num: Optional integer with model number (can be used to identify
      models).
  """
  # We do not use the variables 'name' and 'model_num'. Instead, they can be
  # used to name results as they will be part of the saved gin config.
  del name, model_num

  # Delete the output directory if it already exists.
  if tf.gfile.IsDirectory(model_dir):
    if overwrite:
      tf.gfile.DeleteRecursively(model_dir)
    else:
      raise ValueError("Directory already exists and overwrite is False.")

  # Create a numpy random state. We will sample the random seeds for training
  # and evaluation from this.
  random_state = np.random.RandomState(random_seed)

  # Obtain the dataset.
  dataset = named_data.get_named_ground_truth_data()

  # We create a TPUEstimator based on the provided model. This is primarily so
  # that we could switch to TPU training in the future. For now, we train
  # locally on GPUs.
  run_config = tf.contrib.tpu.RunConfig(
      tf_random_seed=random_seed,
      keep_checkpoint_max=1,
      tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=500))
  tpu_estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=False,
      model_fn=model.model_fn,
      model_dir=os.path.join(model_dir, "tf_checkpoint"),
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      config=run_config)

  # Set up time to keep track of elapsed time in results.
  experiment_timer = time.time()

  # Do the actual training.
  tpu_estimator.train(
      input_fn=_make_input_fn(dataset, random_state.randint(2**16)),
      steps=training_steps)

  # Save model as a TFHub module.
  output_shape = named_data.get_named_ground_truth_data().observation_shape
  module_export_path = os.path.join(model_dir, "tfhub")
  gaussian_encoder_model.export_as_tf_hub(model, output_shape,
                                          tpu_estimator.latest_checkpoint(),
                                          module_export_path)

  # Save the results. The result dir will contain all the results and config
  # files that we copied along, as we progress in the pipeline. The idea is that
  # these files will be available for analysis at the end.
  results_dict = tpu_estimator.evaluate(
      input_fn=_make_input_fn(
          dataset, random_state.randint(2**16), num_batches=eval_steps))
  results_dir = os.path.join(model_dir, "results")
  results_dict["elapsed_time"] = time.time() - experiment_timer
  results.update_result_directory(results_dir, "train", results_dict)
示例#22
0
def postprocess(model_dir,
                output_dir,
                overwrite=False,
                postprocess_fn=gin.REQUIRED,
                random_seed=gin.REQUIRED,
                name=""):
    """Loads a trained Gaussian encoder and extracts representation.

  Args:
    model_dir: String with path to directory where the model is saved.
    output_dir: String with the path where the representation should be saved.
    overwrite: Boolean indicating whether to overwrite output directory.
    postprocess_fn: Function used to extract the representation (see methods.py
      for examples).
    random_seed: Integer with random seed used for postprocessing (may be
      unused).
    name: Optional string with name of the representation (can be used to name
      representations).
  """
    # We do not use the variable 'name'. Instead, it can be used to name
    # representations as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Set up timer to keep track of elapsed time in results.
    experiment_timer = time.time()

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    if gin.query_parameter("dataset.name") == "auto":
        # Obtain the dataset name from the gin config of the previous step.
        gin_config_file = os.path.join(model_dir, "results", "gin",
                                       "train.gin")
        gin_dict = results.gin_dict(gin_config_file)
        with gin.unlock_config():
            gin.bind_parameter("dataset.name",
                               gin_dict["dataset.name"].replace("'", ""))
    dataset = named_data.get_named_ground_truth_data()

    # Path to TFHub module of previously trained model.
    module_path = os.path.join(model_dir, "tfhub")
    with hub.eval_function_for_module(module_path) as f:

        def _gaussian_encoder(x):
            """Encodes images using trained model."""
            # Push images through the TFHub module.
            output = f(dict(images=x),
                       signature="gaussian_encoder",
                       as_dict=True)
            # Convert to numpy arrays and return.
            return {key: np.array(values) for key, values in output.items()}

        # Run the postprocessing function which returns a transformation function
        # that can be used to create the representation from the mean and log
        # variance of the Gaussian distribution given by the encoder. Also returns
        # path to a checkpoint if the transformation requires variables.
        transform_fn, transform_checkpoint_path = postprocess_fn(
            dataset, _gaussian_encoder, np.random.RandomState(random_seed),
            output_dir)

        # Takes the "gaussian_encoder" signature, extracts the representation and
        # then saves under the signature "representation".
        tfhub_module_dir = os.path.join(output_dir, "tfhub")
        convolute_hub.convolute_and_save(module_path, "gaussian_encoder",
                                         tfhub_module_dir, transform_fn,
                                         transform_checkpoint_path,
                                         "representation")

    # We first copy over all the prior results and configs.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict = dict(elapsed_time=time.time() - experiment_timer)
    results.update_result_directory(results_dir, "postprocess", results_dict,
                                    original_results_dir)