コード例 #1
0
def test_multiple_classes():
    gt_label = np.array([
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 1.0, 0.0],
        [0.0, 1.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 1.0, 0.0],
    ])
    preds = np.array([
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 1.0],
    ])
    tensor_gt_label = tf.constant(gt_label, dtype=tf.float32)
    tensor_preds = tf.constant(preds, dtype=tf.float32)
    # Initialize
    mcc = MatthewsCorrelationCoefficient(3)
    # Update
    mcc.update_state(tensor_gt_label, tensor_preds)
    # Check results by comparing to results of scikit-learn matthew implementation.
    sklearn_result = sklearn_matthew(gt_label.argmax(axis=1),
                                     preds.argmax(axis=1))
    check_results(mcc, sklearn_result)
コード例 #2
0
def test_binary_classes():
    gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
    preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
    # Initialize
    mcc = MatthewsCorrelationCoefficient(1)
    # Update
    mcc.update_state(gt_label, preds)
    # Check results
    check_results(mcc, [-0.33333334])
コード例 #3
0
def test_reset_states_graph():
    gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
    preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
    mcc = MatthewsCorrelationCoefficient(1)
    mcc.update_state(gt_label, preds)

    @tf.function
    def reset_states():
        mcc.reset_states()

    reset_states()
    # Check results
    check_results(mcc, [0])
コード例 #4
0
def test_multiple_classes():
    gt_label = tf.constant(
        [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
        dtype=tf.float32,
    )
    preds = tf.constant(
        [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]],
        dtype=tf.float32,
    )
    # Initialize
    mcc = MatthewsCorrelationCoefficient(3)
    mcc.update_state(gt_label, preds)
    # Check results
    check_results(mcc, [-0.33333334, 1.0, 0.57735026])