示例#1
0
    def _test_correct_helper(self, use_weight_factor):
        variable_list = [variables.Variable(1.0)]
        main_loss = variable_list[0] * 2
        adversarial_loss = variable_list[0] * 3
        gradient_ratio_epsilon = 1e-6
        if use_weight_factor:
            weight_factor = constant_op.constant(2.0)
            gradient_ratio = None
            adv_coeff = 2.0
            expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
        else:
            weight_factor = None
            gradient_ratio = constant_op.constant(0.5)
            adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon)
            expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
        combined_loss = tfgan_losses.combine_adversarial_loss(
            main_loss,
            adversarial_loss,
            weight_factor=weight_factor,
            gradient_ratio=gradient_ratio,
            gradient_ratio_epsilon=gradient_ratio_epsilon,
            variables=variable_list)

        with self.test_session(use_gpu=True):
            variables.global_variables_initializer().run()
            self.assertNear(expected_loss, combined_loss.eval(), 1e-5)
  def _test_correct_helper(self, use_weight_factor):
    variable_list = [variables.Variable(1.0)]
    main_loss = variable_list[0] * 2
    adversarial_loss = variable_list[0] * 3
    gradient_ratio_epsilon = 1e-6
    if use_weight_factor:
      weight_factor = constant_op.constant(2.0)
      gradient_ratio = None
      adv_coeff = 2.0
      expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
    else:
      weight_factor = None
      gradient_ratio = constant_op.constant(0.5)
      adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon)
      expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
    combined_loss = tfgan_losses.combine_adversarial_loss(
        main_loss,
        adversarial_loss,
        weight_factor=weight_factor,
        gradient_ratio=gradient_ratio,
        gradient_ratio_epsilon=gradient_ratio_epsilon,
        variables=variable_list)

    with self.test_session(use_gpu=True):
      variables.global_variables_initializer().run()
      self.assertNear(expected_loss, combined_loss.eval(), 1e-5)
示例#3
0
    def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor):
        """Test the 0 adversarial weight or grad ratio skips adversarial loss."""
        main_loss = constant_op.constant(1.0)
        adversarial_loss = constant_op.constant(1.0)

        weight_factor = 0.0 if use_weight_factor else None
        gradient_ratio = None if use_weight_factor else 0.0

        combined_loss = tfgan_losses.combine_adversarial_loss(
            main_loss,
            adversarial_loss,
            weight_factor=weight_factor,
            gradient_ratio=gradient_ratio,
            gradient_summaries=False)

        with self.test_session(use_gpu=True):
            self.assertEqual(1.0, combined_loss.eval())
  def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor):
    """Test the 0 adversarial weight or grad ratio skips adversarial loss."""
    main_loss = constant_op.constant(1.0)
    adversarial_loss = constant_op.constant(1.0)

    weight_factor = 0.0 if use_weight_factor else None
    gradient_ratio = None if use_weight_factor else 0.0

    combined_loss = tfgan_losses.combine_adversarial_loss(
        main_loss,
        adversarial_loss,
        weight_factor=weight_factor,
        gradient_ratio=gradient_ratio,
        gradient_summaries=False)

    with self.test_session(use_gpu=True):
      self.assertEqual(1.0, combined_loss.eval())
示例#5
0
def combine_adversarial_loss(gan_loss,
                             gan_model,
                             non_adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             scalar_summaries=True,
                             gradient_summaries=True):
  """Combine adversarial loss and main loss.

  Uses `combine_adversarial_loss` to combine the losses, and returns
  a modified GANLoss namedtuple.

  Args:
    gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the
      adversarial loss.
    gan_model: A GANModel namedtuple. Used to access the generator's variables.
    non_adversarial_loss: Same as `main_loss` from
      `combine_adversarial_loss`.
    weight_factor: Same as `weight_factor` from
      `combine_adversarial_loss`.
    gradient_ratio: Same as `gradient_ratio` from
      `combine_adversarial_loss`.
    gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from
      `combine_adversarial_loss`.
    scalar_summaries: Same as `scalar_summaries` from
      `combine_adversarial_loss`.
    gradient_summaries: Same as `gradient_summaries` from
      `combine_adversarial_loss`.

  Returns:
    A modified GANLoss namedtuple, with `non_adversarial_loss` included
    appropriately.
  """
  combined_loss = losses_impl.combine_adversarial_loss(
      non_adversarial_loss,
      gan_loss.generator_loss,
      weight_factor,
      gradient_ratio,
      gradient_ratio_epsilon,
      gan_model.generator_variables,
      scalar_summaries,
      gradient_summaries)
  return gan_loss._replace(generator_loss=combined_loss)
示例#6
0
def combine_adversarial_loss(gan_loss,
                             gan_model,
                             non_adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             scalar_summaries=True,
                             gradient_summaries=True):
  """Combine adversarial loss and main loss.

  Uses `combine_adversarial_loss` to combine the losses, and returns
  a modified GANLoss namedtuple.

  Args:
    gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the
      adversarial loss.
    gan_model: A GANModel namedtuple. Used to access the generator's variables.
    non_adversarial_loss: Same as `main_loss` from
      `combine_adversarial_loss`.
    weight_factor: Same as `weight_factor` from
      `combine_adversarial_loss`.
    gradient_ratio: Same as `gradient_ratio` from
      `combine_adversarial_loss`.
    gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from
      `combine_adversarial_loss`.
    scalar_summaries: Same as `scalar_summaries` from
      `combine_adversarial_loss`.
    gradient_summaries: Same as `gradient_summaries` from
      `combine_adversarial_loss`.

  Returns:
    A modified GANLoss namedtuple, with `non_adversarial_loss` included
    appropriately.
  """
  combined_loss = losses_impl.combine_adversarial_loss(
      non_adversarial_loss,
      gan_loss.generator_loss,
      weight_factor,
      gradient_ratio,
      gradient_ratio_epsilon,
      gan_model.generator_variables,
      scalar_summaries,
      gradient_summaries)
  return gan_loss._replace(generator_loss=combined_loss)