def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): """Defines the cycle consistency loss. Uses `cycle_consistency_loss` to compute the cycle consistency loss for a `cyclegan_model`. Args: cyclegan_model: A `CycleGANModel` namedtuple. scope: The scope for the operations performed in computing the loss. Defaults to None. add_summaries: Whether or not to add detailed summaries for the loss. Defaults to False. Returns: A scalar `Tensor` of cycle consistency loss. Raises: ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple. """ if not isinstance(cyclegan_model, namedtuples.CycleGANModel): raise ValueError( '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' % type(cyclegan_model)) return losses_impl.cycle_consistency_loss( cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, scope, add_summaries)
def test_correct_loss(self): loss = tfgan_losses.cycle_consistency_loss( self._data_x, self._reconstructed_data_x, self._data_y, self._reconstructed_data_y) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() self.assertNear(5.25, loss.eval(), 1e-5)
def test_correct_loss(self): loss = tfgan_losses.cycle_consistency_loss(self._data_x, self._reconstructed_data_x, self._data_y, self._reconstructed_data_y) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() self.assertNear(5.25, loss.eval(), 1e-5)