def define_model(self, images_x, images_y): """Defines a CycleGAN model that maps between images_x and images_y. Args: images_x: A 4D float `Tensor` of NHWC format. Images in set X. images_y: A 4D float `Tensor` of NHWC format. Images in set Y. use_identity_loss: Whether to use identity loss or not Returns: A `CycleGANModel` namedtuple. """ if self._use_identity_loss: cyclegan_model = cyclegan_model_with_identity( generator_fn=_shadowdata_generator_model, discriminator_fn=_shadowdata_discriminator_model, data_x=images_x, data_y=images_y) else: cyclegan_model = tfgan.cyclegan_model( generator_fn=_shadowdata_generator_model, discriminator_fn=_shadowdata_discriminator_model, data_x=images_x, data_y=images_y) # Add summaries for generated images. # tfgan.eval.add_cyclegan_image_summaries(cyclegan_model) return cyclegan_model
def _define_model(images_x, images_y): """Defines a CycleGAN model that maps between images_x and images_y. Args: images_x: A 4D float `Tensor` of NHWC format. Images in set X. images_y: A 4D float `Tensor` of NHWC format. Images in set Y. Returns: A `CycleGANModel` namedtuple. """ cyclegan_model = tfgan.cyclegan_model( generator_fn=networks.generator, discriminator_fn=networks.discriminator, data_x=images_x, data_y=images_y) # Add summaries for generated images. tfgan.eval.add_cyclegan_image_summaries(cyclegan_model) return cyclegan_model
def create_callable_cyclegan_model(): return tfgan.cyclegan_model( Generator(), Discriminator(), data_x=tf.zeros([1, 2]), data_y=tf.ones([1, 2]))
def cyclegan_model_with_identity( # Lambdas defining models. generator_fn, discriminator_fn, # data X and Y. data_x, data_y, # Optional scopes. generator_scope="Generator", discriminator_scope="Discriminator", model_x2y_scope="ModelX2Y", model_y2x_scope="ModelY2X", # Options. check_shapes=True): """Returns a CycleGAN model outputs and variables. See https://arxiv.org/abs/1703.10593 for more details. Args: generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. Defaults to 'Generator'. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Defaults to 'Discriminator'. model_x2y_scope: Optional variable scope for model x2y variables. Defaults to 'ModelX2Y'. model_y2x_scope: Optional variable scope for model y2x variables. Defaults to 'ModelY2X'. check_shapes: If `True`, check that generator produces Tensors that are the same shape as `data_x` (`data_y`). Otherwise, skip this check. Returns: A `CycleGANModel` namedtuple. Raises: ValueError: If `check_shapes` is True and `data_x` or the generator output does not have the same shape as `data_y`. ValueError: If TF is executing eagerly. """ original_model = tfgan.cyclegan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, data_x=data_x, data_y=data_y, generator_scope=generator_scope, discriminator_scope=discriminator_scope, model_x2y_scope=model_x2y_scope, model_y2x_scope=model_y2x_scope, check_shapes=check_shapes) with tf.compat.v1.variable_scope(original_model.model_x2y.generator_scope, reuse=True): identity_x = original_model.model_x2y.generator_fn(data_x) with tf.compat.v1.variable_scope(original_model.model_y2x.generator_scope, reuse=True): identity_y = original_model.model_y2x.generator_fn(data_y) model_w_identity = CycleGANModelWithIdentity( model_x2y=original_model.model_x2y, model_y2x=original_model.model_y2x, reconstructed_x=original_model.reconstructed_x, reconstructed_y=original_model.reconstructed_y) model_w_identity.identity_x = identity_x model_w_identity.identity_y = identity_y return model_w_identity
def create_cyclegan_model(): return tfgan.cyclegan_model( generator_model, discriminator_model, data_x=tf.zeros([1, 2]), data_y=tf.ones([1, 2]))