def add_regularization_loss_summaries(gan_model):
  """Adds summaries for a regularization losses..

  Args:
    gan_model: A GANModel tuple.
  """
  if gan_model.generator_scope:
    summary.scalar(
        'generator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.generator_scope.name))
  if gan_model.discriminator_scope:
    summary.scalar(
        'discriminator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.discriminator_scope.name))
def add_regularization_loss_summaries(gan_model):
  """Adds summaries for a regularization losses..

  Args:
    gan_model: A GANModel tuple.
  """
  if gan_model.generator_scope:
    summary.scalar(
        'generator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.generator_scope.name))
  if gan_model.discriminator_scope:
    summary.scalar(
        'discriminator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.discriminator_scope.name))
Beispiel #3
0
  def testGetRegularizationLoss(self):
    # Empty regularization collection should evaluate to 0.0.
    with self.test_session():
      self.assertEqual(0.0, util.get_regularization_loss().eval())

    # Loss should sum.
    ops.add_to_collection(
        ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
    ops.add_to_collection(
        ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
    with self.test_session():
      self.assertEqual(5.0, util.get_regularization_loss().eval())

    # Check scope capture mechanism.
    with ops.name_scope('scope1'):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
    with self.test_session():
      self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
  def testGetRegularizationLoss(self):
    # Empty regularization collection should evaluate to 0.0.
    with self.cached_session():
      self.assertEqual(0.0, util.get_regularization_loss().eval())

    # Loss should sum.
    ops.add_to_collection(
        ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
    ops.add_to_collection(
        ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
    with self.cached_session():
      self.assertEqual(5.0, util.get_regularization_loss().eval())

    # Check scope capture mechanism.
    with ops.name_scope('scope1'):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
    with self.cached_session():
      self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
Beispiel #5
0
def add_regularization_loss_summaries(gan_model):
  """Adds summaries for a regularization losses..

  Args:
    gan_model: A GANModel tuple.
  """
  if isinstance(gan_model, namedtuples.CycleGANModel):
    with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'):
      add_regularization_loss_summaries(gan_model.model_x2y)
    with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'):
      add_regularization_loss_summaries(gan_model.model_y2x)
    return

  if gan_model.generator_scope:
    summary.scalar(
        'generator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.generator_scope.name))
  if gan_model.discriminator_scope:
    summary.scalar(
        'discriminator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.discriminator_scope.name))
def add_regularization_loss_summaries(gan_model):
  """Adds summaries for a regularization losses..

  Args:
    gan_model: A GANModel tuple.
  """
  if isinstance(gan_model, namedtuples.CycleGANModel):
    with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'):
      add_regularization_loss_summaries(gan_model.model_x2y)
    with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'):
      add_regularization_loss_summaries(gan_model.model_y2x)
    return

  if gan_model.generator_scope:
    summary.scalar(
        'generator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.generator_scope.name))
  if gan_model.discriminator_scope:
    summary.scalar(
        'discriminator_regularization_loss',
        loss_util.get_regularization_loss(gan_model.discriminator_scope.name))