def test_preprocessor_error_message(): """Tests whether the preprocessor returns a preprocessor error when there is a problem using the preprocessor """ preprocessor = ArrayIndexer(np.array([[1.2, 3.3], [3.1, 3.2]])) # with tuples X = np.array([[[2, 3], [3, 3]], [[2, 3], [3, 2]]]) # There are less samples than the max index we want to preprocess with pytest.raises(PreprocessorError): preprocess_tuples(X, preprocessor) # with points X = np.array([[1], [2], [3], [3]]) with pytest.raises(PreprocessorError): preprocess_points(X, preprocessor)
def test_preprocess_tuples_simple_example(): """Test the preprocessor on a very simple example of tuples to ensure the result is as expected""" array = np.array([[1, 2], [2, 3], [4, 5]]) def fun(row): return np.array([[1, 1], [3, 3], [4, 4]]) expected_result = np.array([[[1, 1], [1, 1]], [[3, 3], [3, 3]], [[4, 4], [4, 4]]]) assert (preprocess_tuples(array, fun) == expected_result).all()
def test_check_tuples_invalid_n_samples(estimator, context, load_tuples, preprocessor): """Checks that the right warning is printed if n_samples is too small""" tuples = load_tuples() msg = ("Found array with 2 sample(s) (shape={}) while a minimum of 3 " "is required{}.".format((preprocess_tuples(tuples, preprocessor) if (preprocessor is not None and tuples.ndim == 2) else tuples).shape, context)) with pytest.raises(ValueError) as raised_error: check_input(tuples, type_of_inputs='tuples', preprocessor=preprocessor, ensure_min_samples=3, estimator=estimator) assert str(raised_error.value) == msg
def test_check_tuples_invalid_tuple_size(estimator, context, load_tuples, preprocessor): """Checks that the exception are raised if tuple_size is not the one expected""" tuples = load_tuples() preprocessed_tuples = (preprocess_tuples(tuples, preprocessor) if (preprocessor is not None and tuples.ndim == 2) else tuples) expected_msg = ("Tuples of 3 element(s) expected{}. Got tuples of 2 " "element(s) instead (shape={}):\ninput={}.\n" .format(context, preprocessed_tuples.shape, preprocessed_tuples)) with pytest.raises(ValueError) as raised_error: check_input(tuples, type_of_inputs='tuples', tuple_size=3, preprocessor=preprocessor, estimator=estimator) assert str(raised_error.value) == expected_msg