def download_and_install(uri, name=DEFAULT_MODULE_NAME, cache=True): # type: (str, str, bool) -> None """Download, prepare and install a compressed tar file from S3 or local directory as a module. The SageMaker Python SDK saves the user provided scripts as compressed tar files in S3. This function downloads this compressed file and, if provided, transforms it into a module before installing it. This method is the predecessor of :meth:`~sagemaker_containers.beta.framework.files.download_and_extract` and has been kept for backward-compatibility purposes. Args: name (str): name of the script or module. uri (str): the location of the module. cache (bool): defaults to True. It will not download and install the module again if it is already installed. """ should_use_cache = cache and exists(name) if not should_use_cache: with _files.tmpdir() as tmpdir: dst = os.path.join(tmpdir, "tar_file") _files.download_and_extract(uri, dst) module_path = os.path.join(tmpdir, "module_dir") os.makedirs(module_path) prepare(module_path, name) install(module_path)
def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str, bool) -> module """Download, prepare and install a compressed tar file from S3 or provided directory as a module. SageMaker Python SDK saves the user provided scripts as compressed tar files in S3 https://github.com/aws/sagemaker-python-sdk. This function downloads this compressed file, if provided, and transforms it as a module, and installs it. Args: name (str): name of the script or module. uri (str): the location of the module. cache (bool): default True. It will not download and install the module again if it is already installed. Returns: (module): the imported module """ _warning_cache_deprecation(cache) _files.download_and_extract(uri, _env.code_dir) prepare(_env.code_dir, name) install(_env.code_dir) try: module = importlib.import_module(name) six.moves.reload_module(module) # pylint: disable=too-many-function-args return module except Exception as e: # pylint: disable=broad-except six.reraise(_errors.ImportModuleError, _errors.ImportModuleError(e), sys.exc_info()[2])
def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, wait=True, capture_error=False): # type: (str, list, dict, str, bool, bool, bool) -> Popen """Download, prepare and executes a compressed tar file from S3 or provided directory as a module. SageMaker Python SDK saves the user provided scripts as compressed tar files in S3 https://github.com/aws/sagemaker-python-sdk. This function downloads this compressed file, transforms it as a module, and executes it. Args: uri (str): the location of the module. args (list): A list of program arguments. env_vars (dict): A map containing the environment variables to be written. name (str): name of the script or module. cache (bool): If True it will avoid downloading the module again, if already installed. wait (bool): If True run_module will wait for the user module to exit and check the exit code, otherwise it will launch the user module with subprocess and return the process object. """ _warning_cache_deprecation(cache) env_vars = env_vars or {} env_vars = env_vars.copy() _files.download_and_extract(uri, name, _env.code_dir) prepare(_env.code_dir, name) install(_env.code_dir) _env.write_env_vars(env_vars) return run(name, args, env_vars, wait, capture_error)
def test_download_and_extract_source_dir(copy, rmtree, s3_download): uri = _env.channel_path("code") _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() rmtree.assert_any_call(_env.code_dir) copy.assert_called_with(uri, _env.code_dir)
def test_download_and_and_extract_source_dir(move, rmtree, s3_download): uri = _env.channel_path('code') _files.download_and_extract(uri, 'train.sh', _env.code_dir) s3_download.assert_not_called() rmtree.assert_any_call(_env.code_dir) move.assert_called_with(uri, _env.code_dir)
def train(train_env: TrainingEnv): """ :param train_env: :return: """ code_dir = _env.code_dir logger.info('Download code') _files.download_and_extract(train_env.module_dir, code_dir) logger.info('Checking environment') _check_env() logger.info( f'Update {const.CONDA_TRAINING_ENV} conda env using MLProject dependencies' ) _update_codna_env(code_dir) logger.info('Run training') run_params = _split_run_params(train_env.additional_framework_parameters) run_id = _run_training(code_dir, train_env.hyperparameters, run_params) logger.info('Save results') _save_results(run_id, train_env.model_dir)
def run(uri, user_entry_point, args, env_vars=None, wait=True, capture_error=False, runner=_runner.ProcessRunnerType): # type: (str, str, List[str], Dict[str, str], bool, bool, _runner.RunnerType) -> None """Download, prepare and executes a compressed tar file from S3 or provided directory as an user entrypoint. Runs the user entry point, passing env_vars as environment variables and args as command arguments. If the entry point is: - A Python package: executes the packages as >>> env_vars python -m module_name + args - A Python script: executes the script as >>> env_vars python module_name + args - Any other: executes the command as >>> env_vars /bin/sh -c ./module_name + args Example: >>>import sagemaker_containers >>>from sagemaker_containers.beta.framework import entry_point >>>env = sagemaker_containers.training_env() {'channel-input-dirs': {'training': '/opt/ml/input/training'}, 'model_dir': '/opt/ml/model', ...} >>>hyperparameters = env.hyperparameters {'batch-size': 128, 'model_dir': '/opt/ml/model'} >>>args = mapping.to_cmd_args(hyperparameters) ['--batch-size', '128', '--model_dir', '/opt/ml/model'] >>>env_vars = mapping.to_env_vars() ['SAGEMAKER_CHANNELS':'training', 'SAGEMAKER_CHANNEL_TRAINING':'/opt/ml/input/training', 'MODEL_DIR':'/opt/ml/model', ...} >>>entry_point.run('user_script', args, env_vars) SAGEMAKER_CHANNELS=training SAGEMAKER_CHANNEL_TRAINING=/opt/ml/input/training \ SAGEMAKER_MODEL_DIR=/opt/ml/model python -m user_script --batch-size 128 --model_dir /opt/ml/model Args: user_entry_point (str): name of the user provided entry point args (list): A list of program arguments. env_vars (dict): A map containing the environment variables to be written. uri (str): the location of the module. capture_error (bool): Default false. If True, the running process captures the stderr, and appends it to the returned Exception message in case of errors. """ env_vars = env_vars or {} env_vars = env_vars.copy() _files.download_and_extract(uri, user_entry_point, _env.code_dir) install(user_entry_point, _env.code_dir, capture_error) _env.write_env_vars(env_vars) return _runner.get(runner).run(wait, capture_error)
def test_download_and_extract_tar(extractall, s3_download): t = tarfile.open(name="test.tar.gz", mode="w:gz") t.close() uri = t.name _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() extractall.assert_called_with(path=_env.code_dir) os.remove(uri)
def test_download_and_extract_file(copy, s3_download): uri = __file__ _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() copy.assert_called_with(uri, _env.code_dir)
def run( uri, user_entry_point, args, env_vars=None, wait=True, capture_error=False, runner=_runner.ProcessRunnerType, extra_opts=None, ): # type: (str, str, List[str], Dict[str, str], bool, bool, _runner.RunnerType,Dict[str, str]) -> None # pylint: disable=line-too-long # noqa ignore=E501 """Download, prepare and executes a compressed tar file from S3 or provided directory as an user entrypoint. Runs the user entry point, passing env_vars as environment variables and args as command arguments. If the entry point is: - A Python package: executes the packages as >>> env_vars python -m module_name + args - A Python script: executes the script as >>> env_vars python module_name + args - Any other: executes the command as >>> env_vars /bin/sh -c ./module_name + args Example: >>>import sagemaker_containers >>>from sagemaker_containers.beta.framework import entry_point >>>env = sagemaker_containers.training_env() {'channel-input-dirs': {'training': '/opt/ml/input/training'}, 'model_dir': '/opt/ml/model', ...} >>>hyperparameters = env.hyperparameters {'batch-size': 128, 'model_dir': '/opt/ml/model'} >>>args = mapping.to_cmd_args(hyperparameters) ['--batch-size', '128', '--model_dir', '/opt/ml/model'] >>>env_vars = mapping.to_env_vars() ['SAGEMAKER_CHANNELS':'training', 'SAGEMAKER_CHANNEL_TRAINING':'/opt/ml/input/training', 'MODEL_DIR':'/opt/ml/model', ...} >>>entry_point.run('user_script', args, env_vars) SAGEMAKER_CHANNELS=training SAGEMAKER_CHANNEL_TRAINING=/opt/ml/input/training \ SAGEMAKER_MODEL_DIR=/opt/ml/model python -m user_script --batch-size 128 --model_dir /opt/ml/model Args: uri (str): the location of the module. user_entry_point (str): name of the user provided entry point args (list): A list of program arguments. env_vars (dict): A map containing the environment variables to be written (default: None). wait (bool): If the user entry point should be run to completion before this method returns (default: True). capture_error (bool): Default false. If True, the running process captures the stderr, and appends it to the returned Exception message in case of errors. runner (sagemaker_containers.beta.framework.runner.RunnerType): the type of runner object to be created (default: sagemaker_containers.beta.framework.runner.ProcessRunnerType). extra_opts (dict): Additional options for running the entry point (default: None). Currently, this only applies for MPI. Returns: sagemaker_containers.beta.framework.process.ProcessRunner: the runner object responsible for executing the entry point. """ env_vars = env_vars or {} env_vars = env_vars.copy() _files.download_and_extract(uri, _env.code_dir) install(user_entry_point, _env.code_dir, capture_error) _env.write_env_vars(env_vars) _wait_hostname_resolution() return _runner.get(runner, user_entry_point, args, env_vars, extra_opts).run(wait, capture_error)
def test_download_and_and_extract_file(copy, s3_download): uri = _env.channel_path('code') _files.download_and_extract(uri, 'train.sh', _env.code_dir) s3_download.assert_not_called() copy.assert_called_with(uri, os.path.join(_env.code_dir, 'train.sh'))