示例#1
0
def visualize_dataset(dataset_name,
                      output_path,
                      num_animations=5,
                      num_frames=20,
                      fps=10):
    """Visualizes the data set by saving images to output_path.

  For each latent factor, outputs 16 images where only that latent factor is
  varied while all others are kept constant.

  Args:
    dataset_name: String with name of dataset as defined in named_data.py.
    output_path: String with path in which to create the visualizations.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
  """
    data = named_data.get_named_ground_truth_data(dataset_name)
    random_state = np.random.RandomState(0)

    # Create output folder if necessary.
    path = os.path.join(output_path, dataset_name)
    if not gfile.IsDirectory(path):
        gfile.MakeDirs(path)

    # Create still images.
    for i in range(data.num_factors):
        factors = data.sample_factors(16, random_state)
        indices = [j for j in range(data.num_factors) if i != j]
        factors[:, indices] = factors[0, indices]
        images = data.sample_observations_from_factors(factors, random_state)
        visualize_util.grid_save_images(
            images, os.path.join(path, "variations_of_factor%s.png" % i))

    # Create animations.
    for i in range(num_animations):
        base_factor = data.sample_factors(1, random_state)
        images = []
        for j, num_atoms in enumerate(data.factors_num_values):
            factors = np.repeat(base_factor, num_frames, axis=0)
            factors[:,
                    j] = visualize_util.cycle_factor(base_factor[0, j],
                                                     num_atoms, num_frames)
            images.append(
                data.sample_observations_from_factors(factors, random_state))
        visualize_util.save_animation(
            np.array(images), os.path.join(path, "animation%d.gif" % i), fps)
def visualize_weakly_supervised_dataset(data,
                                        path,
                                        num_animations=10,
                                        num_frames=20,
                                        fps=10):
    """Visualizes the data set by saving images to output_path.

  For each latent factor, outputs 16 images where only that latent factor is
  varied while all others are kept constant.

  Args:
    data: String with name of dataset as defined in named_data.py.
    path: String with path in which to create the visualizations.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
  """
    random_state = np.random.RandomState(0)

    # Create output folder if necessary.
    if not tf.compat.v1.gfile.IsDirectory(path):
        tf.compat.v1.gfile.MakeDirs(path)

    # Create animations.
    images = []
    for i in range(num_animations):
        images.append([])
        factor = data.sample_factors(1, random_state)

        images[i].append(
            np.squeeze(data.sample_observations_from_factors(
                factor, random_state),
                       axis=0))
        for _ in range(num_frames):
            factor, _ = simple_dynamics(factor, data, random_state)
            images[i].append(
                np.squeeze(data.sample_observations_from_factors(
                    factor, random_state),
                           axis=0))

    visualize_util.save_animation(np.array(images),
                                  os.path.join(path, "animation.gif"), fps)
示例#3
0
def visualize(model_dir,
              output_dir,
              overwrite=False,
              num_animations=5,
              num_frames=20,
              fps=10,
              num_points_irs=10000):
    """Takes trained model from model_dir and visualizes it in output_dir.

    Args:
      model_dir: Path to directory where the trained model is saved.
      output_dir: Path to output directory.
      overwrite: Boolean indicating whether to overwrite output directory.
      num_animations: Integer with number of distinct animations to create.
      num_frames: Integer with number of frames in each animation.
      fps: Integer with frame rate for the animation.
      num_points_irs: Number of points to be used for the IRS plots.
    """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    module_path = os.path.join(model_dir, "tfhub")

    with hub.eval_function_for_module(module_path) as f:
        # Save reconstructions.
        real_pics = dataset.sample_observations(num_pics, random_state)
        raw_pics = f(dict(images=real_pics),
                     signature="reconstructions",
                     as_dict=True)["images"]
        pics = activation(raw_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

        # Save samples.
        def _decoder(latent_vectors):
            return f(dict(latent_vectors=latent_vectors),
                     signature="decoder",
                     as_dict=True)["images"]

        num_latent = int(gin_dict["encoder.num_latent"])
        num_pics = 64
        random_codes = random_state.normal(0, 1, [num_pics, num_latent])
        pics = activation(_decoder(random_codes))
        results_dir = os.path.join(output_dir, "sampled")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            pics, os.path.join(results_dir, "samples.jpg"))

        # Save latent traversals.
        result = f(
            dict(images=dataset.sample_observations(num_pics, random_state)),
            signature="gaussian_encoder",
            as_dict=True)
        means = result["mean"]
        logvars = result["logvar"]
        results_dir = os.path.join(output_dir, "traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        for i in range(means.shape[1]):
            pics = activation(
                latent_traversal_1d_multi_dim(_decoder, means[i, :], None))
            file_name = os.path.join(results_dir, "traversals{}.jpg".format(i))
            visualize_util.grid_save_images([pics], file_name)

        # Save the latent traversal animations.
        results_dir = os.path.join(output_dir, "animated_traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)

        # Cycle through quantiles of a standard Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j], num_frames)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "std_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through quantiles of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j],
                    num_frames,
                    loc=loc,
                    scale=np.sqrt(total_variance))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fitted_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through [-2, 2] interval.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, -2., 2.)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fixed_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                scale = np.sqrt(total_variance)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, loc - 2. * scale,
                    loc + 2. * scale)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "conf_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through minmax of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, np.min(means[:, j]),
                    np.max(means[:, j]))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "minmax_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Interventional effects visualization.
        factors = dataset.sample_factors(num_points_irs, random_state)
        obs = dataset.sample_observations_from_factors(factors, random_state)
        batch_size = 64
        num_outputs = 0
        latents = []
        while num_outputs < obs.shape[0]:
            input_batch = obs[num_outputs:min(num_outputs +
                                              batch_size, obs.shape[0])]
            output_batch = f(dict(images=input_batch),
                             signature="gaussian_encoder",
                             as_dict=True)["mean"]
            latents.append(output_batch)
            num_outputs += batch_size
        latents = np.concatenate(latents)

        results_dir = os.path.join(output_dir, "interventional_effects")
        vis_all_interventional_effects(factors, latents, results_dir)

    # Finally, we clear the gin config that we have set.
    gin.clear_config()
示例#4
0
 def test_save_animation(self):
   path = os.path.join(self.create_tempdir().full_path, "animation.gif")
   images = np.ones((18, 128, 256, 3), dtype=np.float32)
   visualize_util.save_animation([images, images], path, fps=18)