def test_mcm_3_classes(dtype):
    actuals = tf.constant([[1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=dtype)
    preds = tf.constant([[1, 0, 0], [0, 1, 1], [1, 0, 0], [0, 1, 1]], dtype=dtype)
    # Initialize
    mcm_obj = MultiLabelConfusionMatrix(num_classes=3, dtype=dtype)
    mcm_obj.update_state(actuals, preds)
    # Check results
    check_results(mcm_obj, [[[2, 0], [0, 2]], [[2, 0], [0, 2]], [[0, 2], [2, 0]]])
예제 #2
0
def test_multiclass(dtype):
    actuals = tf.constant(
        [
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [0, 1, 0, 0],
            [0, 1, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ],
        dtype=dtype,
    )
    preds = tf.constant(
        [
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [1, 0, 0, 0],
            [1, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 1],
        ],
        dtype=dtype,
    )

    # Initialize
    mcm_obj = MultiLabelConfusionMatrix(num_classes=4, dtype=dtype)
    mcm_obj.update_state(actuals, preds)
    # Check results
    check_results(
        mcm_obj,
        [
            [[5, 2], [0, 3]],
            [[7, 1], [2, 0]],
            [[7, 0], [1, 2]],
            [[8, 0], [0, 2]],
        ],
    )