def test_weighted(self): cat_hinge_obj = metrics.CategoricalHinge() self.evaluate(variables.variables_initializer(cat_hinge_obj.variables)) y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) sample_weight = constant_op.constant((1., 1.5, 2., 2.5)) result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight) self.assertAllClose(0.5, self.evaluate(result), atol=1e-5)
def test_config(self): cat_hinge_obj = metrics.CategoricalHinge(name='cat_hinge', dtype=dtypes.int32) self.assertEqual(cat_hinge_obj.name, 'cat_hinge') self.assertEqual(cat_hinge_obj._dtype, dtypes.int32) # Check save and restore config cat_hinge_obj2 = metrics.CategoricalHinge.from_config( cat_hinge_obj.get_config()) self.assertEqual(cat_hinge_obj2.name, 'cat_hinge') self.assertEqual(cat_hinge_obj2._dtype, dtypes.int32)
def test_unweighted(self): cat_hinge_obj = metrics.CategoricalHinge() self.evaluate(variables.variables_initializer(cat_hinge_obj.variables)) y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))) y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))) update_op = cat_hinge_obj.update_state(y_true, y_pred) self.evaluate(update_op) result = cat_hinge_obj.result() self.assertAllClose(0.5, result, atol=1e-5)