def _make_prediction_gan_model(input_data, input_data_domain_label, generator_fn, generator_scope): """Make a `StarGANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: # pylint:disable=protected-access input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( input_data_domain_label) # pylint:enable=protected-access generated_data = generator_fn(input_data, input_data_domain_label) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.StarGANModel( input_data=input_data, input_data_domain_label=None, generated_data=generated_data, generated_data_domain_target=input_data_domain_label, reconstructed_data=None, discriminator_input_data_source_predication=None, discriminator_generated_data_source_predication=None, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def get_dummy_gan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) with variable_scope.variable_scope('discriminator') as dis_scope: dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) return tfgan_tuples.StarGANModel( input_data=array_ops.ones([1, 2, 2, 3]), input_data_domain_label=array_ops.ones([1, 2]), generated_data=array_ops.ones([1, 2, 2, 3]), generated_data_domain_target=array_ops.ones([1, 2]), reconstructed_data=array_ops.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var, discriminator_generated_data_source_predication=array_ops.ones( [1]) * gen_var * dis_var, discriminator_input_data_domain_predication=array_ops.ones([1, 2 ]) * dis_var, discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) * gen_var * dis_var, generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def setUp(self): super(StarGANLossWrapperTest, self).setUp() self.input_data = array_ops.ones([1, 2, 2, 3]) self.input_data_domain_label = constant_op.constant([[0, 1]]) self.generated_data = array_ops.ones([1, 2, 2, 3]) self.discriminator_input_data_source_predication = array_ops.ones([1]) self.discriminator_generated_data_source_predication = array_ops.ones( [1]) def _discriminator_fn(inputs, num_domains): """Differentiable dummy discriminator for StarGAN.""" hidden = layers.flatten(inputs) output_src = math_ops.reduce_mean(hidden, axis=1) output_cls = layers.fully_connected(inputs=hidden, num_outputs=num_domains, activation_fn=None, normalizer_fn=None, biases_initializer=None) return output_src, output_cls with variable_scope.variable_scope('discriminator') as dis_scope: pass self.model = namedtuples.StarGANModel( input_data=self.input_data, input_data_domain_label=self.input_data_domain_label, generated_data=self.generated_data, generated_data_domain_target=None, reconstructed_data=None, discriminator_input_data_source_predication=self. discriminator_input_data_source_predication, discriminator_generated_data_source_predication=self. discriminator_generated_data_source_predication, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=None, generator_scope=None, generator_fn=None, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=_discriminator_fn) self.discriminator_fn = _discriminator_fn self.discriminator_scope = dis_scope
def get_stargan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: pass with variable_scope.variable_scope('discriminator') as dis_scope: pass return namedtuples.StarGANModel( input_data=array_ops.ones([1, 2, 2, 3]), input_data_domain_label=array_ops.ones([1, 2]), generated_data=array_ops.ones([1, 2, 2, 3]), generated_data_domain_target=array_ops.ones([1, 2]), reconstructed_data=array_ops.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=array_ops.ones([1]), discriminator_generated_data_source_predication=array_ops.ones([1]), discriminator_input_data_domain_predication=array_ops.ones([1, 2]), discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), generator_variables=None, generator_scope=gen_scope, generator_fn=stargan_generator_model, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=stargan_discriminator_model)