Пример #1
0
def get_callable_acgan_model():
    return namedtuples.ACGANModel(
        *get_callable_gan_model(),
        one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
        discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3],
                                                                   10),
        discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4],
                                                                  10))
Пример #2
0
def acgan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        one_hot_labels,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        check_shapes=True):
    """Returns an ACGANModel contains all the pieces needed for ACGAN training.

  The `acgan_model` is the same as the `gan_model` with the only difference
  being that the discriminator additionally outputs logits to classify the input
  (real or generated).
  Therefore, an explicit field holding one_hot_labels is necessary, as well as a
  discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
  classification.

  See https://arxiv.org/abs/1610.09585 for more details.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a tuple consisting of two Tensors:
        (1) real/fake logits in the range [-inf, inf]
        (2) classification logits in the range [-inf, inf]
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
      acgan_loss.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A ACGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    TypeError: If the discriminator does not output a tuple consisting of
    (discrimination logits, classification logits).
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as dis_scope:
        (discriminator_gen_outputs, discriminator_gen_classification_logits
         ) = _validate_acgan_discriminator_outputs(
             discriminator_fn(generated_data, generator_inputs))
    with variable_scope.variable_scope(dis_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        (discriminator_real_outputs, discriminator_real_classification_logits
         ) = _validate_acgan_discriminator_outputs(
             discriminator_fn(real_data, generator_inputs))
    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

    return namedtuples.ACGANModel(generator_inputs, generated_data,
                                  generator_variables, gen_scope, generator_fn,
                                  real_data, discriminator_real_outputs,
                                  discriminator_gen_outputs,
                                  discriminator_variables, dis_scope,
                                  discriminator_fn, one_hot_labels,
                                  discriminator_real_classification_logits,
                                  discriminator_gen_classification_logits)