Exemplo n.º 1
0
def get_dummy_gan_model():
    """Similar to get_gan_model()."""
    # TODO(joelshor): Find a better way of creating a variable scope.
    with tf.compat.v1.variable_scope('generator') as gen_scope:
        gen_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0)
    with tf.compat.v1.variable_scope('discriminator') as dis_scope:
        dis_var = tf.compat.v1.get_variable('dummy_var', initializer=0.0)
    return tfgan.StarGANModel(
        input_data=tf.ones([1, 2, 2, 3]),
        input_data_domain_label=tf.ones([1, 2]),
        generated_data=tf.ones([1, 2, 2, 3]),
        generated_data_domain_target=tf.ones([1, 2]),
        reconstructed_data=tf.ones([1, 2, 2, 3]),
        discriminator_input_data_source_predication=tf.ones([1]) * dis_var,
        discriminator_generated_data_source_predication=tf.ones([1]) *
        gen_var * dis_var,
        discriminator_input_data_domain_predication=tf.ones([1, 2]) * dis_var,
        discriminator_generated_data_domain_predication=tf.ones([1, 2]) *
        gen_var * dis_var,
        generator_variables=[gen_var],
        generator_scope=gen_scope,
        generator_fn=None,
        discriminator_variables=[dis_var],
        discriminator_scope=dis_scope,
        discriminator_fn=None)
Exemplo n.º 2
0
    def setUp(self):

        super(StarGANLossWrapperTest, self).setUp()

        self.input_data = tf.ones([1, 2, 2, 3])
        self.input_data_domain_label = tf.constant([[0, 1]])
        self.generated_data = tf.ones([1, 2, 2, 3])
        self.discriminator_input_data_source_predication = tf.ones([1])
        self.discriminator_generated_data_source_predication = tf.ones([1])

        def _discriminator_fn(inputs, num_domains):
            """Differentiable dummy discriminator for StarGAN."""
            hidden = tf.compat.v1.layers.flatten(inputs)
            output_src = tf.reduce_mean(input_tensor=hidden, axis=1)
            output_cls = tf.compat.v1.layers.dense(hidden, num_domains)
            return output_src, output_cls

        with tf.compat.v1.variable_scope('discriminator') as dis_scope:
            pass

        self.model = tfgan.StarGANModel(
            input_data=self.input_data,
            input_data_domain_label=self.input_data_domain_label,
            generated_data=self.generated_data,
            generated_data_domain_target=None,
            reconstructed_data=None,
            discriminator_input_data_source_predication=self.
            discriminator_input_data_source_predication,
            discriminator_generated_data_source_predication=self.
            discriminator_generated_data_source_predication,
            discriminator_input_data_domain_predication=None,
            discriminator_generated_data_domain_predication=None,
            generator_variables=None,
            generator_scope=None,
            generator_fn=None,
            discriminator_variables=None,
            discriminator_scope=dis_scope,
            discriminator_fn=_discriminator_fn)

        self.discriminator_fn = _discriminator_fn
        self.discriminator_scope = dis_scope
Exemplo n.º 3
0
def get_stargan_model():
    """Similar to get_gan_model()."""
    # TODO(joelshor): Find a better way of creating a variable scope.
    with tf.compat.v1.variable_scope('discriminator') as dis_scope:
        pass
    with tf.compat.v1.variable_scope('generator') as gen_scope:
        return tfgan.StarGANModel(
            input_data=tf.ones([1, 2, 2, 3]),
            input_data_domain_label=tf.ones([1, 2]),
            generated_data=stargan_generator_model(tf.ones([1, 2, 2, 3]),
                                                   None),
            generated_data_domain_target=tf.ones([1, 2]),
            reconstructed_data=tf.ones([1, 2, 2, 3]),
            discriminator_input_data_source_predication=tf.ones([1]),
            discriminator_generated_data_source_predication=tf.ones([1]),
            discriminator_input_data_domain_predication=tf.ones([1, 2]),
            discriminator_generated_data_domain_predication=tf.ones([1, 2]),
            generator_variables=None,
            generator_scope=gen_scope,
            generator_fn=stargan_generator_model,
            discriminator_variables=None,
            discriminator_scope=dis_scope,
            discriminator_fn=discriminator_model)