def get_dummy_gan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with tf.compat.v1.variable_scope('generator') as gen_scope: gen_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0) with tf.compat.v1.variable_scope('discriminator') as dis_scope: dis_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0) return tfgan.StarGANModel( input_data=tf.ones([1, 2, 2, 3]), input_data_domain_label=tf.ones([1, 2]), generated_data=tf.ones([1, 2, 2, 3]), generated_data_domain_target=tf.ones([1, 2]), reconstructed_data=tf.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=tf.ones([1]) * dis_var, discriminator_generated_data_source_predication=tf.ones([1]) * gen_var * dis_var, discriminator_input_data_domain_predication=tf.ones([1, 2]) * dis_var, discriminator_generated_data_domain_predication=tf.ones([1, 2]) * gen_var * dis_var, generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def setUp(self): super(StarGANLossWrapperTest, self).setUp() self.input_data = tf.ones([1, 2, 2, 3]) self.input_data_domain_label = tf.constant([[0, 1]]) self.generated_data = tf.ones([1, 2, 2, 3]) self.discriminator_input_data_source_predication = tf.ones([1]) self.discriminator_generated_data_source_predication = tf.ones([1]) def _discriminator_fn(inputs, num_domains): """Differentiable dummy discriminator for StarGAN.""" hidden = tf.compat.v1.layers.flatten(inputs) output_src = tf.reduce_mean(input_tensor=hidden, axis=1) output_cls = tf.compat.v1.layers.dense(hidden, num_domains) return output_src, output_cls with tf.compat.v1.variable_scope('discriminator') as dis_scope: pass self.model = tfgan.StarGANModel( input_data=self.input_data, input_data_domain_label=self.input_data_domain_label, generated_data=self.generated_data, generated_data_domain_target=None, reconstructed_data=None, discriminator_input_data_source_predication=self. discriminator_input_data_source_predication, discriminator_generated_data_source_predication=self. discriminator_generated_data_source_predication, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=None, generator_scope=None, generator_fn=None, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=_discriminator_fn) self.discriminator_fn = _discriminator_fn self.discriminator_scope = dis_scope
def get_stargan_model(): """Similar to get_gan_model().""" # TODO(joelshor): Find a better way of creating a variable scope. with tf.compat.v1.variable_scope('discriminator') as dis_scope: pass with tf.compat.v1.variable_scope('generator') as gen_scope: return tfgan.StarGANModel( input_data=tf.ones([1, 2, 2, 3]), input_data_domain_label=tf.ones([1, 2]), generated_data=stargan_generator_model(tf.ones([1, 2, 2, 3]), None), generated_data_domain_target=tf.ones([1, 2]), reconstructed_data=tf.ones([1, 2, 2, 3]), discriminator_input_data_source_predication=tf.ones([1]), discriminator_generated_data_source_predication=tf.ones([1]), discriminator_input_data_domain_predication=tf.ones([1, 2]), discriminator_generated_data_domain_predication=tf.ones([1, 2]), generator_variables=None, generator_scope=gen_scope, generator_fn=stargan_generator_model, discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=discriminator_model)