def _make_prediction_gan_model(input_data, input_data_domain_label, generator_fn, generator_scope): """Make a `StarGANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: # pylint:disable=protected-access input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( input_data_domain_label) # pylint:enable=protected-access generated_data = generator_fn(input_data, input_data_domain_label) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.StarGANModel( input_data=input_data, input_data_domain_label=None, generated_data=generated_data, generated_data_domain_target=input_data_domain_label, reconstructed_data=None, discriminator_input_data_source_predication=None, discriminator_generated_data_source_predication=None, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def get_stargan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('discriminator') as dis_scope: pass with variable_scope.variable_scope('generator') as gen_scope: return namedtuples.StarGANModel( input_data=array_ops.ones([1, 2, 2, 3]), input_data_domain_label=array_ops.ones([1, 2]), generated_data=stargan_generator_model( array_ops.ones([1, 2, 2, 3]), None), generated_data_domain_target=array_ops.ones([1, 2]), reconstructed_data=array_ops.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=array_ops.ones([1]), discriminator_generated_data_source_predication=array_ops.ones([1 ]), discriminator_input_data_domain_predication=array_ops.ones([1, 2]), discriminator_generated_data_domain_predication=array_ops.ones( [1, 2]), generator_variables=None, generator_scope=gen_scope, generator_fn=stargan_generator_model, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=discriminator_model)
def get_dummy_gan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) with variable_scope.variable_scope('discriminator') as dis_scope: dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) return tfgan_tuples.StarGANModel( input_data=array_ops.ones([1, 2, 2, 3]), input_data_domain_label=array_ops.ones([1, 2]), generated_data=array_ops.ones([1, 2, 2, 3]), generated_data_domain_target=array_ops.ones([1, 2]), reconstructed_data=array_ops.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var, discriminator_generated_data_source_predication=array_ops.ones([1]) * gen_var * dis_var, discriminator_input_data_domain_predication=array_ops.ones([1, 2]) * dis_var, discriminator_generated_data_domain_predication=array_ops.ones( [1, 2]) * gen_var * dis_var, generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def setUp(self): super(StarGANLossWrapperTest, self).setUp() self.input_data = array_ops.ones([1, 2, 2, 3]) self.input_data_domain_label = constant_op.constant([[0, 1]]) self.generated_data = array_ops.ones([1, 2, 2, 3]) self.discriminator_input_data_source_predication = array_ops.ones([1]) self.discriminator_generated_data_source_predication = array_ops.ones( [1]) def _discriminator_fn(inputs, num_domains): """Differentiable dummy discriminator for StarGAN.""" hidden = layers.flatten(inputs) output_src = math_ops.reduce_mean(hidden, axis=1) output_cls = layers.fully_connected(inputs=hidden, num_outputs=num_domains, activation_fn=None, normalizer_fn=None, biases_initializer=None) return output_src, output_cls with variable_scope.variable_scope('discriminator') as dis_scope: pass self.model = namedtuples.StarGANModel( input_data=self.input_data, input_data_domain_label=self.input_data_domain_label, generated_data=self.generated_data, generated_data_domain_target=None, reconstructed_data=None, discriminator_input_data_source_predication=self. discriminator_input_data_source_predication, discriminator_generated_data_source_predication=self. discriminator_generated_data_source_predication, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=None, generator_scope=None, generator_fn=None, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=_discriminator_fn) self.discriminator_fn = _discriminator_fn self.discriminator_scope = dis_scope
def stargan_model(generator_fn, discriminator_fn, input_data, input_data_domain_label, generator_scope='Generator', discriminator_scope='Discriminator'): """Returns a StarGAN model outputs and variables. See https://arxiv.org/abs/1711.09020 for more details. Args: generator_fn: A python lambda that takes `inputs` and `targets` as inputs and returns 'generated_data' as the transformed version of `input` based on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, num_domains), and `generated_data` has the same shape as `input`. discriminator_fn: A python lambda that takes `inputs` and `num_domains` as inputs and returns a tuple (`source_prediction`, `domain_prediction`). `source_prediction` represents the source(real/generated) prediction by the discriminator, and `domain_prediction` represents the domain prediction/classification by the discriminator. `source_prediction` has shape (n) and `domain_prediction` has shape (n, num_domains). input_data: Tensor or a list of tensor of shape (n, h, w, c) representing the real input images. input_data_domain_label: Tensor or a list of tensor of shape (batch_size, num_domains) representing the domain label associated with the real images. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: StarGANModel nametuple return the tensor that are needed to compute the loss. Raises: ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully defined in every dimensions. """ # Convert to tensor. input_data = _convert_tensor_or_l_or_d(input_data) input_data_domain_label = _convert_tensor_or_l_or_d( input_data_domain_label) # Convert list of tensor to a single tensor if applicable. if isinstance(input_data, (list, tuple)): input_data = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data], 0) if isinstance(input_data_domain_label, (list, tuple)): input_data_domain_label = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) # Get batch_size, num_domains from the labels. input_data_domain_label.shape.assert_has_rank(2) input_data_domain_label.shape.assert_is_fully_defined() batch_size, num_domains = input_data_domain_label.shape.as_list() # Transform input_data to random target domains. with variable_scope.variable_scope(generator_scope) as generator_scope: generated_data_domain_target = _generate_stargan_random_domain_target( batch_size, num_domains) generated_data = generator_fn(input_data, generated_data_domain_target) # Transform generated_data back to the original input_data domain. with variable_scope.variable_scope(generator_scope, reuse=True): reconstructed_data = generator_fn(generated_data, input_data_domain_label) # Predict source and domain for the generated_data using the discriminator. with variable_scope.variable_scope( discriminator_scope) as discriminator_scope: disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( generated_data, num_domains) # Predict source and domain for the input_data using the discriminator. with variable_scope.variable_scope(discriminator_scope, reuse=True): disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( input_data, num_domains) # Collect trainable variables from the neural networks. generator_variables = variables_lib.get_trainable_variables( generator_scope) discriminator_variables = variables_lib.get_trainable_variables( discriminator_scope) # Create the StarGANModel namedtuple. return namedtuples.StarGANModel( input_data=input_data, input_data_domain_label=input_data_domain_label, generated_data=generated_data, generated_data_domain_target=generated_data_domain_target, reconstructed_data=reconstructed_data, discriminator_input_data_source_predication=disc_input_data_source_pred, discriminator_generated_data_source_predication= disc_gen_data_source_pred, discriminator_input_data_domain_predication=disc_input_data_domain_pred, discriminator_generated_data_domain_predication= disc_gen_data_domain_pred, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=discriminator_variables, discriminator_scope=discriminator_scope, discriminator_fn=discriminator_fn)