Ejemplo n.º 1
0
def test_import_module(reload, import_module, install, download_and_extract):

    modules.import_module("s3://bucket/my-module")

    download_and_extract.assert_called_with("s3://bucket/my-module",
                                            environment.code_dir)
    install.assert_called_with(environment.code_dir)
    reload.assert_called_with(import_module(modules.DEFAULT_MODULE_NAME))
Ejemplo n.º 2
0
def test_import_module_local_directory(reload, import_module, install, prepare,
                                       tarfile, s3_download):
    uri = "/opt/ml/input/data/code/sourcedir.tar.gz"
    modules.import_module(uri)

    s3_download.assert_not_called()
    tarfile.assert_called_with(name="/opt/ml/input/data/code/sourcedir.tar.gz",
                               mode="r:gz")
    prepare.assert_called_once()
    install.assert_called_once()
def test_import_module_with_requirements(user_module, user_module_name,
                                         requirements_file):
    user_module = user_module.add_file(requirements_file).upload()

    module = modules.import_module(uri=user_module.url, name=user_module_name)

    assert module.say() == REQUIREMENTS_TXT_ASSERT_STR
def test_import_module_with_local_script(user_module, user_module_name,
                                         tmpdir):
    tmp_code_dir = str(tmpdir)

    user_module.create_tmp_dir_with_files(tmp_code_dir)

    module = modules.import_module(tmp_code_dir, user_module_name)

    assert module.validate()
def test_import_module_with_local_tar(user_module, user_module_name,
                                      requirements_file):
    user_module = user_module.add_file(requirements_file)
    tar_name = user_module.create_tar()

    module = modules.import_module(tar_name, name=user_module_name)

    assert module.say() == REQUIREMENTS_TXT_ASSERT_STR

    os.remove(tar_name)
Ejemplo n.º 6
0
def train():
    training_env = environment.Environment()

    script = modules.import_module(training_env.module_dir, training_env.module_name)

    model = script.train(**functions.matching_args(script.train, training_env))

    if model:
        if hasattr(script, "save"):
            script.save(model, training_env.model_dir)
        else:
            model_file = os.path.join(training_env.model_dir, "saved_model")
            model.save(model_file)
def framework_training_fn():
    training_env = environment.Environment()

    mod = modules.import_module(training_env.module_dir, training_env.module_name)

    model = mod.train(**functions.matching_args(mod.train, training_env))

    if model:
        if hasattr(mod, "save"):
            mod.save(model, training_env.model_dir)
        else:
            model_file = os.path.join(training_env.model_dir, "saved_model")
            model.save(model_file)
def test_import_module_with_s3_script(user_module, user_module_name):
    user_module.upload()

    module = modules.import_module(user_module.url, user_module_name)

    assert module.validate()
def test_import_module_with_s3_script_with_error(user_module_name):
    user_module = test.UserModule(USER_SCRIPT_WITH_ERROR).add_file(
        SETUP_FILE).upload()

    with pytest.raises(errors.ImportModuleError):
        modules.import_module(user_module.url, user_module_name)