예제 #1
0
def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False):
  """Defines the cycle consistency loss.

  Uses `cycle_consistency_loss` to compute the cycle consistency loss for a
  `cyclegan_model`.

  Args:
    cyclegan_model: A `CycleGANModel` namedtuple.
    scope: The scope for the operations performed in computing the loss.
      Defaults to None.
    add_summaries: Whether or not to add detailed summaries for the loss.
      Defaults to False.

  Returns:
    A scalar `Tensor` of cycle consistency loss.

  Raises:
    ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple.
  """
  if not isinstance(cyclegan_model, namedtuples.CycleGANModel):
    raise ValueError(
        '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' %
        type(cyclegan_model))
  return losses_impl.cycle_consistency_loss(
      cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x,
      cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y,
      scope, add_summaries)
예제 #2
0
def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False):
  """Defines the cycle consistency loss.

  Uses `cycle_consistency_loss` to compute the cycle consistency loss for a
  `cyclegan_model`.

  Args:
    cyclegan_model: A `CycleGANModel` namedtuple.
    scope: The scope for the operations performed in computing the loss.
      Defaults to None.
    add_summaries: Whether or not to add detailed summaries for the loss.
      Defaults to False.

  Returns:
    A scalar `Tensor` of cycle consistency loss.

  Raises:
    ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple.
  """
  if not isinstance(cyclegan_model, namedtuples.CycleGANModel):
    raise ValueError(
        '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' %
        type(cyclegan_model))
  return losses_impl.cycle_consistency_loss(
      cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x,
      cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y,
      scope, add_summaries)
예제 #3
0
 def test_correct_loss(self):
   loss = tfgan_losses.cycle_consistency_loss(
       self._data_x, self._reconstructed_data_x, self._data_y,
       self._reconstructed_data_y)
   with self.test_session(use_gpu=True):
     variables.global_variables_initializer().run()
     self.assertNear(5.25, loss.eval(), 1e-5)
예제 #4
0
 def test_correct_loss(self):
     loss = tfgan_losses.cycle_consistency_loss(self._data_x,
                                                self._reconstructed_data_x,
                                                self._data_y,
                                                self._reconstructed_data_y)
     with self.test_session(use_gpu=True):
         variables.global_variables_initializer().run()
         self.assertNear(5.25, loss.eval(), 1e-5)