def get_dummy_gan_model(generated_data=None): """Returns a GANModel tuple for testing.""" if generated_data is None: generated_data = tf.ones([3, 4]) # TODO(joelshor): Find a better way of creating a variable scope. with tf.compat.v1.variable_scope( 'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope: gen_var = tf.compat.v1.get_variable( 'dummy_var', initializer=0.0, use_resource=False) with tf.compat.v1.variable_scope( 'discriminator', reuse=tf.compat.v1.AUTO_REUSE) as dis_scope: dis_var = tf.compat.v1.get_variable( 'dummy_var', initializer=0.0, use_resource=False) return tfgan.GANModel( generator_inputs=None, generated_data=generated_data, generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, real_data=tf.zeros([3, 4]), discriminator_real_outputs=tf.ones([1, 2, 3]) * dis_var, discriminator_gen_outputs=tf.ones([1, 2, 3]) * gen_var * dis_var, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with tf.compat.v1.variable_scope('generator') as gen_scope: pass with tf.compat.v1.variable_scope('discriminator') as dis_scope: pass return tfgan.GANModel(generator_inputs=tf.zeros([4, 32, 32, 3]), generated_data=tf.zeros([4, 32, 32, 3]), generator_variables=[tf.Variable(0), tf.Variable(1)], generator_scope=gen_scope, generator_fn=generator_model, real_data=tf.ones([4, 32, 32, 3]), discriminator_real_outputs=tf.ones([1, 2, 3]), discriminator_gen_outputs=tf.ones([1, 2, 3]), discriminator_variables=[tf.Variable(0)], discriminator_scope=dis_scope, discriminator_fn=discriminator_model)
def get_dummy_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with tf.compat.v1.variable_scope('generator') as gen_scope: gen_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0) with tf.compat.v1.variable_scope('discriminator') as dis_scope: dis_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0) return tfgan.GANModel( generator_inputs=None, generated_data=tf.ones([3, 4]), generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, real_data=tf.zeros([3, 4]), discriminator_real_outputs=tf.ones([1, 2, 3]) * dis_var, discriminator_gen_outputs=tf.ones([1, 2, 3]) * gen_var * dis_var, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def _partial_model(generator_inputs_np): model = tfgan.GANModel(*[None] * 11) return model._replace(generator_inputs=tf.constant( generator_inputs_np, dtype=tf.float32))
def _model(disc_gen_outputs, disc_real_outputs): model = tfgan.GANModel(*[None] * 11) return model._replace(discriminator_real_outputs=disc_real_outputs, discriminator_gen_outputs=disc_gen_outputs)