コード例 #1
0
def test_mcm_3_classes(dtype):
    actuals = tf.constant([[1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=dtype)
    preds = tf.constant([[1, 0, 0], [0, 1, 1], [1, 0, 0], [0, 1, 1]], dtype=dtype)
    # Initialize
    mcm_obj = MultiLabelConfusionMatrix(num_classes=3, dtype=dtype)
    mcm_obj.update_state(actuals, preds)
    # Check results
    check_results(mcm_obj, [[[2, 0], [0, 2]], [[2, 0], [0, 2]], [[0, 2], [2, 0]]])
コード例 #2
0
def test_config():
    mcm_obj = MultiLabelConfusionMatrix(num_classes=3)
    assert mcm_obj.num_classes == 3
    assert mcm_obj.dtype == tf.float32
    # Check save and restore config
    mcm_obj2 = MultiLabelConfusionMatrix.from_config(mcm_obj.get_config())
    assert mcm_obj2.num_classes == 3
    assert mcm_obj2.dtype == tf.float32
コード例 #3
0
 def test_config(self):
     mcm_obj = MultiLabelConfusionMatrix(num_classes=3)
     self.assertEqual(mcm_obj.num_classes, 3)
     self.assertEqual(mcm_obj.dtype, tf.int32)
     # Check save and restore config
     mcm_obj2 = MultiLabelConfusionMatrix.from_config(mcm_obj.get_config())
     self.assertEqual(mcm_obj2.num_classes, 3)
     self.assertEqual(mcm_obj2.dtype, tf.int32)
コード例 #4
0
def test_multiclass(dtype):
    actuals = tf.constant(
        [
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [0, 1, 0, 0],
            [0, 1, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ],
        dtype=dtype,
    )
    preds = tf.constant(
        [
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [1, 0, 0, 0],
            [1, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 1],
        ],
        dtype=dtype,
    )

    # Initialize
    mcm_obj = MultiLabelConfusionMatrix(num_classes=4, dtype=dtype)
    mcm_obj.update_state(actuals, preds)
    # Check results
    check_results(
        mcm_obj,
        [
            [[5, 2], [0, 3]],
            [[7, 1], [2, 0]],
            [[7, 0], [1, 2]],
            [[8, 0], [0, 2]],
        ],
    )
コード例 #5
0
 def initialize_vars(self, n_classes, input_dtype):
     mcm_obj = MultiLabelConfusionMatrix(num_classes=n_classes,
                                         dtype=input_dtype)
     self.evaluate(tf.compat.v1.variables_initializer(mcm_obj.variables))
     return mcm_obj