예제 #1
0
    def testGetLossRaisesErrors(self):
        with self.assertRaisesRegex(
                TypeError,
                'custom_name.*must be.*MinDiffLoss.*string.*4.*int'):
            utils._get_loss(4, 'custom_name')

        with self.assertRaisesRegex(
                ValueError,
                'custom_name.*must be.*supported values.*bad_name'):
            utils._get_loss('bad_name', 'custom_name')
예제 #2
0
 def testForMMDLoss(self):
     loss = utils._get_loss('mmd')
     self.assertIsInstance(loss, mmd_loss.MMDLoss)
     loss = utils._get_loss('mmd_loss')
     self.assertIsInstance(loss, mmd_loss.MMDLoss)
     loss = utils._get_loss(mmd_loss.MMDLoss())
     self.assertIsInstance(loss, mmd_loss.MMDLoss)
     loss_name = 'custom_name'
     loss = utils._get_loss(mmd_loss.MMDLoss(name=loss_name))
     self.assertIsInstance(loss, mmd_loss.MMDLoss)
     self.assertEqual(loss.name, loss_name)
    def __init__(self,
                 original_model: tf.keras.Model,
                 loss,
                 loss_weight: complex = 1.0,
                 predictions_transform=None,
                 **kwargs):
        """Initializes a MinDiffModel instance.

    Raises:
      ValueError: If `predictions_transform` is passed in but not callable.
    """

        super(MinDiffModel, self).__init__(**kwargs)
        # Set _auto_track_sub_layers to true to ensure we track the
        # original_model and MinDiff layers.
        self._auto_track_sub_layers = True  # Track sub layers.
        self.built = True  # This Model is built, original_model may or may not be.

        self._original_model = original_model
        self._loss = loss_utils._get_loss(loss)
        self._loss_weight = loss_weight
        self._min_diff_loss_metric = tf.keras.metrics.Mean("min_diff_loss")

        if (predictions_transform is not None
                and not callable(predictions_transform)):
            raise ValueError(
                "`predictions_transform` must be callable if passed "
                "in, given: {}".format(predictions_transform))
        self._predictions_transform = predictions_transform

        # Clear input_spec in case there is one. We cannot make any strong
        # assertions because `min_diff_data` may or may not be included and can
        # have different shapes since weight is optional.
        self.input_spec = None
예제 #4
0
    def testForCustomLoss(self):
        class CustomLoss(base_loss.MinDiffLoss):
            def call(self, x, y):
                pass

        loss = CustomLoss()
        loss_output = utils._get_loss(loss)
        self.assertIs(loss_output, loss)
예제 #5
0
 def testForAbsoluteCorrelationLoss(self):
     loss = utils._get_loss('abs_corr')
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('abS_coRr')  # Strangely capitalized.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('abs_corr_loss')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('absolute_correlation')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss(
         'absolute_correlation_loss')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss_name = 'custom_name'
     loss = utils._get_loss(abscorrloss.AbsoluteCorrelationLoss(loss_name))
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     self.assertEqual(loss.name, loss_name)
예제 #6
0
 def testAcceptsNone(self):
     loss = utils._get_loss(None)
     self.assertIsNone(loss)