def get_dummy_gan_model():
  # TODO(joelshor): Find a better way of creating a variable scope.
  with variable_scope.variable_scope('generator') as gen_scope:
    gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
  with variable_scope.variable_scope('discriminator') as dis_scope:
    dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
  return tfgan_tuples.GANModel(
      generator_inputs=None,
      generated_data=array_ops.ones([3, 4]),
      generator_variables=[gen_var],
      generator_scope=gen_scope,
      generator_fn=None,
      real_data=array_ops.zeros([3, 4]),
      discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
      discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
      discriminator_variables=[dis_var],
      discriminator_scope=dis_scope,
      discriminator_fn=None)
Esempio n. 2
0
def get_gan_model():
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('generator') as gen_scope:
        pass
    with variable_scope.variable_scope('discriminator') as dis_scope:
        pass
    return namedtuples.GANModel(
        generator_inputs=None,
        generated_data=None,
        generator_variables=None,
        generator_scope=gen_scope,
        generator_fn=generator_model,
        real_data=array_ops.ones([1, 2, 3]),
        discriminator_real_outputs=array_ops.ones([1, 2, 3]),
        discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
        discriminator_variables=None,
        discriminator_scope=dis_scope,
        discriminator_fn=discriminator_model)
Esempio n. 3
0
def get_gan_model():
    # TODO (joelshor): Find a better way of creating a variable scope. id:731
    # https://github.com/imdone/tensorflow/issues/732
    with variable_scope.variable_scope('generator') as gen_scope:
        pass
    with variable_scope.variable_scope('discriminator') as dis_scope:
        pass
    return namedtuples.GANModel(
        generator_inputs=array_ops.zeros([4, 32, 32, 3]),
        generated_data=array_ops.zeros([4, 32, 32, 3]),
        generator_variables=[variables.Variable(0),
                             variables.Variable(1)],
        generator_scope=gen_scope,
        generator_fn=generator_model,
        real_data=array_ops.ones([4, 32, 32, 3]),
        discriminator_real_outputs=array_ops.ones([1, 2, 3]),
        discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
        discriminator_variables=[variables.Variable(0)],
        discriminator_scope=dis_scope,
        discriminator_fn=discriminator_model)
def _make_prediction_gan_model(generator_inputs, generator_fn,
                               generator_scope):
    """Make a `GANModel` from just the generator."""
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = tfgan_train._convert_tensor_or_l_or_d(
            generator_inputs)  # pylint:disable=protected-access
        generated_data = generator_fn(generator_inputs)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.GANModel(generator_inputs,
                                 generated_data,
                                 generator_variables,
                                 gen_scope,
                                 generator_fn,
                                 real_data=None,
                                 discriminator_real_outputs=None,
                                 discriminator_gen_outputs=None,
                                 discriminator_variables=None,
                                 discriminator_scope=None,
                                 discriminator_fn=None)
Esempio n. 5
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
Esempio n. 6
0
def gan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        # Options.
        check_shapes=True):
    """Returns GAN model outputs and variables.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A GANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as dis_scope:
        discriminator_gen_outputs = discriminator_fn(generated_data,
                                                     generator_inputs)
    with variable_scope.variable_scope(dis_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        discriminator_real_outputs = discriminator_fn(real_data,
                                                      generator_inputs)

    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

    return namedtuples.GANModel(generator_inputs, generated_data,
                                generator_variables, gen_scope, generator_fn,
                                real_data, discriminator_real_outputs,
                                discriminator_gen_outputs,
                                discriminator_variables, dis_scope,
                                discriminator_fn)
Esempio n. 7
0
 def _partial_model(generator_inputs_np):
     model = namedtuples.GANModel(*[None] * 11)
     return model._replace(generator_inputs=constant_op.constant(
         generator_inputs_np, dtype=dtypes.float32))