def get_callable_acgan_model(): return namedtuples.ACGANModel( *get_callable_gan_model(), one_hot_labels=array_ops.one_hot([0, 1, 2], 10), discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10), discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
def acgan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, one_hot_labels, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. The `acgan_model` is the same as the `gan_model` with the only difference being that the discriminator additionally outputs logits to classify the input (real or generated). Therefore, an explicit field holding one_hot_labels is necessary, as well as a discriminator_fn that outputs a 2-tuple holding the logits for real/fake and classification. See https://arxiv.org/abs/1610.09585 for more details. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a tuple consisting of two Tensors: (1) real/fake logits in the range [-inf, inf] (2) classification logits in the range [-inf, inf] real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by acgan_loss. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A ACGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. TypeError: If the discriminator does not output a tuple consisting of (discrimination logits, classification logits). """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: (discriminator_gen_outputs, discriminator_gen_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) (discriminator_real_outputs, discriminator_real_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.ACGANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn, one_hot_labels, discriminator_real_classification_logits, discriminator_gen_classification_logits)