Ejemplo n.º 1
0
def add_gan_model_image_summaries(gan_model,
                                  grid_size=4,
                                  model_summaries=True):
    """Adds image summaries for real and fake images.

  Args:
    gan_model: A GANModel tuple.
    grid_size: The size of an image grid.
    model_summaries: Also add summaries of the model.

  Raises:
    ValueError: If real and generated data aren't images.
  """
    if isinstance(gan_model, namedtuples.CycleGANModel):
        raise ValueError(
            '`add_gan_model_image_summaries` does not take CycleGANModels. Please '
            'use `add_cyclegan_image_summaries` instead.')
    _assert_is_image(gan_model.real_data)
    _assert_is_image(gan_model.generated_data)

    num_images = grid_size**2
    real_image_shape = gan_model.real_data.shape.as_list()[1:3]
    generated_image_shape = gan_model.generated_data.shape.as_list()[1:3]
    real_channels = gan_model.real_data.shape.as_list()[3]
    generated_channels = gan_model.generated_data.shape.as_list()[3]

    tf.compat.v1.summary.image('real_data',
                               eval_utils.image_grid(
                                   gan_model.real_data[:num_images],
                                   grid_shape=(grid_size, grid_size),
                                   image_shape=real_image_shape,
                                   num_channels=real_channels),
                               max_outputs=1)
    tf.compat.v1.summary.image('generated_data',
                               eval_utils.image_grid(
                                   gan_model.generated_data[:num_images],
                                   grid_shape=(grid_size, grid_size),
                                   image_shape=generated_image_shape,
                                   num_channels=generated_channels),
                               max_outputs=1)

    if model_summaries:
        add_gan_model_summaries(gan_model)
Ejemplo n.º 2
0
 def test_image_grid(self):
     eval_utils.image_grid(input_tensor=tf.zeros([25, 32, 32, 3]),
                           grid_shape=(5, 5))