Example #1
0
def get_dummy_gan_model(generated_data=None):
  """Returns a GANModel tuple for testing."""
  if generated_data is None:
    generated_data = tf.ones([3, 4])
  # TODO(joelshor): Find a better way of creating a variable scope.
  with tf.compat.v1.variable_scope(
      'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
    gen_var = tf.compat.v1.get_variable(
        'dummy_var', initializer=0.0, use_resource=False)
  with tf.compat.v1.variable_scope(
      'discriminator', reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
    dis_var = tf.compat.v1.get_variable(
        'dummy_var', initializer=0.0, use_resource=False)
  return tfgan.GANModel(
      generator_inputs=None,
      generated_data=generated_data,
      generator_variables=[gen_var],
      generator_scope=gen_scope,
      generator_fn=None,
      real_data=tf.zeros([3, 4]),
      discriminator_real_outputs=tf.ones([1, 2, 3]) * dis_var,
      discriminator_gen_outputs=tf.ones([1, 2, 3]) * gen_var * dis_var,
      discriminator_variables=[dis_var],
      discriminator_scope=dis_scope,
      discriminator_fn=None)
Example #2
0
def get_gan_model():
    # TODO(joelshor): Find a better way of creating a variable scope.
    with tf.compat.v1.variable_scope('generator') as gen_scope:
        pass
    with tf.compat.v1.variable_scope('discriminator') as dis_scope:
        pass
    return tfgan.GANModel(generator_inputs=tf.zeros([4, 32, 32, 3]),
                          generated_data=tf.zeros([4, 32, 32, 3]),
                          generator_variables=[tf.Variable(0),
                                               tf.Variable(1)],
                          generator_scope=gen_scope,
                          generator_fn=generator_model,
                          real_data=tf.ones([4, 32, 32, 3]),
                          discriminator_real_outputs=tf.ones([1, 2, 3]),
                          discriminator_gen_outputs=tf.ones([1, 2, 3]),
                          discriminator_variables=[tf.Variable(0)],
                          discriminator_scope=dis_scope,
                          discriminator_fn=discriminator_model)
Example #3
0
def get_dummy_gan_model():
    # TODO(joelshor): Find a better way of creating a variable scope.
    with tf.compat.v1.variable_scope('generator') as gen_scope:
        gen_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0)
    with tf.compat.v1.variable_scope('discriminator') as dis_scope:
        dis_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0)
    return tfgan.GANModel(
        generator_inputs=None,
        generated_data=tf.ones([3, 4]),
        generator_variables=[gen_var],
        generator_scope=gen_scope,
        generator_fn=None,
        real_data=tf.zeros([3, 4]),
        discriminator_real_outputs=tf.ones([1, 2, 3]) * dis_var,
        discriminator_gen_outputs=tf.ones([1, 2, 3]) * gen_var * dis_var,
        discriminator_variables=[dis_var],
        discriminator_scope=dis_scope,
        discriminator_fn=None)
Example #4
0
 def _partial_model(generator_inputs_np):
     model = tfgan.GANModel(*[None] * 11)
     return model._replace(generator_inputs=tf.constant(
         generator_inputs_np, dtype=tf.float32))
Example #5
0
 def _model(disc_gen_outputs, disc_real_outputs):
     model = tfgan.GANModel(*[None] * 11)
     return model._replace(discriminator_real_outputs=disc_real_outputs,
                           discriminator_gen_outputs=disc_gen_outputs)