def test_transformer_implementation():
    test.create_resource_config()
    test.create_input_data_config()
    test.create_hyperparameters_config({"sagemaker_program": "user_script.py"})

    model_path = os.path.join(env.model_dir, "fake_ml_model")
    fake_ml_framework.Model(weights=[6, 9, 42]).save(model_path)

    transform = transformer.Transformer(model_fn=model_fn,
                                        predict_fn=predict_fn)

    transform.initialize()

    with worker.Worker(transform_fn=transform.transform,
                       module_name="fake_ml_model").test_client() as client:
        payload = [6, 9, 42.0]
        response = post(client, payload, content_types.NPY, content_types.JSON)

        assert response.status_code == http_client.OK

        assert response.get_data(as_text=True) == "[36.0, 81.0, 1764.0]"

        response = post(client, payload, content_types.JSON, content_types.CSV)

        assert response.status_code == http_client.OK
        assert response.get_data(as_text=True) == "36.0\n81.0\n1764.0\n"

        response = post(client, payload, content_types.CSV, content_types.NPY)

        assert response.status_code == http_client.OK
        response_data = encoders.npy_to_numpy(response.get_data())

        np.testing.assert_array_almost_equal(response_data,
                                             np.asarray([36.0, 81.0, 1764.0]))
Exemple #2
0
def main(environ, start_response):
    global app
    if app is None:
        serving_env = env.ServingEnv()
        user_module = modules.import_module(serving_env.module_dir,
                                            serving_env.module_name)
        user_module_transformer = _user_module_transformer(user_module)
        user_module_transformer.initialize()
        app = worker.Worker(transform_fn=user_module_transformer.transform,
                            module_name=serving_env.module_name)
    return app(environ, start_response)
Exemple #3
0
def main(environ, start_response):
    global app
    if app is None:
        serving_env = env.ServingEnv()
        _update_mxnet_env_vars()

        user_module = modules.import_module(serving_env.module_dir, serving_env.module_name)
        user_transformer = _user_module_transformer(user_module, serving_env.model_dir)

        app = worker.Worker(transform_fn=user_transformer.transform,
                            module_name=serving_env.module_name)

    return app(environ, start_response)
Exemple #4
0
def main(environ, start_response):
    global app
    if app is None:
        serving_env = env.ServingEnv()
        logger.setLevel(serving_env.log_level)
        user_module = modules.import_module(serving_env.module_dir, serving_env.module_name)
        user_module_transformer = _user_module_transformer(user_module)
        user_module_transformer.initialize()
        app = worker.Worker(transform_fn=user_module_transformer.transform,
                            module_name=serving_env.module_name,
                            healthcheck_fn=default_healthcheck_fn)

    return app(environ, start_response)
Exemple #5
0
def main(environ, start_response):
    global app

    if app is None:
        serving_env = env.ServingEnv()

        user_module_transformer, execution_parameters_fn = import_module(
            serving_env.module_name, serving_env.module_dir)

        app = worker.Worker(transform_fn=user_module_transformer.transform,
                            module_name=serving_env.module_name,
                            execution_parameters_fn=execution_parameters_fn)

    return app(environ, start_response)
Exemple #6
0
def main(environ, start_response):
    serving_env = env.ServingEnv()

    logger.setLevel(serving_env.log_level)

    user_module = modules.import_module_from_s3(serving_env.module_dir,
                                                serving_env.module_name)

    user_module_transformer = _user_module_transformer(user_module)

    user_module_transformer.initialize()

    app = worker.Worker(transform_fn=user_module_transformer.transform,
                        module_name=serving_env.module_name)
    return app(environ, start_response)