def test_no_shape_check(self): def dummy_generator_model(_): return (None, None) def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument return 1 with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): train.gan_model(dummy_generator_model, dummy_discriminator_model, real_data=array_ops.zeros([1, 2]), generator_inputs=array_ops.zeros([1]), check_shapes=True) train.gan_model(dummy_generator_model, dummy_discriminator_model, real_data=array_ops.zeros([1, 2]), generator_inputs=array_ops.zeros([1]), check_shapes=False)
def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( generator_model, discriminator_model, real_data=array_ops.zeros([1, 2]), generator_inputs=random_ops.random_normal([1, 2])) # This should work inside a scope. train.gan_loss(gan_model, gradient_penalty_weight=1.0) # This should also work outside a scope. train.gan_loss(gan_model, gradient_penalty_weight=1.0)
def test_discriminator_only_sees_pool(self): """Checks that discriminator only sees pooled values.""" def checker_gen_fn(_): return constant_op.constant(0.0) model = train.gan_model(checker_gen_fn, discriminator_model, real_data=array_ops.zeros([]), generator_inputs=random_ops.random_normal([])) def tensor_pool_fn(_): return (random_ops.random_uniform([]), random_ops.random_uniform([])) def checker_dis_fn(inputs, _): """Discriminator that checks that it only sees pooled Tensors.""" self.assertFalse(constant_op.is_constant(inputs)) return inputs model = model._replace(discriminator_fn=checker_dis_fn) train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): """Construct a `GANModel`, and optionally pass in `mode`.""" # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) if 'mode' in inspect.getargspec(discriminator_fn).args: discriminator_fn = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope=generator_scope, check_shapes=False) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] with ops.name_scope(None): for summary_type in add_summaries: _summary_type_map[summary_type](gan_model) return gan_model
def model_fn(features, labels, mode, params): """Model function defining an inpainting estimator.""" batch_size = params['batch_size'] z_shape = [batch_size] + params['z_shape'] add_summaries = params['add_summaries'] input_clip = params['input_clip'] z = variable_scope.get_variable( name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) generator = functools.partial(generator_fn, mode=mode) discriminator = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.gan_model(generator_fn=generator, discriminator_fn=discriminator, real_data=labels, generator_inputs=z, check_shapes=False) loss = loss_fn(gan_model, features, labels, add_summaries) # Use a variable scope to make sure that estimator variables dont cause # save/load problems when restoring from ckpts. with variable_scope.variable_scope(OPTIMIZER_NAME): opt = optimizer(learning_rate=params['learning_rate'], **params['opt_kwargs']) train_op = opt.minimize( loss=loss, global_step=training_util.get_or_create_global_step(), var_list=[z]) if add_summaries: z_grads = gradients_impl.gradients(loss, z) summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) summary.scalar('z_loss/loss', loss) return model_fn_lib.EstimatorSpec(mode=mode, predictions=gan_model.generated_data, loss=loss, train_op=train_op)
def create_callable_gan_model(): return train.gan_model(Generator(), Discriminator(), real_data=array_ops.zeros([1, 2]), generator_inputs=random_ops.random_normal([1, 2]))