Beispiel #1
0
 def test_no_shape_check(self):
   def dummy_generator_model(_):
     return (None, None)
   def dummy_discriminator_model(data, conditioning):  # pylint: disable=unused-argument
     return 1
   with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
     train.gan_model(
         dummy_generator_model,
         dummy_discriminator_model,
         real_data=array_ops.zeros([1, 2]),
         generator_inputs=array_ops.zeros([1]),
         check_shapes=True)
   train.gan_model(
       dummy_generator_model,
       dummy_discriminator_model,
       real_data=array_ops.zeros([1, 2]),
       generator_inputs=array_ops.zeros([1]),
       check_shapes=False)
 def test_no_shape_check(self):
   def dummy_generator_model(_):
     return (None, None)
   def dummy_discriminator_model(data, conditioning):  # pylint: disable=unused-argument
     return 1
   with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
     train.gan_model(
         dummy_generator_model,
         dummy_discriminator_model,
         real_data=array_ops.zeros([1, 2]),
         generator_inputs=array_ops.zeros([1]),
         check_shapes=True)
   train.gan_model(
       dummy_generator_model,
       dummy_discriminator_model,
       real_data=array_ops.zeros([1, 2]),
       generator_inputs=array_ops.zeros([1]),
       check_shapes=False)
Beispiel #3
0
    def test_doesnt_crash_when_in_nested_scope(self):
        with variable_scope.variable_scope('outer_scope'):
            gan_model = train.gan_model(
                generator_model,
                discriminator_model,
                real_data=array_ops.zeros([1, 2]),
                generator_inputs=random_ops.random_normal([1, 2]))

            # This should work inside a scope.
            train.gan_loss(gan_model, gradient_penalty_weight=1.0)

        # This should also work outside a scope.
        train.gan_loss(gan_model, gradient_penalty_weight=1.0)
  def test_doesnt_crash_when_in_nested_scope(self):
    with variable_scope.variable_scope('outer_scope'):
      gan_model = train.gan_model(
          generator_model,
          discriminator_model,
          real_data=array_ops.zeros([1, 2]),
          generator_inputs=random_ops.random_normal([1, 2]))

      # This should work inside a scope.
      train.gan_loss(gan_model, gradient_penalty_weight=1.0)

    # This should also work outside a scope.
    train.gan_loss(gan_model, gradient_penalty_weight=1.0)
Beispiel #5
0
 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
                          generator_inputs, generator_scope, add_summaries):
  """Make a `GANModel` for training."""
  gan_model = tfgan_train.gan_model(
      generator_fn,
      discriminator_fn,
      real_data,
      generator_inputs,
      generator_scope=generator_scope,
      check_shapes=_use_check_shapes(real_data))
  if add_summaries:
    if not isinstance(add_summaries, (tuple, list)):
      add_summaries = [add_summaries]
    with ops.name_scope(None):
      for summary_type in add_summaries:
        _summary_type_map[summary_type](gan_model)

  return gan_model
def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
                          generator_inputs, generator_scope, add_summaries):
    """Make a `GANModel` for training."""
    gan_model = tfgan_train.gan_model(
        generator_fn,
        discriminator_fn,
        real_data,
        generator_inputs,
        generator_scope=generator_scope,
        check_shapes=_use_check_shapes(real_data))
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with ops.name_scope(None):
            for summary_type in add_summaries:
                _summary_type_map[summary_type](gan_model)

    return gan_model
Beispiel #8
0
 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Beispiel #9
