Example #1
0
    def test_no_shape_check(self):
        def dummy_generator_model(_):
            return (None, None)

        def dummy_discriminator_model(data, conditioning):  # pylint: disable=unused-argument
            return 1

        with self.assertRaisesRegexp(AttributeError,
                                     'object has no attribute'):
            train.gan_model(dummy_generator_model,
                            dummy_discriminator_model,
                            real_data=array_ops.zeros([1, 2]),
                            generator_inputs=array_ops.zeros([1]),
                            check_shapes=True)
        train.gan_model(dummy_generator_model,
                        dummy_discriminator_model,
                        real_data=array_ops.zeros([1, 2]),
                        generator_inputs=array_ops.zeros([1]),
                        check_shapes=False)
Example #2
0
    def test_doesnt_crash_when_in_nested_scope(self):
        with variable_scope.variable_scope('outer_scope'):
            gan_model = train.gan_model(
                generator_model,
                discriminator_model,
                real_data=array_ops.zeros([1, 2]),
                generator_inputs=random_ops.random_normal([1, 2]))

            # This should work inside a scope.
            train.gan_loss(gan_model, gradient_penalty_weight=1.0)

        # This should also work outside a scope.
        train.gan_loss(gan_model, gradient_penalty_weight=1.0)
Example #3
0
    def test_discriminator_only_sees_pool(self):
        """Checks that discriminator only sees pooled values."""
        def checker_gen_fn(_):
            return constant_op.constant(0.0)

        model = train.gan_model(checker_gen_fn,
                                discriminator_model,
                                real_data=array_ops.zeros([]),
                                generator_inputs=random_ops.random_normal([]))

        def tensor_pool_fn(_):
            return (random_ops.random_uniform([]),
                    random_ops.random_uniform([]))

        def checker_dis_fn(inputs, _):
            """Discriminator that checks that it only sees pooled Tensors."""
            self.assertFalse(constant_op.is_constant(inputs))
            return inputs

        model = model._replace(discriminator_fn=checker_dis_fn)
        train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Example #4
0
def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, add_summaries, mode):
    """Construct a `GANModel`, and optionally pass in `mode`."""
    # If network functions have an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn, mode=mode)
    if 'mode' in inspect.getargspec(discriminator_fn).args:
        discriminator_fn = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn,
                                      discriminator_fn,
                                      real_data,
                                      generator_inputs,
                                      generator_scope=generator_scope,
                                      check_shapes=False)
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with ops.name_scope(None):
            for summary_type in add_summaries:
                _summary_type_map[summary_type](gan_model)

    return gan_model
Example #5
0
  def model_fn(features, labels, mode, params):
    """Model function defining an inpainting estimator."""
    batch_size = params['batch_size']
    z_shape = [batch_size] + params['z_shape']
    add_summaries = params['add_summaries']
    input_clip = params['input_clip']

    z = variable_scope.get_variable(
        name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape),
        constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip))

    generator = functools.partial(generator_fn, mode=mode)
    discriminator = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn=generator,
                                      discriminator_fn=discriminator,
                                      real_data=labels,
                                      generator_inputs=z,
                                      check_shapes=False)

    loss = loss_fn(gan_model, features, labels, add_summaries)

    # Use a variable scope to make sure that estimator variables dont cause
    # save/load problems when restoring from ckpts.
    with variable_scope.variable_scope(OPTIMIZER_NAME):
      opt = optimizer(learning_rate=params['learning_rate'],
                      **params['opt_kwargs'])
      train_op = opt.minimize(
          loss=loss, global_step=training_util.get_or_create_global_step(),
          var_list=[z])

    if add_summaries:
      z_grads = gradients_impl.gradients(loss, z)
      summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads))
      summary.scalar('z_loss/loss', loss)

    return model_fn_lib.EstimatorSpec(mode=mode,
                                      predictions=gan_model.generated_data,
                                      loss=loss,
                                      train_op=train_op)
Example #6
0
def create_callable_gan_model():
    return train.gan_model(Generator(),
                           Discriminator(),
                           real_data=array_ops.zeros([1, 2]),
                           generator_inputs=random_ops.random_normal([1, 2]))