def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
  """Make a `StarGANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=tf.estimator.ModeKeys.PREDICT)
  with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
    # pylint:disable=protected-access
    input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
        input_data_domain_label)
    # pylint:enable=protected-access
    generated_data = generator_fn(input_data, input_data_domain_label)
  generator_variables = contrib.get_trainable_variables(gen_scope)

  return tfgan_tuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=None,
      generated_data=generated_data,
      generated_data_domain_target=input_data_domain_label,
      reconstructed_data=None,
      discriminator_input_data_source_predication=None,
      discriminator_generated_data_source_predication=None,
      discriminator_input_data_domain_predication=None,
      discriminator_generated_data_domain_predication=None,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
示例#2
0
def make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
    """Make a `GANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=tf.estimator.ModeKeys.PREDICT)
    with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
        generator_inputs = tfgan_train._convert_tensor_or_l_or_d(
            generator_inputs)  # pylint:disable=protected-access
        generated_data = generator_fn(generator_inputs)

    generator_variables = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, gen_scope.name)

    return tfgan_tuples.GANModel(generator_inputs,
                                 generated_data,
                                 generator_variables,
                                 gen_scope,
                                 generator_fn,
                                 real_data=None,
                                 discriminator_real_outputs=None,
                                 discriminator_gen_outputs=None,
                                 discriminator_variables=None,
                                 discriminator_scope=None,
                                 discriminator_fn=None)
示例#3
0
def cut_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        feat_discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        feat_discriminator_scope='FeatDiscriminator',
        # Options.
        check_shapes=True):
    """Returns GAN model outputs and variables.

    Args:
      generator_fn: A python lambda that takes `generator_inputs` as inputs and
        returns the outputs of the GAN generator.
      discriminator_fn: A python lambda that takes `real_data`/`generated data`
        and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
      real_data: A Tensor representing the real data.
      generator_inputs: A Tensor or list of Tensors to the generator. In the
        vanilla GAN case, this might be a single noise Tensor. In the conditional
        GAN case, this might be the generator's conditioning.
      generator_scope: Optional generator variable scope. Useful if you want to
        reuse a subgraph that has already been created.
      discriminator_scope: Optional discriminator variable scope. Useful if you
        want to reuse a subgraph that has already been created.
      check_shapes: If `True`, check that generator produces Tensors that are the
        same shape as real data. Otherwise, skip this check.

    Returns:
      A GANModel namedtuple.

    Raises:
      ValueError: If the generator outputs a Tensor that isn't the same shape as
        `real_data`.
      ValueError: If TF is executing eagerly.
    """
    if tf.executing_eagerly():
        raise ValueError('`vut_model` doesn\'t work when executing eagerly.')
    # Create models
    with tf.compat.v1.variable_scope(
            generator_scope, reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with tf.compat.v1.variable_scope(
            discriminator_scope, reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
        discriminator_gen_outputs = discriminator_fn(generated_data,
                                                     generator_inputs)
    with tf.compat.v1.variable_scope(dis_scope, reuse=True):
        real_data = _convert_tensor_or_l_or_d(real_data)
        discriminator_real_outputs = discriminator_fn(real_data,
                                                      generator_inputs)

    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = get_trainable_variables(gen_scope)
    discriminator_variables = get_trainable_variables(dis_scope)

    return CUTModel(generator_inputs, generated_data, generator_variables,
                    gen_scope, generator_fn, real_data,
                    discriminator_real_outputs, discriminator_gen_outputs,
                    discriminator_variables, dis_scope, discriminator_fn)