def test_get_runner_by_mpi_with_extra_args(training_env):
    training_env().num_gpus = 0

    test_runner = runner.get(runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS, MPI_OPTS)

    assert isinstance(test_runner, mpi.MasterRunner)

    assert test_runner._user_entry_point == USER_SCRIPT
    assert test_runner._args == CMD_ARGS
    assert test_runner._env_vars == ENV_VARS
    assert test_runner._process_per_host == 2
    assert test_runner._num_processes == 4
    assert test_runner._custom_mpi_options == NCCL_DEBUG_MPI_OPT

    training_env().to_cmd_args.assert_not_called()
    training_env().to_env_vars.assert_not_called()
    training_env().user_entry_point.assert_not_called()
    training_env().additional_framework_parameters.assert_not_called()

    training_env().is_master = False
    test_runner = runner.get(runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS)

    assert isinstance(test_runner, mpi.WorkerRunner)

    assert test_runner._user_entry_point == USER_SCRIPT
    assert test_runner._args == CMD_ARGS
    assert test_runner._env_vars == ENV_VARS

    training_env().to_cmd_args.assert_not_called()
    training_env().to_env_vars.assert_not_called()
    training_env().user_entry_point.assert_not_called()
def test_get_runner_by_mpi_returns_runnner(training_env):
    training_env().num_gpus = 0

    test_runner = runner.get(runner.MPIRunnerType)

    assert isinstance(test_runner, mpi.MasterRunner)
    training_env().to_cmd_args.assert_called()
    training_env().to_env_vars.assert_called()

    training_env().is_master = False
    test_runner = runner.get(runner.MPIRunnerType)

    assert isinstance(test_runner, mpi.WorkerRunner)
    training_env().to_cmd_args.assert_called()
    training_env().to_env_vars.assert_called()
def test_runnner_with_default_gpu_processes_per_host(training_env):
    training_env().additional_framework_parameters = dict()
    training_env().num_gpus = 2

    test_runner = runner.get(runner.MPIRunnerType)

    assert isinstance(test_runner, mpi.MasterRunner)
    assert test_runner._process_per_host == 2
def test_get_runner_by_process_with_extra_args(training_env):
    test_runner = runner.get(runner.ProcessRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS)

    assert isinstance(test_runner, process.ProcessRunner)

    assert test_runner._user_entry_point == USER_SCRIPT
    assert test_runner._args == CMD_ARGS
    assert test_runner._env_vars == ENV_VARS

    training_env().to_cmd_args.assert_not_called()
    training_env().to_env_vars.assert_not_called()
    training_env().user_entry_point.assert_not_called()
def run(
    uri,
    user_entry_point,
    args,
    env_vars=None,
    wait=True,
    capture_error=False,
    runner_type=runner.ProcessRunnerType,
    extra_opts=None,
):
    """Download, prepare and execute a compressed tar file from S3 or provided directory as a user
    entry point. Run the user entry point, passing env_vars as environment variables and args
    as command arguments.

    If the entry point is:
        - A Python package: executes the packages as >>> env_vars python -m module_name + args
        - A Python script: executes the script as >>> env_vars python module_name + args
        - Any other: executes the command as >>> env_vars /bin/sh -c ./module_name + args

    Example:
         >>>from sagemaker_training import entry_point, environment, mapping

         >>>env = environment.Environment()
         {'channel-input-dirs': {'training': '/opt/ml/input/training'},
          'model_dir': '/opt/ml/model', ...}


         >>>hyperparameters = environment.hyperparameters
         {'batch-size': 128, 'model_dir': '/opt/ml/model'}

         >>>args = mapping.to_cmd_args(hyperparameters)
         ['--batch-size', '128', '--model_dir', '/opt/ml/model']

         >>>env_vars = mapping.to_env_vars()
         ['SAGEMAKER_CHANNELS':'training', 'SAGEMAKER_CHANNEL_TRAINING':'/opt/ml/input/training',
         'MODEL_DIR':'/opt/ml/model', ...}

         >>>entry_point.run('user_script', args, env_vars)
         SAGEMAKER_CHANNELS=training SAGEMAKER_CHANNEL_TRAINING=/opt/ml/input/training \
         SAGEMAKER_MODEL_DIR=/opt/ml/model python -m user_script --batch-size 128
                             --model_dir /opt/ml/model

    Args:
        uri (str): The location of the module or script. This can be an S3 uri, a path to
            a local directory, or a path to a local tarball.
        user_entry_point (str): Name of the user provided entry point.
        args ([str]):  A list of program arguments.
        env_vars (dict(str,str)): A map containing the environment variables to be written
            (default: None).
        wait (bool): If the user entry point should be run to completion before this method returns
            (default: True).
        capture_error (bool): Default false. If True, the running process captures the
            stderr, and appends it to the returned Exception message in case of errors.
        runner_type (sagemaker_training.runner.RunnerType): The type of runner object to
            be created (default: sagemaker_training.runner.ProcessRunnerType).
        extra_opts (dict(str,str)): Additional options for running the entry point (default: None).
            Currently, this only applies for MPI.

    Returns:
        sagemaker_training.process.ProcessRunner: The runner object responsible for
            executing the entry point.
    """
    env_vars = env_vars or {}
    env_vars = env_vars.copy()

    files.download_and_extract(uri=uri, path=environment.code_dir)
    install(name=user_entry_point,
            path=environment.code_dir,
            capture_error=capture_error)

    environment.write_env_vars(env_vars)

    _wait_hostname_resolution()

    return runner.get(runner_type, user_entry_point, args, env_vars,
                      extra_opts).run(wait, capture_error)
def test_get_runner_by_process_returns_runnner(training_env):
    test_runner = runner.get(runner.ProcessRunnerType)

    assert isinstance(test_runner, process.ProcessRunner)
    training_env().to_cmd_args.assert_called()
    training_env().to_env_vars.assert_called()
def test_get_runner_returns_runnner_itself(runner_class):
    runner_mock = MagicMock(spec=runner_class)

    assert runner.get(runner_mock) == runner_mock
def test_get_runner_invalid_identifier():
    with pytest.raises(ValueError):
        runner.get(42)