def infogan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, unstructured_generator_inputs, structured_generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator'): """Returns an InfoGAN model outputs and variables. See https://arxiv.org/abs/1606.03657 for more details. Args: generator_fn: A python lambda that takes a list of Tensors as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). `logits` are in the range [-inf, inf], and `distribution_list` is a list of Tensorflow distributions representing the predicted noise distribution of the ith structure noise. real_data: A Tensor representing the real data. unstructured_generator_inputs: A list of Tensors to the generator. These tensors represent the unstructured noise or conditioning. structured_generator_inputs: A list of Tensors to the generator. These tensors must have high mutual information with the recognizer. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: An InfoGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. ValueError: If the discriminator output is malformed. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: unstructured_generator_inputs = _convert_tensor_or_l_or_d( unstructured_generator_inputs) structured_generator_inputs = _convert_tensor_or_l_or_d( structured_generator_inputs) generator_inputs = (unstructured_generator_inputs + structured_generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as disc_scope: dis_gen_outputs, predicted_distributions = discriminator_fn( generated_data, generator_inputs) _validate_distributions(predicted_distributions, structured_generator_inputs) with variable_scope.variable_scope(disc_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) if not generated_data.get_shape().is_compatible_with( real_data.get_shape()): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.get_shape(), real_data.get_shape())) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(disc_scope) return namedtuples.InfoGANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, dis_real_outputs, dis_gen_outputs, discriminator_variables, disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, predicted_distributions)
def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], predicted_distributions=[categorical.Categorical([1.0])], discriminator_and_aux_fn=infogan_discriminator_model)
def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], predicted_distributions=[categorical.Categorical([1.0])])