Ejemplo n.º 1
0
    def define_model(self, images_x, images_y):
        """Defines a CycleGAN model that maps between images_x and images_y.

        Args:
          images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
          images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.
          use_identity_loss: Whether to use identity loss or not

        Returns:
          A `CycleGANModel` namedtuple.
        """
        if self._use_identity_loss:
            cyclegan_model = cyclegan_model_with_identity(
                generator_fn=_shadowdata_generator_model,
                discriminator_fn=_shadowdata_discriminator_model,
                data_x=images_x,
                data_y=images_y)
        else:
            cyclegan_model = tfgan.cyclegan_model(
                generator_fn=_shadowdata_generator_model,
                discriminator_fn=_shadowdata_discriminator_model,
                data_x=images_x,
                data_y=images_y)

        # Add summaries for generated images.
        # tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)

        return cyclegan_model
Ejemplo n.º 2
0
def _define_model(images_x, images_y):
    """Defines a CycleGAN model that maps between images_x and images_y.

  Args:
    images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
    images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.

  Returns:
    A `CycleGANModel` namedtuple.
  """
    cyclegan_model = tfgan.cyclegan_model(
        generator_fn=networks.generator,
        discriminator_fn=networks.discriminator,
        data_x=images_x,
        data_y=images_y)

    # Add summaries for generated images.
    tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)

    return cyclegan_model
Ejemplo n.º 3
0
def create_callable_cyclegan_model():
  return tfgan.cyclegan_model(
      Generator(),
      Discriminator(),
      data_x=tf.zeros([1, 2]),
      data_y=tf.ones([1, 2]))
Ejemplo n.º 4
0
def cyclegan_model_with_identity(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # data X and Y.
        data_x,
        data_y,
        # Optional scopes.
        generator_scope="Generator",
        discriminator_scope="Discriminator",
        model_x2y_scope="ModelX2Y",
        model_y2x_scope="ModelY2X",
        # Options.
        check_shapes=True):
    """Returns a CycleGAN model outputs and variables.

    See https://arxiv.org/abs/1703.10593 for more details.

    Args:
      generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
        returns the outputs of the GAN generator.
      discriminator_fn: A python lambda that takes `real_data`/`generated data`
        and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
      data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
      data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
      generator_scope: Optional generator variable scope. Useful if you want to
        reuse a subgraph that has already been created. Defaults to 'Generator'.
      discriminator_scope: Optional discriminator variable scope. Useful if you
        want to reuse a subgraph that has already been created. Defaults to
        'Discriminator'.
      model_x2y_scope: Optional variable scope for model x2y variables. Defaults
        to 'ModelX2Y'.
      model_y2x_scope: Optional variable scope for model y2x variables. Defaults
        to 'ModelY2X'.
      check_shapes: If `True`, check that generator produces Tensors that are the
        same shape as `data_x` (`data_y`). Otherwise, skip this check.

    Returns:
      A `CycleGANModel` namedtuple.

    Raises:
      ValueError: If `check_shapes` is True and `data_x` or the generator output
        does not have the same shape as `data_y`.
      ValueError: If TF is executing eagerly.
    """
    original_model = tfgan.cyclegan_model(
        generator_fn=generator_fn,
        discriminator_fn=discriminator_fn,
        data_x=data_x,
        data_y=data_y,
        generator_scope=generator_scope,
        discriminator_scope=discriminator_scope,
        model_x2y_scope=model_x2y_scope,
        model_y2x_scope=model_y2x_scope,
        check_shapes=check_shapes)

    with tf.compat.v1.variable_scope(original_model.model_x2y.generator_scope,
                                     reuse=True):
        identity_x = original_model.model_x2y.generator_fn(data_x)
    with tf.compat.v1.variable_scope(original_model.model_y2x.generator_scope,
                                     reuse=True):
        identity_y = original_model.model_y2x.generator_fn(data_y)

    model_w_identity = CycleGANModelWithIdentity(
        model_x2y=original_model.model_x2y,
        model_y2x=original_model.model_y2x,
        reconstructed_x=original_model.reconstructed_x,
        reconstructed_y=original_model.reconstructed_y)
    model_w_identity.identity_x = identity_x
    model_w_identity.identity_y = identity_y

    return model_w_identity
Ejemplo n.º 5
0
def create_cyclegan_model():
  return tfgan.cyclegan_model(
      generator_model,
      discriminator_model,
      data_x=tf.zeros([1, 2]),
      data_y=tf.ones([1, 2]))