def testGetLossRaisesErrors(self): with self.assertRaisesRegex( TypeError, 'custom_name.*must be.*MinDiffLoss.*string.*4.*int'): utils._get_loss(4, 'custom_name') with self.assertRaisesRegex( ValueError, 'custom_name.*must be.*supported values.*bad_name'): utils._get_loss('bad_name', 'custom_name')
def testForMMDLoss(self): loss = utils._get_loss('mmd') self.assertIsInstance(loss, mmd_loss.MMDLoss) loss = utils._get_loss('mmd_loss') self.assertIsInstance(loss, mmd_loss.MMDLoss) loss = utils._get_loss(mmd_loss.MMDLoss()) self.assertIsInstance(loss, mmd_loss.MMDLoss) loss_name = 'custom_name' loss = utils._get_loss(mmd_loss.MMDLoss(name=loss_name)) self.assertIsInstance(loss, mmd_loss.MMDLoss) self.assertEqual(loss.name, loss_name)
def __init__(self, original_model: tf.keras.Model, loss, loss_weight: complex = 1.0, predictions_transform=None, **kwargs): """Initializes a MinDiffModel instance. Raises: ValueError: If `predictions_transform` is passed in but not callable. """ super(MinDiffModel, self).__init__(**kwargs) # Set _auto_track_sub_layers to true to ensure we track the # original_model and MinDiff layers. self._auto_track_sub_layers = True # Track sub layers. self.built = True # This Model is built, original_model may or may not be. self._original_model = original_model self._loss = loss_utils._get_loss(loss) self._loss_weight = loss_weight self._min_diff_loss_metric = tf.keras.metrics.Mean("min_diff_loss") if (predictions_transform is not None and not callable(predictions_transform)): raise ValueError( "`predictions_transform` must be callable if passed " "in, given: {}".format(predictions_transform)) self._predictions_transform = predictions_transform # Clear input_spec in case there is one. We cannot make any strong # assertions because `min_diff_data` may or may not be included and can # have different shapes since weight is optional. self.input_spec = None
def testForCustomLoss(self): class CustomLoss(base_loss.MinDiffLoss): def call(self, x, y): pass loss = CustomLoss() loss_output = utils._get_loss(loss) self.assertIs(loss_output, loss)
def testForAbsoluteCorrelationLoss(self): loss = utils._get_loss('abs_corr') self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) loss = utils._get_loss('abS_coRr') # Strangely capitalized. self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) loss = utils._get_loss('abs_corr_loss') # Other accepted name. self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) loss = utils._get_loss('absolute_correlation') # Other accepted name. self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) loss = utils._get_loss( 'absolute_correlation_loss') # Other accepted name. self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) loss_name = 'custom_name' loss = utils._get_loss(abscorrloss.AbsoluteCorrelationLoss(loss_name)) self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss) self.assertEqual(loss.name, loss_name)
def testAcceptsNone(self): loss = utils._get_loss(None) self.assertIsNone(loss)