Exemplo n.º 1
0
 def test_generator_inference(self):
     """Check one inference step."""
     img_batch = tf.zeros([2, 32, 32, 3])
     model_output, _ = generator.cyclegan_generator_resnet(img_batch)
     with self.cached_session() as sess:
         sess.run(tf.compat.v1.global_variables_initializer())
         sess.run(model_output)
Exemplo n.º 2
0
  def input_and_output_same_shape(self, kernel_size):
    img_batch = tf.compat.v1.placeholder(tf.float32, shape=[None, 32, 32, 3])
    output_img_batch, _ = generator.cyclegan_generator_resnet(
        img_batch, kernel_size=kernel_size)

    self.assertAllEqual(img_batch.shape.as_list(),
                        output_img_batch.shape.as_list())
Exemplo n.º 3
0
    def test_generator_unknown_batch_dim(self):
        """Check that generator can take unknown batch dimension inputs."""
        if tf.executing_eagerly():
            # tf.placeholder() is not compatible with eager execution.
            return
        img = tf.compat.v1.placeholder(tf.float32, shape=[None, 32, None, 3])
        output_imgs, _ = generator.cyclegan_generator_resnet(img)

        self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list())
Exemplo n.º 4
0
def generator(input_images):
    """Thin wrapper around CycleGAN generator to conform to the TF-GAN API.

  Args:
    input_images: A batch of images to translate. Images should be normalized
      already. Shape is [batch, height, width, channels].

  Returns:
    Returns generated image batch.

  Raises:
    ValueError: If shape of last dimension (channels) is not defined.
  """
    input_images.shape.assert_has_rank(4)
    input_size = input_images.shape.as_list()
    channels = input_size[-1]
    if channels is None:
        raise ValueError('Last dimension shape must be known but is None: %s' %
                         input_size)
    output_images, _ = gmodule.cyclegan_generator_resnet(input_images)
    return output_images
Exemplo n.º 5
0
def generator(input_images):
    """Thin wrapper around CycleGAN generator to conform to the TF-GAN API.

  Args:
    input_images: A batch of images to translate. Images should be normalized
      already. Shape is [batch, height, width, channels].

  Returns:
    Returns generated image batch.

  Raises:
    ValueError: If shape of last dimension (channels) is not defined.
  """
    input_images.shape.assert_has_rank(4)
    input_size = input_images.shape.as_list()
    channels = input_size[-1]
    if channels is None:
        raise ValueError('Last dimension shape must be known but is None: %s' %
                         input_size)
    output_images, _ = gmodule.cyclegan_generator_resnet(input_images,
                                                         tanh_linear_slope=0.1)
    # Optionally add image to summaries.
    # tf.summary.image('generator_preconcat_residue', output_images)

    # Difference between cycleGAN and the version used for DME:
    # 1. We have a 1 × 1 convolutional path from the input to the output.
    concat_images = tf.concat([output_images, input_images], axis=3)
    output_images = concat_images

    output_images = tf.layers.conv2d(output_images,
                                     channels, [1, 1],
                                     activation=None)
    # Optionally add image to summaries.
    # tf.summary.image('generator_residue', output_images)

    # 2. We model this function as a residual.
    output_images += input_images

    return output_images
Exemplo n.º 6
0
 def test_generator_graph(self, shape):
     """Check that generator can take small and non-square inputs."""
     output_imgs, _ = generator.cyclegan_generator_resnet(tf.ones(shape))
     self.assertAllEqual(shape, output_imgs.shape.as_list())
Exemplo n.º 7
0
    def test_generator_unknown_batch_dim(self):
        """Check that generator can take unknown batch dimension inputs."""
        img = tf.compat.v1.placeholder(tf.float32, shape=[None, 32, None, 3])
        output_imgs, _ = generator.cyclegan_generator_resnet(img)

        self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list())