def get_dummy_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) with variable_scope.variable_scope('discriminator') as dis_scope: dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) return tfgan_tuples.GANModel( generator_inputs=None, generated_data=array_ops.ones([3, 4]), generator_variables=[gen_var], generator_scope=gen_scope, generator_fn=None, real_data=array_ops.zeros([3, 4]), discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, discriminator_variables=[dis_var], discriminator_scope=dis_scope, discriminator_fn=None)
def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: pass with variable_scope.variable_scope('discriminator') as dis_scope: pass return namedtuples.GANModel( generator_inputs=None, generated_data=None, generator_variables=None, generator_scope=gen_scope, generator_fn=generator_model, real_data=array_ops.ones([1, 2, 3]), discriminator_real_outputs=array_ops.ones([1, 2, 3]), discriminator_gen_outputs=array_ops.ones([1, 2, 3]), discriminator_variables=None, discriminator_scope=dis_scope, discriminator_fn=discriminator_model)
def get_gan_model(): # TODO (joelshor): Find a better way of creating a variable scope. id:731 # https://github.com/imdone/tensorflow/issues/732 with variable_scope.variable_scope('generator') as gen_scope: pass with variable_scope.variable_scope('discriminator') as dis_scope: pass return namedtuples.GANModel( generator_inputs=array_ops.zeros([4, 32, 32, 3]), generated_data=array_ops.zeros([4, 32, 32, 3]), generator_variables=[variables.Variable(0), variables.Variable(1)], generator_scope=gen_scope, generator_fn=generator_model, real_data=array_ops.ones([4, 32, 32, 3]), discriminator_real_outputs=array_ops.ones([1, 2, 3]), discriminator_gen_outputs=array_ops.ones([1, 2, 3]), discriminator_variables=[variables.Variable(0)], discriminator_scope=dis_scope, discriminator_fn=discriminator_model)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d( generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.GANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data=None, discriminator_real_outputs=None, discriminator_gen_outputs=None, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.GANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data=None, discriminator_real_outputs=None, discriminator_gen_outputs=None, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def gan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', # Options. check_shapes=True): """Returns GAN model outputs and variables. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A GANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.GANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn)
def _partial_model(generator_inputs_np): model = namedtuples.GANModel(*[None] * 11) return model._replace(generator_inputs=constant_op.constant( generator_inputs_np, dtype=dtypes.float32))