0
    def model_fn(features, labels, mode, params):
        """Model function defining an inpainting estimator."""
        batch_size = params['batch_size']
        z_shape = [batch_size] + params['z_shape']
        add_summaries = params['add_summaries']
        input_clip = params['input_clip']

        z = variable_scope.get_variable(
            name=INPUT_NAME,
            initializer=random_ops.truncated_normal(z_shape),
            constraint=lambda x: clip_ops.clip_by_value(
                x, -input_clip, input_clip))

        generator = functools.partial(generator_fn, mode=mode)
        discriminator = functools.partial(discriminator_fn, mode=mode)
        gan_model = tfgan_train.gan_model(generator_fn=generator,
                                          discriminator_fn=discriminator,
                                          real_data=labels,
                                          generator_inputs=z,
                                          check_shapes=False)

        loss = loss_fn(gan_model, features, labels, add_summaries)

        # Use a variable scope to make sure that estimator variables dont cause
        # save/load problems when restoring from ckpts.
        with variable_scope.variable_scope(OPTIMIZER_NAME):
            opt = optimizer(learning_rate=params['learning_rate'],
                            **params['opt_kwargs'])
            train_op = opt.minimize(
                loss=loss,
                global_step=training_util.get_or_create_global_step(),
                var_list=[z])

        if add_summaries:
            z_grads = gradients_impl.gradients(loss, z)
            summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads))
            summary.scalar('z_loss/loss', loss)

        return model_fn_lib.EstimatorSpec(mode=mode,
                                          predictions=gan_model.generated_data,
                                          loss=loss,
                                          train_op=train_op)
def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, add_summaries, mode):
    """Make a `GANModel`, and optionally pass in `mode`."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn,
                                      discriminator_fn,
                                      real_data,
                                      generator_inputs,
                                      generator_scope=generator_scope,
                                      check_shapes=False)
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with ops.name_scope(None):
            for summary_type in add_summaries:
                _summary_type_map[summary_type](gan_model)

    return gan_model
def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, add_summaries, mode):
  """Make a `GANModel`, and optionally pass in `mode`."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn, mode=mode)
  gan_model = tfgan_train.gan_model(
      generator_fn,
      discriminator_fn,
      real_data,
      generator_inputs,
      generator_scope=generator_scope,
      check_shapes=False)
  if add_summaries:
    if not isinstance(add_summaries, (tuple, list)):
      add_summaries = [add_summaries]
    with ops.name_scope(None):
      for summary_type in add_summaries:
        _summary_type_map[summary_type](gan_model)

  return gan_model
  def model_fn(features, labels, mode, params):
    """Model function defining an inpainting estimator."""
    batch_size = params['batch_size']
    z_shape = [batch_size] + params['z_shape']
    add_summaries = params['add_summaries']
    input_clip = params['input_clip']

    z = variable_scope.get_variable(
        name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape),
        constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip))

    generator = functools.partial(generator_fn, mode=mode)
    discriminator = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn=generator,
                                      discriminator_fn=discriminator,
                                      real_data=labels,
                                      generator_inputs=z,
                                      check_shapes=False)

    loss = loss_fn(gan_model, features, labels, add_summaries)

    # Use a variable scope to make sure that estimator variables dont cause
    # save/load problems when restoring from ckpts.
    with variable_scope.variable_scope(OPTIMIZER_NAME):
      opt = optimizer(learning_rate=params['learning_rate'],
                      **params['opt_kwargs'])
      train_op = opt.minimize(
          loss=loss, global_step=training_util.get_or_create_global_step(),
          var_list=[z])

    if add_summaries:
      z_grads = gradients_impl.gradients(loss, z)
      summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads))
      summary.scalar('z_loss/loss', loss)

    return model_fn_lib.EstimatorSpec(mode=mode,
                                      predictions=gan_model.generated_data,
                                      loss=loss,
                                      train_op=train_op)
Beispiel #13
0
def create_callable_gan_model():
    return train.gan_model(Generator(),
                           Discriminator(),
                           real_data=array_ops.zeros([1, 2]),
                           generator_inputs=random_ops.random_normal([1, 2]))
Beispiel #14
0
def create_callable_gan_model():
  return train.gan_model(
      Generator(),
      Discriminator(),
      real_data=array_ops.zeros([1, 2]),
      generator_inputs=random_ops.random_normal([1, 2]))