def testInvalidLogitsShape(self):
     """An error is raised when logits have invalid shape."""
     with self.cached_session():
         logits = constant_op.constant([-1.0, 2.1], shape=(2, ))
         labels = constant_op.constant([0, 1])
         with self.assertRaises(ValueError):
             _ = losses.sparse_multiclass_hinge_loss(labels, logits)
 def testInvalidLabelsDtype(self):
     """An error is raised when labels have invalid shape."""
     with self.cached_session():
         logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
         labels = constant_op.constant([1, 0], dtype=dtypes.float32)
         with self.assertRaises(ValueError):
             _ = losses.sparse_multiclass_hinge_loss(labels, logits)
    def testUnknownShape(self):
        """Result keeps same with `testZeroLossInt32Labels`"""
        logits_np = np.array([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
                              [0.5, 1.8, -1.0]])
        labels_np = np.array([0, 2, 1], dtype=np.int32)

        logits_shapes = [
            [3, 3],  # batch_size, num_classes
            [None, 3],
            [3, None],
            [None, None]
        ]

        for batch_size, num_classes in logits_shapes:
            with self.cached_session():
                logits = array_ops.placeholder(dtypes.float32,
                                               shape=(batch_size, num_classes))
                labels = array_ops.placeholder(dtypes.int32,
                                               shape=(batch_size, ))
                loss = losses.sparse_multiclass_hinge_loss(labels, logits)
                result = loss.eval(feed_dict={
                    logits: logits_np,
                    labels: labels_np
                })
                self.assertAlmostEqual(result, 0.0, 3)
 def testZeroLossInt64Labels(self):
     """Loss is 0 if true class logits sufficiently higher than other classes."""
     with self.cached_session():
         logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
                                        [-0.5, 0.8, -1.0]])
         labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
         self.assertAlmostEqual(loss.eval(), 0.0, 3)
 def testOutOfRangeLabels(self):
     """An error is raised when labels are not in [0, num_classes)."""
     with self.cached_session():
         logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
                                        [0.5, 1.8, -1.0]])
         labels = constant_op.constant([1, 0, 4])
         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
         with self.assertRaises(errors.InvalidArgumentError):
             loss.eval()
 def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
     """Error raised when weights and labels have different ranks and sizes."""
     with self.cached_session():
         logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
         labels = constant_op.constant([1, 0], shape=(2, 1))
         weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3, ))
         with self.assertRaises(ValueError):
             _ = losses.sparse_multiclass_hinge_loss(
                 labels, logits, weights)
 def testNoneWeightRaisesValueError(self):
     """An error is raised when weights are None."""
     with self.cached_session():
         logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
         labels = constant_op.constant([1, 0])
         with self.assertRaises(ValueError):
             _ = losses.sparse_multiclass_hinge_loss(labels,
                                                     logits,
                                                     weights=None)
 def testNonZeroLossWithScalarTensorWeights(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
     with self.cached_session():
         logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                        [0.2, -1.8, 4.0]])
         labels = constant_op.constant([1, 0, 2], shape=(3, 1))
         weights = constant_op.constant(5.0)
         loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
         self.assertAlmostEqual(loss.eval(), 5.5, 3)
 def testCorrectPredictionsSomeClassesInsideMargin(self):
     """Loss is > 0 even if true class logits are higher than other classes."""
     with self.cached_session():
         logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
                                        [1.5, 1.8, -1.0]])
         labels = constant_op.constant([0, 2, 1])
         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
         # The first and third samples incur some loss (0.6 and 0.7 respectively).
         self.assertAlmostEqual(loss.eval(), 0.4333, 3)
 def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
     with self.cached_session():
         logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                        [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
         labels = constant_op.constant([1, 0, 2, 1])
         weights = constant_op.constant([[1.0], [0.0], [2.0], [4.0]])
         loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
         # The overall loss is 1/3 *(3.0*1.0 + 0.0*0.3+ 2.0*0.0 + 4.0*0.8) = 6.2/3.
         self.assertAlmostEqual(loss.eval(), 2.06666, 3)
 def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
     with self.cached_session():
         logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                        [0.2, -1.8, 4.0]])
         labels = constant_op.constant([1, 0, 2], shape=(3, 1))
         weights = constant_op.constant([1.0, 0.5, 2.0], shape=(3, ))
         loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
         # The overall loss is 1/3 *(3.0*1.0 + 0.5*0.3+ 2.0*0.0) = 1.05
         self.assertAlmostEqual(loss.eval(), 1.05, 3)
 def testIncorrectPredictionsZeroWeights(self):
     """Loss is 0 when all weights are missing even if predictions are wrong."""
     with self.cached_session():
         logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                        [0.2, -1.8, 4.0]])
         labels = constant_op.constant([1, 0, 2], shape=(3, 1))
         weights = constant_op.constant([0.0, 0.0, 0.0], shape=(3, 1))
         loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
         # No overall loss since all weights are 0.
         self.assertAlmostEqual(loss.eval(), 0.0, 3)
 def testIncorrectPredictionsColumnLabels(self):
     """Same as above but labels is a rank-2 tensor."""
     with self.cached_session():
         logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                        [0.2, -1.8, 4.0]])
         labels = constant_op.constant([1, 0, 2], shape=(3, 1))
         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
         # The first examples incurs a high loss (3.0) since the logits of an
         # incorrect class (0) are higher than the logits of the ground truth. The
         # second example also incures a (smaller) loss (0.3).
         self.assertAlmostEqual(loss.eval(), 1.1, 3)
 def testIncorrectPredictions(self):
     """Loss is >0 when an incorrect class has higher logits than true class."""
     with self.cached_session():
         logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
                                        [0.5, -1.8, 2.0]])
         labels = constant_op.constant([1, 0, 2])
         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
         # The first examples incurs a high loss (3.2) since the logits of an
         # incorrect class (0) are higher than the logits of the ground truth. The
         # second example also incures a (smaller) loss (0.4).
         self.assertAlmostEqual(loss.eval(), 1.2, 3)