def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
    """Make a `StarGANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=model_fn_lib.ModeKeys.PREDICT)
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        # pylint:disable=protected-access
        input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
        input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
            input_data_domain_label)
        # pylint:enable=protected-access
        generated_data = generator_fn(input_data, input_data_domain_label)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=None,
        generated_data=generated_data,
        generated_data_domain_target=input_data_domain_label,
        reconstructed_data=None,
        discriminator_input_data_source_predication=None,
        discriminator_generated_data_source_predication=None,
        discriminator_input_data_domain_predication=None,
        discriminator_generated_data_domain_predication=None,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=None,
        discriminator_scope=None,
        discriminator_fn=None)
def get_dummy_gan_model():
  """Similar to get_gan_model()."""
  # TODO(joelshor): Find a better way of creating a variable scope.
  with variable_scope.variable_scope('generator') as gen_scope:
    gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
  with variable_scope.variable_scope('discriminator') as dis_scope:
    dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
  return tfgan_tuples.StarGANModel(
      input_data=array_ops.ones([1, 2, 2, 3]),
      input_data_domain_label=array_ops.ones([1, 2]),
      generated_data=array_ops.ones([1, 2, 2, 3]),
      generated_data_domain_target=array_ops.ones([1, 2]),
      reconstructed_data=array_ops.ones([1, 2, 2, 3]),
      discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var,
      discriminator_generated_data_source_predication=array_ops.ones(
          [1]) * gen_var * dis_var,
      discriminator_input_data_domain_predication=array_ops.ones([1, 2
                                                                 ]) * dis_var,
      discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) *
      gen_var * dis_var,
      generator_variables=[gen_var],
      generator_scope=gen_scope,
      generator_fn=None,
      discriminator_variables=[dis_var],
      discriminator_scope=dis_scope,
      discriminator_fn=None)
    def setUp(self):

        super(StarGANLossWrapperTest, self).setUp()

        self.input_data = array_ops.ones([1, 2, 2, 3])
        self.input_data_domain_label = constant_op.constant([[0, 1]])
        self.generated_data = array_ops.ones([1, 2, 2, 3])
        self.discriminator_input_data_source_predication = array_ops.ones([1])
        self.discriminator_generated_data_source_predication = array_ops.ones(
            [1])

        def _discriminator_fn(inputs, num_domains):
            """Differentiable dummy discriminator for StarGAN."""
            hidden = layers.flatten(inputs)
            output_src = math_ops.reduce_mean(hidden, axis=1)
            output_cls = layers.fully_connected(inputs=hidden,
                                                num_outputs=num_domains,
                                                activation_fn=None,
                                                normalizer_fn=None,
                                                biases_initializer=None)
            return output_src, output_cls

        with variable_scope.variable_scope('discriminator') as dis_scope:
            pass

        self.model = namedtuples.StarGANModel(
            input_data=self.input_data,
            input_data_domain_label=self.input_data_domain_label,
            generated_data=self.generated_data,
            generated_data_domain_target=None,
            reconstructed_data=None,
            discriminator_input_data_source_predication=self.
            discriminator_input_data_source_predication,
            discriminator_generated_data_source_predication=self.
            discriminator_generated_data_source_predication,
            discriminator_input_data_domain_predication=None,
            discriminator_generated_data_domain_predication=None,
            generator_variables=None,
            generator_scope=None,
            generator_fn=None,
            discriminator_variables=None,
            discriminator_scope=dis_scope,
            discriminator_fn=_discriminator_fn)

        self.discriminator_fn = _discriminator_fn
        self.discriminator_scope = dis_scope
Exemplo n.º 4
0
def get_stargan_model():
    """Similar to get_gan_model()."""
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('generator') as gen_scope:
        pass
    with variable_scope.variable_scope('discriminator') as dis_scope:
        pass
    return namedtuples.StarGANModel(
        input_data=array_ops.ones([1, 2, 2, 3]),
        input_data_domain_label=array_ops.ones([1, 2]),
        generated_data=array_ops.ones([1, 2, 2, 3]),
        generated_data_domain_target=array_ops.ones([1, 2]),
        reconstructed_data=array_ops.ones([1, 2, 2, 3]),
        discriminator_input_data_source_predication=array_ops.ones([1]),
        discriminator_generated_data_source_predication=array_ops.ones([1]),
        discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
        discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
        generator_variables=None,
        generator_scope=gen_scope,
        generator_fn=stargan_generator_model,
        discriminator_variables=None,
        discriminator_scope=dis_scope,
        discriminator_fn=stargan_discriminator_model)