def get_cyclegan_model():
    with variable_scope.variable_scope('x2y'):
        model_x2y = get_gan_model()
    with variable_scope.variable_scope('y2x'):
        model_y2x = get_gan_model()
    return namedtuples.CycleGANModel(
        model_x2y=model_x2y,
        model_y2x=model_y2x,
        reconstructed_x=array_ops.zeros([3, 30, 35, 6]),
        reconstructed_y=array_ops.zeros([3, 30, 35, 6]))
Esempio n. 2
0
 def test_correct_loss(self):
     """Test the output of `cycle_consistency_loss`."""
     loss = tfgan_losses.cycle_consistency_loss(
         namedtuples.CycleGANModel(
             model_x2y=self._model_x2y,
             model_y2x=self._model_y2x,
             reconstructed_x=constant_op.constant([9, 8],
                                                  dtype=dtypes.float32),
             reconstructed_y=constant_op.constant([7, 2],
                                                  dtype=dtypes.float32)))
     with self.test_session(use_gpu=True):
         variables.global_variables_initializer().run()
         self.assertNear(5.0, loss.eval(), 1e-5)
Esempio n. 3
0
def cyclegan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # data X and Y.
        data_x,
        data_y,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        model_x2y_scope='ModelX2Y',
        model_y2x_scope='ModelY2X',
        # Options.
        check_shapes=True):
    """Returns a CycleGAN model outputs and variables.

  See https://arxiv.org/abs/1703.10593 for more details.

  Args:
    generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
    data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
    data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created. Defaults to 'Generator'.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created. Defaults to
      'Discriminator'.
    model_x2y_scope: Optional variable scope for model x2y variables. Defaults
      to 'ModelX2Y'.
    model_y2x_scope: Optional variable scope for model y2x variables. Defaults
      to 'ModelY2X'.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as `data_x` (`data_y`). Otherwise, skip this check.

  Returns:
    A `CycleGANModel` namedtuple.

  Raises:
    ValueError: If `check_shapes` is True and `data_x` or the generator output
      does not have the same shape as `data_y`.
  """

    # Create models.
    def _define_partial_model(input_data, output_data):
        return gan_model(generator_fn=generator_fn,
                         discriminator_fn=discriminator_fn,
                         real_data=output_data,
                         generator_inputs=input_data,
                         generator_scope=generator_scope,
                         discriminator_scope=discriminator_scope,
                         check_shapes=check_shapes)

    with variable_scope.variable_scope(model_x2y_scope):
        model_x2y = _define_partial_model(data_x, data_y)
    with variable_scope.variable_scope(model_y2x_scope):
        model_y2x = _define_partial_model(data_y, data_x)

    with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True):
        reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
    with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True):
        reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)

    return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
                                     reconstructed_y)
Esempio n. 4
0
def get_callable_cyclegan_model():
    return namedtuples.CycleGANModel(model_x2y=get_callable_gan_model(),
                                     model_y2x=get_callable_gan_model(),
                                     reconstructed_x=array_ops.ones([1, 2, 3]),
                                     reconstructed_y=array_ops.zeros([1, 2,
                                                                      3]))