예제 #1
0
 def test_weighted(self):
     cat_hinge_obj = metrics.CategoricalHinge()
     self.evaluate(variables.variables_initializer(cat_hinge_obj.variables))
     y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
                                    (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
     y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
                                    (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
     sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
     result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
     self.assertAllClose(0.5, self.evaluate(result), atol=1e-5)
예제 #2
0
    def test_config(self):
        cat_hinge_obj = metrics.CategoricalHinge(name='cat_hinge',
                                                 dtype=dtypes.int32)
        self.assertEqual(cat_hinge_obj.name, 'cat_hinge')
        self.assertEqual(cat_hinge_obj._dtype, dtypes.int32)

        # Check save and restore config
        cat_hinge_obj2 = metrics.CategoricalHinge.from_config(
            cat_hinge_obj.get_config())
        self.assertEqual(cat_hinge_obj2.name, 'cat_hinge')
        self.assertEqual(cat_hinge_obj2._dtype, dtypes.int32)
예제 #3
0
    def test_unweighted(self):
        cat_hinge_obj = metrics.CategoricalHinge()
        self.evaluate(variables.variables_initializer(cat_hinge_obj.variables))
        y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
                                       (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
        y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
                                       (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))

        update_op = cat_hinge_obj.update_state(y_true, y_pred)
        self.evaluate(update_op)
        result = cat_hinge_obj.result()
        self.assertAllClose(0.5, result, atol=1e-5)