def testName(self):
        # Default name.
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        self.assertEqual(loss_fn.name, 'absolute_correlation_loss')

        # Custom name.
        loss_fn = loss_lib.AbsoluteCorrelationLoss(name='custom_loss')
        self.assertEqual(loss_fn.name, 'custom_loss')
    def testNegativeCorrelationNoWeights(self):
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
        predictions = tf.constant([[0.12], [0.7], [0.2], [0.86], [0.32]])

        loss_value = loss_fn(membership, predictions)
        self.assertAllClose(0.959671, loss_value)
    def testNoWeights(self):
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
        predictions = tf.constant([[0.3], [0.1], [0.86], [0.06], [0.75]])

        loss_value = loss_fn(membership, predictions)
        self.assertAllClose(0.823209, loss_value)
  def testSerialization(self):
    loss = loss_lib.AbsoluteCorrelationLoss(name='custom_name')
    serialized_loss = tf.keras.utils.serialize_keras_object(loss)

    deserialized_loss = tf.keras.utils.deserialize_keras_object(serialized_loss)
    self.assertIsInstance(deserialized_loss, loss_lib.AbsoluteCorrelationLoss)
    self.assertEqual(deserialized_loss.name, loss.name)
    def testAllZeroWeights(self):
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
        predictions = tf.constant([[0.12], [0.7], [0.2], [0.86], [0.32]])
        sample_weights = tf.constant([[0.0], [0.0], [0.0], [0.0], [0.0]])

        loss_value = loss_fn(membership, predictions, sample_weights)
        self.assertAllClose(0, loss_value)
    def testWithWeights(self):
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
        predictions = tf.constant([[0.3], [0.1], [0.86], [0.06], [0.75]])
        sample_weights = tf.constant([[1.0], [2.0], [2.5], [1.2], [0.9]])

        loss_value = loss_fn(membership, predictions, sample_weights)
        self.assertAllClose(0.872562, loss_value)
    def testSomeNegativeWeights(self):
        loss_fn = loss_lib.AbsoluteCorrelationLoss()
        membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
        predictions = tf.constant([[0.12], [0.7], [0.2], [0.86], [0.32]])
        sample_weights = tf.constant([[1.0], [2.0], [-2.5], [1.2], [0.9]])

        with self.assertRaisesRegex(
                tf.errors.InvalidArgumentError,
                'sample_weight.*cannot contain any negative weights'):
            loss_value = loss_fn(membership, predictions, sample_weights)
            if not tf.executing_eagerly():
                with self.cached_session() as sess:
                    sess.run(loss_value)
  def testGradientsWithAllZeroWeights(self):
    loss_fn = loss_lib.AbsoluteCorrelationLoss()
    variables = tf.constant([[0.1], [0.3], [0.5], [0.7], [0.9]])
    membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
    sample_weights = tf.constant([[0.0], [0.0], [0.0], [0.0], [0.0]])

    with tf.GradientTape() as tape:
      tape.watch(variables)
      predictions = variables * 3  # arbitrary linear operation.
      loss_value = loss_fn(membership, predictions, sample_weights)

    gradients = tape.gradient(loss_value, variables)
    # Gradients should all be 0 for weights that are all 0.
    expected_gradients = [[0.0], [0.0], [0.0], [0.0], [0.0]]
    self.assertAllClose(expected_gradients, gradients)
Пример #9
0
 def testForAbsoluteCorrelationLoss(self):
     loss = utils._get_loss('abs_corr')
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('abS_coRr')  # Strangely capitalized.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('abs_corr_loss')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss('absolute_correlation')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss = utils._get_loss(
         'absolute_correlation_loss')  # Other accepted name.
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     loss_name = 'custom_name'
     loss = utils._get_loss(abscorrloss.AbsoluteCorrelationLoss(loss_name))
     self.assertIsInstance(loss, abscorrloss.AbsoluteCorrelationLoss)
     self.assertEqual(loss.name, loss_name)
  def testGradients(self):
    loss_fn = loss_lib.AbsoluteCorrelationLoss()
    variables = tf.constant([[0.1], [0.3], [0.5], [0.7], [0.9]])
    membership = tf.constant([[1.0], [0.0], [1.0], [0.0], [1.0]])
    sample_weights = tf.constant([[1.0], [2.0], [2.5], [1.2], [0.9]])

    with tf.GradientTape() as tape:
      tape.watch(variables)
      predictions = variables * 3  # arbitrary linear operation.
      loss_value = loss_fn(membership, predictions, sample_weights)

    gradients = tape.gradient(loss_value, variables)
    # Assert that gradient computations are non trivial and do not change based
    # on loss implementation.
    expected_gradients = [[0.5481868], [-1.2328656], [1.1707227], [-0.83559656],
                          [0.34955233]]
    self.assertAllClose(expected_gradients, gradients)
 def testEnableSummaryHistogram(self):
     loss = loss_lib.AbsoluteCorrelationLoss()
     self.assertTrue(loss.enable_summary_histogram)
     loss = loss_lib.AbsoluteCorrelationLoss(enable_summary_histogram=False)
     self.assertFalse(loss.enable_summary_histogram)