def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
  """Make a `StarGANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=tf.estimator.ModeKeys.PREDICT)
  with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
    # pylint:disable=protected-access
    input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
        input_data_domain_label)
    # pylint:enable=protected-access
    generated_data = generator_fn(input_data, input_data_domain_label)
  generator_variables = contrib.get_trainable_variables(gen_scope)

  return tfgan_tuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=None,
      generated_data=generated_data,
      generated_data_domain_target=input_data_domain_label,
      reconstructed_data=None,
      discriminator_input_data_source_predication=None,
      discriminator_generated_data_source_predication=None,
      discriminator_input_data_domain_predication=None,
      discriminator_generated_data_domain_predication=None,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
Example #2
0
def make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=tf.estimator.ModeKeys.PREDICT)
  with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = contrib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
Example #3
0
def combine_adversarial_loss(main_loss,
                             adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             variables=None,
                             scalar_summaries=True,
                             gradient_summaries=True,
                             scope=None):
    """Utility to combine main and adversarial losses.

  This utility combines the main and adversarial losses in one of two ways.
  1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
  2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
    used to make sure both losses affect weights roughly equally, as in
    https://arxiv.org/pdf/1705.05823.

  One can optionally also visualize the scalar and gradient behavior of the
  losses.

  Args:
    main_loss: A float Tensor of any shape, indicating the main loss. The size
      of the first dimension must be the same as the first dimension of
      adversarial_loss. If main_loss and adversarial_loss are not compatible
      shapes, both will be mean-reduced to just their first dimension (assumed
      to be the batch dimension).
    adversarial_loss: A float Tensor of any shape, indicating the adversarial
      loss. The size of the first dimension must be the same as the first
      dimension of main_loss. If  main_loss and adversarial_loss are not
      compatible shapes, both will be mean-reduced to just their first dimension
      (assumed to be the batch dimension).
    weight_factor: If not `None`, the coefficient by which to multiply the
      adversarial loss. Exactly one of this and `gradient_ratio` must be
      non-None.
    gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
      Specifically,
        gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
      Exactly one of this and `weight_factor` must be non-None.
    gradient_ratio_epsilon: An epsilon to add to the adversarial loss
      coefficient denominator, to avoid division-by-zero.
    variables: List of variables to calculate gradients with respect to. If not
      present, defaults to all trainable variables.
    scalar_summaries: Create scalar summaries of losses. If main_loss and
      adversarial_loss are not scalars, they will be mean-reduced to scalars for
      summary computation.
    gradient_summaries: Create gradient summaries of losses.
    scope: Optional name scope.

  Returns:
    A float Tensor indicating the desired combined loss. If main_loss and
    adversarial_loss are both scalars then this will also be a scalar, otherwise
    it will be of shape [main_loss.shape[0]].

  Raises:
    ValueError: Malformed input.
    RuntimeError: If `tf.gradients` require computing, and TensorFlow is
      executing eagerly.
  """
    _validate_args(weight_factor, gradient_ratio)
    if variables is None:
        variables = contrib.get_trainable_variables()

    with tf.compat.v1.name_scope(scope,
                                 'adversarial_loss',
                                 values=[main_loss, adversarial_loss]):
        # If losses are not the same shape, reduce them to both be shape [batch,].
        if not main_loss.shape.is_compatible_with(adversarial_loss.shape):
            if main_loss.shape[0] != adversarial_loss.shape[0]:
                raise ValueError(
                    'main_loss and adversarial_loss must have the same sized first '
                    'dimension. Found %d and %d.' %
                    (main_loss.shape[0], adversarial_loss.shape[0]))
            tf.compat.v1.logging.warning(
                'Applying mean reduction per-batch-element to main and adversarial '
                'losses to make shapes compatible. If this is undesirable, ensure '
                'that the shapes are compatible before passing them into '
                'combine_adversarial_loss.')
            main_loss = tf.math.reduce_mean(input_tensor=main_loss,
                                            axis=list(
                                                range(1,
                                                      main_loss.shape.rank)))
            adversarial_loss = tf.math.reduce_mean(
                input_tensor=adversarial_loss,
                axis=list(range(1, adversarial_loss.shape.rank)))

        # Compute gradients if we will need them.
        if gradient_summaries or gradient_ratio is not None:
            # `tf.gradients` doesn't work in eager.
            if tf.executing_eagerly():
                raise RuntimeError('`tf.gradients` doesn\'t work in eager.')
            main_loss_grad_mag = numerically_stable_global_norm(
                tf.gradients(ys=main_loss, xs=variables))
            adv_loss_grad_mag = numerically_stable_global_norm(
                tf.gradients(ys=adversarial_loss, xs=variables))

        # Add summaries, if applicable.
        if scalar_summaries:
            tf.compat.v1.summary.scalar(
                'main_loss', tf.math.reduce_mean(input_tensor=main_loss))
            tf.compat.v1.summary.scalar(
                'adversarial_loss',
                tf.math.reduce_mean(input_tensor=adversarial_loss))
        if gradient_summaries:
            tf.compat.v1.summary.scalar('main_loss_gradients',
                                        main_loss_grad_mag)
            tf.compat.v1.summary.scalar('adversarial_loss_gradients',
                                        adv_loss_grad_mag)

        # Combine losses in the appropriate way.
        # If `weight_factor` is always `0`, avoid computing the adversarial loss
        # tensor entirely.
        if _used_weight((weight_factor, gradient_ratio)) == 0:
            final_loss = main_loss
        elif weight_factor is not None:
            final_loss = (main_loss +
                          tf.stop_gradient(weight_factor) * adversarial_loss)
        elif gradient_ratio is not None:
            grad_mag_ratio = main_loss_grad_mag / (adv_loss_grad_mag +
                                                   gradient_ratio_epsilon)
            adv_coeff = grad_mag_ratio / gradient_ratio
            tf.compat.v1.summary.scalar('adversarial_coefficient', adv_coeff)
            final_loss = (main_loss +
                          tf.stop_gradient(adv_coeff) * adversarial_loss)

    return final_loss
Example #4
0
def combine_adversarial_loss(main_loss,
                             adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             variables=None,
                             scalar_summaries=True,
                             gradient_summaries=True,
                             scope=None):
    """Utility to combine main and adversarial losses.

  This utility combines the main and adversarial losses in one of two ways.
  1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
  2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
    used to make sure both losses affect weights roughly equally, as in
    https://arxiv.org/pdf/1705.05823.

  One can optionally also visualize the scalar and gradient behavior of the
  losses.

  Args:
    main_loss: A floating scalar Tensor indicating the main loss.
    adversarial_loss: A floating scalar Tensor indication the adversarial loss.
    weight_factor: If not `None`, the coefficient by which to multiply the
      adversarial loss. Exactly one of this and `gradient_ratio` must be
      non-None.
    gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
      Specifically,
        gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
      Exactly one of this and `weight_factor` must be non-None.
    gradient_ratio_epsilon: An epsilon to add to the adversarial loss
      coefficient denominator, to avoid division-by-zero.
    variables: List of variables to calculate gradients with respect to. If not
      present, defaults to all trainable variables.
    scalar_summaries: Create scalar summaries of losses.
    gradient_summaries: Create gradient summaries of losses.
    scope: Optional name scope.

  Returns:
    A floating scalar Tensor indicating the desired combined loss.

  Raises:
    ValueError: Malformed input.
    RuntimeError: If `tf.gradients` require computing, and TensorFlow is
      executing eagerly.
  """
    _validate_args([main_loss, adversarial_loss], weight_factor,
                   gradient_ratio)
    if variables is None:
        variables = contrib.get_trainable_variables()

    with tf.compat.v1.name_scope(scope,
                                 'adversarial_loss',
                                 values=[main_loss, adversarial_loss]):
        # Compute gradients if we will need them.
        if gradient_summaries or gradient_ratio is not None:
            # `tf.gradients` doesn't work in eager.
            if tf.executing_eagerly():
                raise RuntimeError('`tf.gradients` doesn\'t work in eager.')
            main_loss_grad_mag = numerically_stable_global_norm(
                tf.gradients(ys=main_loss, xs=variables))
            adv_loss_grad_mag = numerically_stable_global_norm(
                tf.gradients(ys=adversarial_loss, xs=variables))

        # Add summaries, if applicable.
        if scalar_summaries:
            tf.compat.v1.summary.scalar('main_loss', main_loss)
            tf.compat.v1.summary.scalar('adversarial_loss', adversarial_loss)
        if gradient_summaries:
            tf.compat.v1.summary.scalar('main_loss_gradients',
                                        main_loss_grad_mag)
            tf.compat.v1.summary.scalar('adversarial_loss_gradients',
                                        adv_loss_grad_mag)

        # Combine losses in the appropriate way.
        # If `weight_factor` is always `0`, avoid computing the adversarial loss
        # tensor entirely.
        if _used_weight((weight_factor, gradient_ratio)) == 0:
            final_loss = main_loss
        elif weight_factor is not None:
            final_loss = (main_loss +
                          tf.stop_gradient(weight_factor) * adversarial_loss)
        elif gradient_ratio is not None:
            grad_mag_ratio = main_loss_grad_mag / (adv_loss_grad_mag +
                                                   gradient_ratio_epsilon)
            adv_coeff = grad_mag_ratio / gradient_ratio
            tf.compat.v1.summary.scalar('adversarial_coefficient', adv_coeff)
            final_loss = (main_loss +
                          tf.stop_gradient(adv_coeff) * adversarial_loss)

    return final_loss
Example #5
0
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)
    num_trainable_vars = len(get_trainable_variables())

    if create_global_step:
      gstep = tf.compat.v1.get_variable(
          'custom_gstep',
          dtype=tf.int32,
          initializer=0,
          trainable=False)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, tfgan.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTraintf.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(hook, get_sync_optimizer_hook_type())
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      sess.run(tf.compat.v1.local_variables_initializer())

      sess.run(g_opt.chief_init_op)
      sess.run(d_opt.chief_init_op)

      gstep_before = sess.run(global_step)

      # Start required queue runner for SyncReplicasOptimizer.
      coord = tf.train.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      sess.run(g_sync_init_op)
      sess.run(d_sync_init_op)

      sess.run(train_ops.generator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      sess.run(train_ops.discriminator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      coord.request_stop()
      coord.join(g_threads + d_threads)
Example #6
0
def cut_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        feat_discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        feat_discriminator_scope='FeatDiscriminator',
        # Options.
        check_shapes=True):
    """Returns GAN model outputs and variables.

    Args:
      generator_fn: A python lambda that takes `generator_inputs` as inputs and
        returns the outputs of the GAN generator.
      discriminator_fn: A python lambda that takes `real_data`/`generated data`
        and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
      real_data: A Tensor representing the real data.
      generator_inputs: A Tensor or list of Tensors to the generator. In the
        vanilla GAN case, this might be a single noise Tensor. In the conditional
        GAN case, this might be the generator's conditioning.
      generator_scope: Optional generator variable scope. Useful if you want to
        reuse a subgraph that has already been created.
      discriminator_scope: Optional discriminator variable scope. Useful if you
        want to reuse a subgraph that has already been created.
      check_shapes: If `True`, check that generator produces Tensors that are the
        same shape as real data. Otherwise, skip this check.

    Returns:
      A GANModel namedtuple.

    Raises:
      ValueError: If the generator outputs a Tensor that isn't the same shape as
        `real_data`.
      ValueError: If TF is executing eagerly.
    """
    if tf.executing_eagerly():
        raise ValueError('`vut_model` doesn\'t work when executing eagerly.')
    # Create models
    with tf.compat.v1.variable_scope(
            generator_scope, reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with tf.compat.v1.variable_scope(
            discriminator_scope, reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
        discriminator_gen_outputs = discriminator_fn(generated_data,
                                                     generator_inputs)
    with tf.compat.v1.variable_scope(dis_scope, reuse=True):
        real_data = _convert_tensor_or_l_or_d(real_data)
        discriminator_real_outputs = discriminator_fn(real_data,
                                                      generator_inputs)

    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = get_trainable_variables(gen_scope)
    discriminator_variables = get_trainable_variables(dis_scope)

    return CUTModel(generator_inputs, generated_data, generator_variables,
                    gen_scope, generator_fn, real_data,
                    discriminator_real_outputs, discriminator_gen_outputs,
                    discriminator_variables, dis_scope, discriminator_fn)