Beispiel #1
0
def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
    """Make a `StarGANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=model_fn_lib.ModeKeys.PREDICT)
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        # pylint:disable=protected-access
        input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
        input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
            input_data_domain_label)
        # pylint:enable=protected-access
        generated_data = generator_fn(input_data, input_data_domain_label)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=None,
        generated_data=generated_data,
        generated_data_domain_target=input_data_domain_label,
        reconstructed_data=None,
        discriminator_input_data_source_predication=None,
        discriminator_generated_data_source_predication=None,
        discriminator_input_data_domain_predication=None,
        discriminator_generated_data_domain_predication=None,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=None,
        discriminator_scope=None,
        discriminator_fn=None)
def get_stargan_model():
    """Similar to get_gan_model()."""
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('discriminator') as dis_scope:
        pass
    with variable_scope.variable_scope('generator') as gen_scope:
        return namedtuples.StarGANModel(
            input_data=array_ops.ones([1, 2, 2, 3]),
            input_data_domain_label=array_ops.ones([1, 2]),
            generated_data=stargan_generator_model(
                array_ops.ones([1, 2, 2, 3]), None),
            generated_data_domain_target=array_ops.ones([1, 2]),
            reconstructed_data=array_ops.ones([1, 2, 2, 3]),
            discriminator_input_data_source_predication=array_ops.ones([1]),
            discriminator_generated_data_source_predication=array_ops.ones([1
                                                                            ]),
            discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
            discriminator_generated_data_domain_predication=array_ops.ones(
                [1, 2]),
            generator_variables=None,
            generator_scope=gen_scope,
            generator_fn=stargan_generator_model,
            discriminator_variables=None,
            discriminator_scope=dis_scope,
            discriminator_fn=discriminator_model)
Beispiel #3
0
def get_dummy_gan_model():
    """Similar to get_gan_model()."""
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('generator') as gen_scope:
        gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
    with variable_scope.variable_scope('discriminator') as dis_scope:
        dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
    return tfgan_tuples.StarGANModel(
        input_data=array_ops.ones([1, 2, 2, 3]),
        input_data_domain_label=array_ops.ones([1, 2]),
        generated_data=array_ops.ones([1, 2, 2, 3]),
        generated_data_domain_target=array_ops.ones([1, 2]),
        reconstructed_data=array_ops.ones([1, 2, 2, 3]),
        discriminator_input_data_source_predication=array_ops.ones([1]) *
        dis_var,
        discriminator_generated_data_source_predication=array_ops.ones([1]) *
        gen_var * dis_var,
        discriminator_input_data_domain_predication=array_ops.ones([1, 2]) *
        dis_var,
        discriminator_generated_data_domain_predication=array_ops.ones(
            [1, 2]) * gen_var * dis_var,
        generator_variables=[gen_var],
        generator_scope=gen_scope,
        generator_fn=None,
        discriminator_variables=[dis_var],
        discriminator_scope=dis_scope,
        discriminator_fn=None)
    def setUp(self):

        super(StarGANLossWrapperTest, self).setUp()

        self.input_data = array_ops.ones([1, 2, 2, 3])
        self.input_data_domain_label = constant_op.constant([[0, 1]])
        self.generated_data = array_ops.ones([1, 2, 2, 3])
        self.discriminator_input_data_source_predication = array_ops.ones([1])
        self.discriminator_generated_data_source_predication = array_ops.ones(
            [1])

        def _discriminator_fn(inputs, num_domains):
            """Differentiable dummy discriminator for StarGAN."""
            hidden = layers.flatten(inputs)
            output_src = math_ops.reduce_mean(hidden, axis=1)
            output_cls = layers.fully_connected(inputs=hidden,
                                                num_outputs=num_domains,
                                                activation_fn=None,
                                                normalizer_fn=None,
                                                biases_initializer=None)
            return output_src, output_cls

        with variable_scope.variable_scope('discriminator') as dis_scope:
            pass

        self.model = namedtuples.StarGANModel(
            input_data=self.input_data,
            input_data_domain_label=self.input_data_domain_label,
            generated_data=self.generated_data,
            generated_data_domain_target=None,
            reconstructed_data=None,
            discriminator_input_data_source_predication=self.
            discriminator_input_data_source_predication,
            discriminator_generated_data_source_predication=self.
            discriminator_generated_data_source_predication,
            discriminator_input_data_domain_predication=None,
            discriminator_generated_data_domain_predication=None,
            generator_variables=None,
            generator_scope=None,
            generator_fn=None,
            discriminator_variables=None,
            discriminator_scope=dis_scope,
            discriminator_fn=_discriminator_fn)

        self.discriminator_fn = _discriminator_fn
        self.discriminator_scope = dis_scope
Beispiel #5
0
def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator'):
    """Returns a StarGAN model outputs and variables.

  See https://arxiv.org/abs/1711.09020 for more details.

  Args:
    generator_fn: A python lambda that takes `inputs` and `targets` as inputs
      and returns 'generated_data' as the transformed version of `input` based
      on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
      num_domains), and `generated_data` has the same shape as `input`.
    discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
      inputs and returns a tuple (`source_prediction`, `domain_prediction`).
      `source_prediction` represents the source(real/generated) prediction by
      the discriminator, and `domain_prediction` represents the domain
      prediction/classification by the discriminator. `source_prediction` has
      shape (n) and `domain_prediction` has shape (n, num_domains).
    input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
      the real input images.
    input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
      num_domains) representing the domain label associated with the real
      images.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    StarGANModel nametuple return the tensor that are needed to compute the
    loss.

  Raises:
    ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
    defined in every dimensions.
  """

    # Convert to tensor.
    input_data = _convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = _convert_tensor_or_l_or_d(
        input_data_domain_label)

    # Convert list of tensor to a single tensor if applicable.
    if isinstance(input_data, (list, tuple)):
        input_data = array_ops.concat(
            [ops.convert_to_tensor(x) for x in input_data], 0)
    if isinstance(input_data_domain_label, (list, tuple)):
        input_data_domain_label = array_ops.concat(
            [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)

    # Get batch_size, num_domains from the labels.
    input_data_domain_label.shape.assert_has_rank(2)
    input_data_domain_label.shape.assert_is_fully_defined()
    batch_size, num_domains = input_data_domain_label.shape.as_list()

    # Transform input_data to random target domains.
    with variable_scope.variable_scope(generator_scope) as generator_scope:
        generated_data_domain_target = _generate_stargan_random_domain_target(
            batch_size, num_domains)
        generated_data = generator_fn(input_data, generated_data_domain_target)

    # Transform generated_data back to the original input_data domain.
    with variable_scope.variable_scope(generator_scope, reuse=True):
        reconstructed_data = generator_fn(generated_data,
                                          input_data_domain_label)

    # Predict source and domain for the generated_data using the discriminator.
    with variable_scope.variable_scope(
            discriminator_scope) as discriminator_scope:
        disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
            generated_data, num_domains)

    # Predict source and domain for the input_data using the discriminator.
    with variable_scope.variable_scope(discriminator_scope, reuse=True):
        disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
            input_data, num_domains)

    # Collect trainable variables from the neural networks.
    generator_variables = variables_lib.get_trainable_variables(
        generator_scope)
    discriminator_variables = variables_lib.get_trainable_variables(
        discriminator_scope)

    # Create the StarGANModel namedtuple.
    return namedtuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=input_data_domain_label,
        generated_data=generated_data,
        generated_data_domain_target=generated_data_domain_target,
        reconstructed_data=reconstructed_data,
        discriminator_input_data_source_predication=disc_input_data_source_pred,
        discriminator_generated_data_source_predication=
        disc_gen_data_source_pred,
        discriminator_input_data_domain_predication=disc_input_data_domain_pred,
        discriminator_generated_data_domain_predication=
        disc_gen_data_domain_pred,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=discriminator_variables,
        discriminator_scope=discriminator_scope,
        discriminator_fn=discriminator_fn)