def test_transformer_implementation():
    test.create_resource_config()
    test.create_input_data_config()
    test.create_hyperparameters_config({"sagemaker_program": "user_script.py"})

    model_path = os.path.join(env.model_dir, "fake_ml_model")
    fake_ml_framework.Model(weights=[6, 9, 42]).save(model_path)

    transform = transformer.Transformer(model_fn=model_fn,
                                        predict_fn=predict_fn)

    transform.initialize()

    with worker.Worker(transform_fn=transform.transform,
                       module_name="fake_ml_model").test_client() as client:
        payload = [6, 9, 42.0]
        response = post(client, payload, content_types.NPY, content_types.JSON)

        assert response.status_code == http_client.OK

        assert response.get_data(as_text=True) == "[36.0, 81.0, 1764.0]"

        response = post(client, payload, content_types.JSON, content_types.CSV)

        assert response.status_code == http_client.OK
        assert response.get_data(as_text=True) == "36.0\n81.0\n1764.0\n"

        response = post(client, payload, content_types.CSV, content_types.NPY)

        assert response.status_code == http_client.OK
        response_data = encoders.npy_to_numpy(response.get_data())

        np.testing.assert_array_almost_equal(response_data,
                                             np.asarray([36.0, 81.0, 1764.0]))
def _user_module_transformer(user_module):
    model_fn = getattr(user_module, 'model_fn', default_model_fn)
    input_fn = getattr(user_module, 'input_fn', default_input_fn)
    predict_fn = getattr(user_module, 'predict_fn', default_predict_fn)
    output_fn = getattr(user_module, 'output_fn', default_output_fn)

    return transformer.Transformer(model_fn=model_fn, input_fn=input_fn, predict_fn=predict_fn,
                                   output_fn=output_fn)
Example #3
0
def _user_module_transformer(user_module):
    model_fn = getattr(user_module, "model_fn", default_model_fn)
    input_fn = getattr(user_module, "input_fn", None)
    predict_fn = getattr(user_module, "predict_fn", None)
    output_fn = getattr(user_module, "output_fn", None)
    transform_fn = getattr(user_module, "transform_fn", None)

    if transform_fn and (input_fn or predict_fn or output_fn):
        raise exc.UserError(
            "Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn"
        )

    if transform_fn is not None:
        return transformer.Transformer(model_fn=model_fn, transform_fn=transform_fn)
    else:
        return transformer.Transformer(
            model_fn=model_fn,
            input_fn=input_fn or default_input_fn,
            predict_fn=default_predict_fn,
            output_fn=output_fn or default_output_fn,
        )
Example #4
0
def _transformer_with_transform_fn(model_fn, transform_fn):
    user_transformer = transformer.Transformer(model_fn=model_fn,
                                               transform_fn=transform_fn)
    user_transformer.initialize()
    return user_transformer