def test_import_module(reload, import_module, install, download_and_extract): modules.import_module("s3://bucket/my-module") download_and_extract.assert_called_with("s3://bucket/my-module", environment.code_dir) install.assert_called_with(environment.code_dir) reload.assert_called_with(import_module(modules.DEFAULT_MODULE_NAME))
def test_import_module_local_directory(reload, import_module, install, prepare, tarfile, s3_download): uri = "/opt/ml/input/data/code/sourcedir.tar.gz" modules.import_module(uri) s3_download.assert_not_called() tarfile.assert_called_with(name="/opt/ml/input/data/code/sourcedir.tar.gz", mode="r:gz") prepare.assert_called_once() install.assert_called_once()
def test_import_module_with_requirements(user_module, user_module_name, requirements_file): user_module = user_module.add_file(requirements_file).upload() module = modules.import_module(uri=user_module.url, name=user_module_name) assert module.say() == REQUIREMENTS_TXT_ASSERT_STR
def test_import_module_with_local_script(user_module, user_module_name, tmpdir): tmp_code_dir = str(tmpdir) user_module.create_tmp_dir_with_files(tmp_code_dir) module = modules.import_module(tmp_code_dir, user_module_name) assert module.validate()
def test_import_module_with_local_tar(user_module, user_module_name, requirements_file): user_module = user_module.add_file(requirements_file) tar_name = user_module.create_tar() module = modules.import_module(tar_name, name=user_module_name) assert module.say() == REQUIREMENTS_TXT_ASSERT_STR os.remove(tar_name)
def train(): training_env = environment.Environment() script = modules.import_module(training_env.module_dir, training_env.module_name) model = script.train(**functions.matching_args(script.train, training_env)) if model: if hasattr(script, "save"): script.save(model, training_env.model_dir) else: model_file = os.path.join(training_env.model_dir, "saved_model") model.save(model_file)
def framework_training_fn(): training_env = environment.Environment() mod = modules.import_module(training_env.module_dir, training_env.module_name) model = mod.train(**functions.matching_args(mod.train, training_env)) if model: if hasattr(mod, "save"): mod.save(model, training_env.model_dir) else: model_file = os.path.join(training_env.model_dir, "saved_model") model.save(model_file)
def test_import_module_with_s3_script(user_module, user_module_name): user_module.upload() module = modules.import_module(user_module.url, user_module_name) assert module.validate()
def test_import_module_with_s3_script_with_error(user_module_name): user_module = test.UserModule(USER_SCRIPT_WITH_ERROR).add_file( SETUP_FILE).upload() with pytest.raises(errors.ImportModuleError): modules.import_module(user_module.url, user_module_name)