예제 #1
0
def test_import_module_from_s3_script_with_error(user_module_name):
    user_module = test.UserModule(USER_SCRIPT_WITH_ERROR).upload()

    with pytest.raises(errors.ImportModuleError):
        modules.import_module_from_s3(user_module.url,
                                      user_module_name,
                                      cache=False)
예제 #2
0
def test_import_module_from_s3_script(user_module_name):
    user_module = test.UserModule(USER_SCRIPT).upload()

    module = modules.import_module_from_s3(user_module.url,
                                           user_module_name,
                                           cache=False)

    assert module.validate()
예제 #3
0
def test_import_module_from_s3_script_with_additional_files(user_module_name):
    user_module = test.UserModule(USER_SCRIPT_WITH_ADDITIONAL_FILE).add_file(
        ADDITIONAL_FILE).upload()

    module = modules.import_module_from_s3(user_module.url,
                                           user_module_name,
                                           cache=False)

    assert module.validate()
예제 #4
0
def main(environ, start_response):
    serving_env = env.ServingEnv()
    user_module = modules.import_module_from_s3(serving_env.module_dir,
                                                serving_env.module_name)

    user_module_transformer = _user_module_transformer(user_module)

    user_module_transformer.initialize()

    app = worker.Worker(transform_fn=user_module_transformer.transform,
                        module_name=serving_env.module_name)
    return app(environ, start_response)
예제 #5
0
def framework_training_fn():
    training_env = sagemaker_containers.training_env()

    mod = modules.import_module_from_s3(training_env.module_dir,
                                        training_env.module_name, False)

    model = mod.train(**functions.matching_args(mod.train, training_env))

    if model:
        if hasattr(mod, 'save'):
            mod.save(model, training_env.model_dir)
        else:
            model_file = os.path.join(training_env.model_dir, 'saved_model')
            model.save(model_file)
예제 #6
0
def train():
    training_env = sagemaker_containers.training_env()

    script = modules.import_module_from_s3(training_env.module_dir,
                                           training_env.module_name, False)

    model = script.train(**functions.matching_args(script.train, training_env))

    if model:
        if hasattr(script, 'save'):
            script.save(model, training_env.model_dir)
        else:
            model_file = os.path.join(training_env.model_dir, 'saved_model')
            model.save(model_file)
예제 #7
0
def test_import_module_from_s3_script_with_requirements(user_module_name):
    user_module = test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(
        REQUIREMENTS_FILE).upload()

    module = modules.import_module_from_s3(user_module.url,
                                           user_module_name,
                                           cache=False)

    assert module.say() == """
 ____                   __  __       _.............
/ ___|  __ _  __ _  ___|  \/  | __ _| | _____ _ __.
\___ \ / _` |/ _` |/ _ \ |\/| |/ _` | |/ / _ \ '__|
 ___) | (_| | (_| |  __/ |  | | (_| |   <  __/ |...
|____/ \__,_|\__, |\___|_|  |_|\__,_|_|\_\___|_|...
             |___/.................................
""".replace('.', ' ').strip()