def get_cyclegan_model(): with variable_scope.variable_scope('x2y'): model_x2y = get_gan_model() with variable_scope.variable_scope('y2x'): model_y2x = get_gan_model() return namedtuples.CycleGANModel( model_x2y=model_x2y, model_y2x=model_y2x, reconstructed_x=array_ops.zeros([3, 30, 35, 6]), reconstructed_y=array_ops.zeros([3, 30, 35, 6]))
def test_correct_loss(self): """Test the output of `cycle_consistency_loss`.""" loss = tfgan_losses.cycle_consistency_loss( namedtuples.CycleGANModel( model_x2y=self._model_x2y, model_y2x=self._model_y2x, reconstructed_x=constant_op.constant([9, 8], dtype=dtypes.float32), reconstructed_y=constant_op.constant([7, 2], dtype=dtypes.float32))) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() self.assertNear(5.0, loss.eval(), 1e-5)
def cyclegan_model( # Lambdas defining models. generator_fn, discriminator_fn, # data X and Y. data_x, data_y, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', model_x2y_scope='ModelX2Y', model_y2x_scope='ModelY2X', # Options. check_shapes=True): """Returns a CycleGAN model outputs and variables. See https://arxiv.org/abs/1703.10593 for more details. Args: generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. Defaults to 'Generator'. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Defaults to 'Discriminator'. model_x2y_scope: Optional variable scope for model x2y variables. Defaults to 'ModelX2Y'. model_y2x_scope: Optional variable scope for model y2x variables. Defaults to 'ModelY2X'. check_shapes: If `True`, check that generator produces Tensors that are the same shape as `data_x` (`data_y`). Otherwise, skip this check. Returns: A `CycleGANModel` namedtuple. Raises: ValueError: If `check_shapes` is True and `data_x` or the generator output does not have the same shape as `data_y`. """ # Create models. def _define_partial_model(input_data, output_data): return gan_model(generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=output_data, generator_inputs=input_data, generator_scope=generator_scope, discriminator_scope=discriminator_scope, check_shapes=check_shapes) with variable_scope.variable_scope(model_x2y_scope): model_x2y = _define_partial_model(data_x, data_y) with variable_scope.variable_scope(model_y2x_scope): model_y2x = _define_partial_model(data_y, data_x) with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True): reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data) with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True): reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x, reconstructed_y)
def get_callable_cyclegan_model(): return namedtuples.CycleGANModel(model_x2y=get_callable_gan_model(), model_y2x=get_callable_gan_model(), reconstructed_x=array_ops.ones([1, 2, 3]), reconstructed_y=array_ops.zeros([1, 2, 3]))