def import_module(uri, name=DEFAULT_MODULE_NAME):  # type: (str, str) -> module
    """Download, prepare and install a compressed tar file from S3 or provided directory as a
    module.
    SageMaker Python SDK saves the user provided scripts as compressed tar files in S3
    https://github.com/aws/sagemaker-python-sdk.
    This function downloads this compressed file (if provided), transforms it as a module, and
    installs it.

    Args:
        name (str): Name of the script or module.
        uri (str): The location of the module.

    Returns:
        (module): The imported module.
    """
    files.download_and_extract(uri, environment.code_dir)

    prepare(environment.code_dir, name)
    install(environment.code_dir)
    try:
        module = importlib.import_module(name)
        six.moves.reload_module(module)  # pylint: disable=too-many-function-args

        return module
    except Exception as e:  # pylint: disable=broad-except
        six.reraise(errors.ImportModuleError, errors.ImportModuleError(e),
                    sys.exc_info()[2])
Пример #2
0
def test_download_and_extract_source_dir(copy, rmtree, s3_download):
    uri = environment.channel_path("code")
    files.download_and_extract(uri, environment.code_dir)
    s3_download.assert_not_called()

    rmtree.assert_any_call(environment.code_dir)
    copy.assert_called_with(uri, environment.code_dir)
Пример #3
0
def test_download_and_extract_tar(extractall, s3_download):
    t = tarfile.open(name="test.tar.gz", mode="w:gz")
    t.close()
    uri = t.name
    files.download_and_extract(uri, environment.code_dir)

    s3_download.assert_not_called()
    extractall.assert_called_with(path=environment.code_dir)

    os.remove(uri)
def run(
    uri,
    user_entry_point,
    args,
    env_vars=None,
    wait=True,
    capture_error=False,
    runner_type=runner.ProcessRunnerType,
    extra_opts=None,
):
    """Download, prepare and execute a compressed tar file from S3 or provided directory as a user
    entry point. Run the user entry point, passing env_vars as environment variables and args
    as command arguments.

    If the entry point is:
        - A Python package: executes the packages as >>> env_vars python -m module_name + args
        - A Python script: executes the script as >>> env_vars python module_name + args
        - Any other: executes the command as >>> env_vars /bin/sh -c ./module_name + args

    Example:
         >>>from sagemaker_training import entry_point, environment, mapping

         >>>env = environment.Environment()
         {'channel-input-dirs': {'training': '/opt/ml/input/training'},
          'model_dir': '/opt/ml/model', ...}


         >>>hyperparameters = environment.hyperparameters
         {'batch-size': 128, 'model_dir': '/opt/ml/model'}

         >>>args = mapping.to_cmd_args(hyperparameters)
         ['--batch-size', '128', '--model_dir', '/opt/ml/model']

         >>>env_vars = mapping.to_env_vars()
         ['SAGEMAKER_CHANNELS':'training', 'SAGEMAKER_CHANNEL_TRAINING':'/opt/ml/input/training',
         'MODEL_DIR':'/opt/ml/model', ...}

         >>>entry_point.run('user_script', args, env_vars)
         SAGEMAKER_CHANNELS=training SAGEMAKER_CHANNEL_TRAINING=/opt/ml/input/training \
         SAGEMAKER_MODEL_DIR=/opt/ml/model python -m user_script --batch-size 128
                             --model_dir /opt/ml/model

    Args:
        uri (str): The location of the module or script. This can be an S3 uri, a path to
            a local directory, or a path to a local tarball.
        user_entry_point (str): Name of the user provided entry point.
        args ([str]):  A list of program arguments.
        env_vars (dict(str,str)): A map containing the environment variables to be written
            (default: None).
        wait (bool): If the user entry point should be run to completion before this method returns
            (default: True).
        capture_error (bool): Default false. If True, the running process captures the
            stderr, and appends it to the returned Exception message in case of errors.
        runner_type (sagemaker_training.runner.RunnerType): The type of runner object to
            be created (default: sagemaker_training.runner.ProcessRunnerType).
        extra_opts (dict(str,str)): Additional options for running the entry point (default: None).
            Currently, this only applies for MPI.

    Returns:
        sagemaker_training.process.ProcessRunner: The runner object responsible for
            executing the entry point.
    """
    env_vars = env_vars or {}
    env_vars = env_vars.copy()

    files.download_and_extract(uri=uri, path=environment.code_dir)
    install(name=user_entry_point,
            path=environment.code_dir,
            capture_error=capture_error)

    environment.write_env_vars(env_vars)

    _wait_hostname_resolution()

    return runner.get(runner_type, user_entry_point, args, env_vars,
                      extra_opts).run(wait, capture_error)
Пример #5
0
def test_download_and_extract_file(copy, s3_download):
    uri = __file__
    files.download_and_extract(uri, environment.code_dir)

    s3_download.assert_not_called()
    copy.assert_called_with(uri, environment.code_dir)