def test_predict_should_predict_in_test_mode(): tape_fit = TapeCallbackFunction() tape_transform = TapeCallbackFunction() p = Pipeline([ TestOnlyWrapper( CallbackWrapper(MultiplyByN(2), tape_transform, tape_fit)), TrainOnlyWrapper( CallbackWrapper(MultiplyByN(4), tape_transform, tape_fit)) ]) outputs = p.predict(np.array([1, 1])) assert np.array_equal(outputs, np.array([2, 2]))
def test_predict_should_transform_with_initial_is_train_mode_after_predict(): tape_fit = TapeCallbackFunction() tape_transform = TapeCallbackFunction() p = Pipeline([ TestOnlyWrapper( CallbackWrapper(MultiplyByN(2), tape_transform, tape_fit)), TrainOnlyWrapper( CallbackWrapper(MultiplyByN(4), tape_transform, tape_fit)) ]) p.predict(np.array([1, 1])) outputs = p.transform(np.array([1, 1])) assert np.array_equal(outputs, np.array([4, 4]))
def test_handle_predict_should_predict_in_test_mode(): tape_fit = TapeCallbackFunction() tape_transform = TapeCallbackFunction() p = Pipeline([ TestOnlyWrapper( CallbackWrapper(MultiplyByN(2), tape_transform, tape_fit)), TrainOnlyWrapper( CallbackWrapper(MultiplyByN(4), tape_transform, tape_fit)) ]) data_container = p.handle_predict(data_container=DataContainer( data_inputs=np.array([1, 1]), expected_outputs=np.array([1, 1])), context=ExecutionContext()) assert np.array_equal(data_container.data_inputs, np.array([2, 2]))
DATA_INPUTS = np.array(range(5)) EXPECTED_OUTPUTS = np.array(range(5, 10)) EXPECTED_PROCESSED_OUTPUTS = np.array([5.0, 6.0, 7.0, 8.0, 9.0]) tape_transform_preprocessing = TapeCallbackFunction() tape_fit_preprocessing = TapeCallbackFunction() tape_transform_postprocessing = TapeCallbackFunction() tape_fit_postprocessing = TapeCallbackFunction() tape_inverse_transform_preprocessing = TapeCallbackFunction() @pytest.mark.parametrize('test_case', [ NeuraxleTestCase( pipeline=Pipeline([ ReversiblePreprocessingWrapper( preprocessing_step=CallbackWrapper(MultiplyByN(2), tape_transform_preprocessing, tape_fit_postprocessing, tape_inverse_transform_preprocessing), postprocessing_step=CallbackWrapper(AddN(10), tape_transform_postprocessing, tape_fit_postprocessing) )] ), callbacks=[tape_transform_preprocessing, tape_fit_preprocessing, tape_transform_postprocessing, tape_fit_postprocessing, tape_inverse_transform_preprocessing], expected_callbacks_data=[ [DATA_INPUTS], [], [DATA_INPUTS * 2], [], [(DATA_INPUTS * 2) + 10] ], data_inputs=DATA_INPUTS, expected_processed_outputs=EXPECTED_PROCESSED_OUTPUTS, execution_mode=ExecutionMode.TRANSFORM ),