Exemplo n.º 1
0
 def _add_comparison_summary(gan_model, reconstructions):
     image_list = (tf.unstack(gan_model.generator_inputs[:1]) +
                   tf.unstack(gan_model.generated_data[:1]) +
                   tf.unstack(reconstructions[:1]))
     tf.summary.image('image_comparison',
                      eval_utils.image_reshaper(image_list,
                                                num_cols=len(image_list)),
                      max_outputs=1)
Exemplo n.º 2
0
def add_image_comparison_summaries(gan_model, num_comparisons=2,
                                   display_diffs=False):
  """Adds image summaries to compare triplets of images.

  The first image is the generator input, the second is the generator output,
  and the third is the real data. This style of comparison is useful for
  image translation problems, where the generator input is a corrupted image,
  the generator output is the reconstruction, and the real data is the target.

  Args:
    gan_model: A GANModel tuple.
    num_comparisons: The number of image triplets to display.
    display_diffs: Also display the difference between generated and target.

  Raises:
    ValueError: If real data, generated data, and generator inputs aren't
      images.
    ValueError: If the generator input, real, and generated data aren't all the
      same size.
  """
  _assert_is_image(gan_model.generator_inputs)
  _assert_is_image(gan_model.generated_data)
  _assert_is_image(gan_model.real_data)

  gan_model.generated_data.shape.assert_is_compatible_with(
      gan_model.generator_inputs.shape)
  gan_model.real_data.shape.assert_is_compatible_with(
      gan_model.generated_data.shape)

  image_list = []
  image_list.extend(
      tf.unstack(gan_model.generator_inputs[:num_comparisons]))
  image_list.extend(
      tf.unstack(gan_model.generated_data[:num_comparisons]))
  image_list.extend(tf.unstack(gan_model.real_data[:num_comparisons]))
  if display_diffs:
    generated_list = tf.unstack(
        gan_model.generated_data[:num_comparisons])
    real_list = tf.unstack(gan_model.real_data[:num_comparisons])
    diffs = [
        tf.abs(_to_float(generated) - _to_float(real)) for
        generated, real in zip(generated_list, real_list)]
    image_list.extend(diffs)

  # Reshape image and display.
  tf.summary.image(
      'image_comparison',
      eval_utils.image_reshaper(image_list, num_cols=num_comparisons),
      max_outputs=1)
Exemplo n.º 3
0
    def _build_image(image):
        """Helper function to create a result for each image on the fly."""

        # Expand the first dimension as batch_size = 1.
        images = tf.expand_dims(image, axis=0)

        # Tile the image num_domains times, so we can get all transformed together.
        images = tf.tile(images, [num_domains, 1, 1, 1])

        # Create the targets to 0, 1, 2, ..., num_domains-1.
        targets = tf.one_hot(list(range(num_domains)), num_domains)

        with tf.compat.v1.variable_scope(stargan_model.generator_scope,
                                         reuse=True):

            # Add the original image.
            output_images_list = [image]

            # Generate the image and add to the list.
            gen_images = stargan_model.generator_fn(images, targets)
            gen_images_list = tf.split(gen_images, num_domains)
            gen_images_list = [
                tf.squeeze(img, axis=0) for img in gen_images_list
            ]
            output_images_list.extend(gen_images_list)

            # Display diffs.
            if display_diffs:
                diff_images = gen_images - images
                diff_images_list = tf.split(diff_images, num_domains)
                diff_images_list = [
                    tf.squeeze(img, axis=0) for img in diff_images_list
                ]
                output_images_list.append(tf.zeros_like(image))
                output_images_list.extend(diff_images_list)

            # Create the final image.
            final_image = eval_utils.image_reshaper(output_images_list,
                                                    num_cols=num_domains + 1)

        # Reduce the first rank.
        return tf.squeeze(final_image, axis=0)
Exemplo n.º 4
0
 def test_image_reshaper_image(self):
     images = eval_utils.image_reshaper(images=tf.zeros([25, 32, 32, 3]),
                                        num_cols=2)
     images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])