예제 #1
0
def test_threshold_predictions_multiclass():
    """Test thresholding of multiclass predictions."""
    y = np.random.rand(10, 5)
    y = y / np.sum(y, axis=1)[:, np.newaxis]
    y_thresh = threshold_predictions(y)
    assert y_thresh.shape == (10, )
    assert (y_thresh == np.argmax(y, axis=1)).all()
예제 #2
0
def test_threshold_predictions_binary():
    """Test thresholding of binary predictions."""
    # Get a random prediction matrix
    y = np.random.rand(10, 2)
    y = y / np.sum(y, axis=1)[:, np.newaxis]
    y_thresh = threshold_predictions(y, 0.5)
    assert y_thresh.shape == (10, )
    assert (y_thresh == np.argmax(y, axis=1)).all()