def test_threshold_predictions_multiclass(): """Test thresholding of multiclass predictions.""" y = np.random.rand(10, 5) y = y / np.sum(y, axis=1)[:, np.newaxis] y_thresh = threshold_predictions(y) assert y_thresh.shape == (10, ) assert (y_thresh == np.argmax(y, axis=1)).all()
def test_threshold_predictions_binary(): """Test thresholding of binary predictions.""" # Get a random prediction matrix y = np.random.rand(10, 2) y = y / np.sum(y, axis=1)[:, np.newaxis] y_thresh = threshold_predictions(y, 0.5) assert y_thresh.shape == (10, ) assert (y_thresh == np.argmax(y, axis=1)).all()