def test_handle_classification_mode_threshold(): """Test proper thresholding.""" y = np.random.rand(10, 2) y = y / np.sum(y, axis=1)[:, np.newaxis] y = np.expand_dims(y, 1) y_expected = np.argmax(np.squeeze(y), axis=1)[:, np.newaxis] y_out = handle_classification_mode(y, "threshold", threshold_value=0.5) assert y_out.shape == (10, 1) assert np.array_equal(y_out, y_expected)
def test_handle_classification_mode_direct(): """Test proper thresholding.""" y = np.random.rand(10, 2) y = y / np.sum(y, axis=1)[:, np.newaxis] y = np.expand_dims(y, 1) y_expected = y y_out = handle_classification_mode(y, "direct") assert y_out.shape == (10, 1, 2) assert np.array_equal(y_out, y_expected)
def test_handle_classification_mode_threshold_nonstandard(): """Test proper thresholding.""" y = np.random.rand(10, 2) y = y / np.sum(y, axis=1)[:, np.newaxis] y_expected = np.where(y[:, 1] >= 0.3, np.ones(10), np.zeros(10))[:, np.newaxis] y = np.expand_dims(y, 1) y_out = handle_classification_mode(y, "threshold", threshold_value=0.3) assert y_out.shape == (10, 1) assert np.array_equal(y_out, y_expected)