Exemplo n.º 1
0
def infogan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        unstructured_generator_inputs,
        structured_generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator'):
    """Returns an InfoGAN model outputs and variables.

  See https://arxiv.org/abs/1606.03657 for more details.

  Args:
    generator_fn: A python lambda that takes a list of Tensors as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
      `logits` are in the range [-inf, inf], and `distribution_list` is a list
      of Tensorflow distributions representing the predicted noise distribution
      of the ith structure noise.
    real_data: A Tensor representing the real data.
    unstructured_generator_inputs: A list of Tensors to the generator.
      These tensors represent the unstructured noise or conditioning.
    structured_generator_inputs: A list of Tensors to the generator.
      These tensors must have high mutual information with the recognizer.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    An InfoGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    ValueError: If the discriminator output is malformed.
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        unstructured_generator_inputs = _convert_tensor_or_l_or_d(
            unstructured_generator_inputs)
        structured_generator_inputs = _convert_tensor_or_l_or_d(
            structured_generator_inputs)
        generator_inputs = (unstructured_generator_inputs +
                            structured_generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as disc_scope:
        dis_gen_outputs, predicted_distributions = discriminator_fn(
            generated_data, generator_inputs)
    _validate_distributions(predicted_distributions,
                            structured_generator_inputs)
    with variable_scope.variable_scope(disc_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

    if not generated_data.get_shape().is_compatible_with(
            real_data.get_shape()):
        raise ValueError(
            'Generator output shape (%s) must be the same shape as real data '
            '(%s).' % (generated_data.get_shape(), real_data.get_shape()))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(disc_scope)

    return namedtuples.InfoGANModel(
        generator_inputs,
        generated_data,
        generator_variables,
        gen_scope,
        generator_fn,
        real_data,
        dis_real_outputs,
        dis_gen_outputs,
        discriminator_variables,
        disc_scope,
        lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
        structured_generator_inputs,
        predicted_distributions)
Exemplo n.º 2
0
def get_callable_infogan_model():
    return namedtuples.InfoGANModel(
        *get_callable_gan_model(),
        structured_generator_inputs=[constant_op.constant(0)],
        predicted_distributions=[categorical.Categorical([1.0])],
        discriminator_and_aux_fn=infogan_discriminator_model)
Exemplo n.º 3
0
def get_infogan_model():
    return namedtuples.InfoGANModel(
        *get_gan_model(),
        structured_generator_inputs=[constant_op.constant(0)],
        predicted_distributions=[categorical.Categorical([1.0])])