def _test_correct_helper(self, use_weight_factor): variable_list = [variables.Variable(1.0)] main_loss = variable_list[0] * 2 adversarial_loss = variable_list[0] * 3 gradient_ratio_epsilon = 1e-6 if use_weight_factor: weight_factor = constant_op.constant(2.0) gradient_ratio = None adv_coeff = 2.0 expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 else: weight_factor = None gradient_ratio = constant_op.constant(0.5) adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon) expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 combined_loss = tfgan_losses.combine_adversarial_loss( main_loss, adversarial_loss, weight_factor=weight_factor, gradient_ratio=gradient_ratio, gradient_ratio_epsilon=gradient_ratio_epsilon, variables=variable_list) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() self.assertNear(expected_loss, combined_loss.eval(), 1e-5)
def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor): """Test the 0 adversarial weight or grad ratio skips adversarial loss.""" main_loss = constant_op.constant(1.0) adversarial_loss = constant_op.constant(1.0) weight_factor = 0.0 if use_weight_factor else None gradient_ratio = None if use_weight_factor else 0.0 combined_loss = tfgan_losses.combine_adversarial_loss( main_loss, adversarial_loss, weight_factor=weight_factor, gradient_ratio=gradient_ratio, gradient_summaries=False) with self.test_session(use_gpu=True): self.assertEqual(1.0, combined_loss.eval())
def combine_adversarial_loss(gan_loss, gan_model, non_adversarial_loss, weight_factor=None, gradient_ratio=None, gradient_ratio_epsilon=1e-6, scalar_summaries=True, gradient_summaries=True): """Combine adversarial loss and main loss. Uses `combine_adversarial_loss` to combine the losses, and returns a modified GANLoss namedtuple. Args: gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the adversarial loss. gan_model: A GANModel namedtuple. Used to access the generator's variables. non_adversarial_loss: Same as `main_loss` from `combine_adversarial_loss`. weight_factor: Same as `weight_factor` from `combine_adversarial_loss`. gradient_ratio: Same as `gradient_ratio` from `combine_adversarial_loss`. gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from `combine_adversarial_loss`. scalar_summaries: Same as `scalar_summaries` from `combine_adversarial_loss`. gradient_summaries: Same as `gradient_summaries` from `combine_adversarial_loss`. Returns: A modified GANLoss namedtuple, with `non_adversarial_loss` included appropriately. """ combined_loss = losses_impl.combine_adversarial_loss( non_adversarial_loss, gan_loss.generator_loss, weight_factor, gradient_ratio, gradient_ratio_epsilon, gan_model.generator_variables, scalar_summaries, gradient_summaries) return gan_loss._replace(generator_loss=combined_loss)