예제 #1
0
def get_gan_model():
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('generator') as gen_scope:
        pass
    with variable_scope.variable_scope('discriminator') as dis_scope:
        pass
    return namedtuples.GANModel(
        generator_inputs=None,
        generated_data=None,
        generator_variables=None,
        generator_scope=gen_scope,
        generator_fn=generator_model,
        real_data=array_ops.ones([1, 2, 3]),
        discriminator_real_outputs=array_ops.ones([1, 2, 3]),
        discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
        discriminator_variables=None,
        discriminator_scope=dis_scope,
        discriminator_fn=discriminator_model)
def get_gan_model():
    # TODO(joelshor): Find a better way of creating a variable scope.
    with variable_scope.variable_scope('generator') as gen_scope:
        gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
    with variable_scope.variable_scope('discriminator') as dis_scope:
        dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
    return tfgan_tuples.GANModel(
        generator_inputs=None,
        generated_data=array_ops.ones([3, 4]),
        generator_variables=[gen_var],
        generator_scope=gen_scope,
        generator_fn=None,
        real_data=None,
        discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
        discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var *
        dis_var,
        discriminator_variables=[dis_var],
        discriminator_scope=dis_scope,
        discriminator_fn=None)
예제 #3
0
def _make_prediction_gan_model(generator_inputs, generator_fn,
                               generator_scope):
    """Make a `GANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=model_fn_lib.ModeKeys.PREDICT)
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = tfgan_train._convert_tensor_or_l_or_d(
            generator_inputs)  # pylint:disable=protected-access
        generated_data = generator_fn(generator_inputs)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.GANModel(generator_inputs,
                                 generated_data,
                                 generator_variables,
                                 gen_scope,
                                 generator_fn,
                                 real_data=None,
                                 discriminator_real_outputs=None,
                                 discriminator_gen_outputs=None,
                                 discriminator_variables=None,
                                 discriminator_scope=None,
                                 discriminator_fn=None)
 def _partial_model(generator_inputs_np):
     model = namedtuples.GANModel(*[None] * 11)
     return model._replace(generator_inputs=constant_op.constant(
         generator_inputs_np, dtype=dtypes.float32))