def train(env):
    logger.info('MXNet training environment: {}'.format(env.to_env_vars()))

    if env.additional_framework_parameters.get(LAUNCH_PS_ENV_NAME, False):
        _verify_hosts(env.hosts)

        ps_port = env.hyperparameters.get('_ps_port', '8000')
        ps_verbose = env.hyperparameters.get('_ps_verbose', '0')

        logger.info('Starting distributed training task')
        if scheduler_host(env.hosts) == env.current_host:
            _run_mxnet_process('scheduler', env.hosts, ps_port, ps_verbose)
        _run_mxnet_process('server', env.hosts, ps_port, ps_verbose)
        os.environ.update(_env_vars_for_role('worker', env.hosts, ps_port, ps_verbose))

    mpi_enabled = env.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)

    if mpi_enabled:
        runner_type = runner.MPIRunnerType
    else:
        runner_type = runner.ProcessRunnerType

    entry_point.run(uri=env.module_dir,
                    user_entry_point=env.user_entry_point,
                    args=env.to_cmd_args(),
                    env_vars=env.to_env_vars(),
                    runner_type=runner_type)
def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract):
    runner_mock = MagicMock(spec=process.ProcessRunner)
    entry_point.run(
        uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock
    )

    gethostbyname.assert_called_with("algo-2")
    gethostbyname.assert_any_call("algo-1")
def train(training_env):
    logger.info("Invoking user training script.")

    entry_point.run(
        training_env.module_dir,
        training_env.user_entry_point,
        training_env.to_cmd_args(),
        training_env.to_env_vars(),
    )
def test_run_module_with_extra_opts(
    gethostbyname, chmod, download_and_extract, get_runner, sys_path
):
    module_name = "default_user_module_name"
    args = ["--some-arg", "42"]
    extra_opts = {"foo": "bar"}

    entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts)
    get_runner.assert_called_with(runner.ProcessRunnerType, module_name, args, {}, extra_opts)
def framework_training_with_script_mode_fn(capture_error):
    training_env = environment.Environment()

    entry_point.run(
        training_env.module_dir,
        training_env.user_entry_point,
        training_env.to_cmd_args(),
        training_env.to_env_vars(),
        capture_error=capture_error,
    )
def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract):

    gethostbyname.side_effect = [ValueError(), ValueError(), True, True]

    runner_mock = MagicMock(spec=process.ProcessRunner)
    entry_point.run(
        uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock
    )

    gethostbyname.assert_has_calls([call("algo-1"), call("algo-1"), call("algo-1"), call("algo-2")])
def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract, get_runner, sys_path):
    module_name = "default_user_module_name"
    args = ["--some-arg", "42"]
    entry_point.run(
        uri="s3://url", user_entry_point=module_name, args=args, env_vars={"FOO": "BAR"}
    )

    expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""}
    get_runner.assert_called_with(
        runner.ProcessRunnerType, module_name, args, expected_env_vars, None
    )
def _run_worker(env, cmd_args, tf_config):
    env_vars = env.to_env_vars()
    env_vars["TF_CONFIG"] = json.dumps(tf_config)

    entry_point.run(
        uri=env.module_dir,
        user_entry_point=env.user_entry_point,
        args=cmd_args,
        env_vars=env_vars,
        capture_error=True,
    )
Example #9
0
def train(training_environment):
    """Run PyTorch training on a user supplied module.

    The user supplied module is run in either a local or distributed SageMaker
    environment.

    The user supplied module and its dependencies are downloaded from S3.
    Training is invoked by calling a "train" function in the user supplied module.
    if the environment contains multiple hosts, then a distributed learning
    task is started.

    Args:
        training_environment: training environment object containing environment
            variables, training arguments and hyperparameters.
    """
    # Block until all host DNS lookups succeed. Relies on retrying dns_lookup.
    logger.info('Block until all host DNS lookups succeed.')
    for host in training_environment.hosts:
        _dns_lookup(host)

    _set_nccl_environment(training_environment.network_interface_name)

    _set_distributed_environment(training_environment.hosts)

    mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled')

    smdataparallel_enabled = training_environment.additional_framework_parameters.get(
        LAUNCH_SMDATAPARALLEL_ENV_NAME, False
    )

    if mpi_enabled:
        runner_type = runner.MPIRunnerType
    elif smdataparallel_enabled:
        runner_type = runner.SMDataParallelRunnerType
        logger.info('Invoking SMDataParallel')
    else:
        runner_type = runner.ProcessRunnerType

    logger.info('Invoking user training script.')
    try:
        entry_point.run(uri=training_environment.module_dir,
                        user_entry_point=training_environment.user_entry_point,
                        args=training_environment.to_cmd_args(),
                        env_vars=training_environment.to_env_vars(),
                        capture_error=True,
                        runner_type=runner_type)
    except errors.ExecuteUserScriptError as err:
        message = str(err)
        if message.find('terminate called after throwing an instance of \'gloo::EnforceNotMet\'') > -1:
            logger.warn('Known exception: {}'.format(message))
        else:
            info = sys.exc_info()
            six.reraise(info[0], err, info[2])
Example #10
0
def test_run_module_no_wait(gethostbyname, chmod, download_and_extract):
    runner_mock = MagicMock(spec=process.ProcessRunner)

    module_name = "default_user_module_name"
    entry_point.run(
        uri="s3://url",
        user_entry_point=module_name,
        args=["42"],
        wait=False,
        runner_type=runner_mock,
    )

    runner_mock.run.assert_called_with(False, False)
def test_script_entry_point_with_python_package(
    gethostbyname, check_error, chmod, entry_point_type_module
):
    runner_mock = MagicMock(spec=process.ProcessRunner)

    entry_point.run(
        uri="s3://dummy-uri",
        user_entry_point="train.sh",
        args=["dummy_arg"],
        runner_type=runner_mock,
    )

    chmod.assert_called_with(os.path.join(environment.code_dir, "train.sh"), 511)
def test_run_module_wait(gethostbyname, check_error, chmod, download_and_extract):
    runner_mock = MagicMock(spec=process.ProcessRunner)

    entry_point.run(
        uri="s3://url",
        user_entry_point="launcher.sh",
        args=["42"],
        capture_error=True,
        runner_type=runner_mock,
    )

    download_and_extract.assert_called_with(uri="s3://url", path=environment.code_dir)
    runner_mock.run.assert_called_with(True, True)
    chmod.assert_called_with(os.path.join(environment.code_dir, "launcher.sh"), 511)
def train(training_environment):
    """Runs Scikit-learn training on a user supplied module in local SageMaker environment.
    The user supplied module and its dependencies are downloaded from S3.
    Training is invoked by calling a "train" function in the user supplied module.

    Args:
        training_environment: training environment object containing environment variables,
                               training arguments and hyperparameters
    """
    logger.info('Invoking user training script.')
    entry_point.run(uri=training_environment.module_dir,
                    user_entry_point=training_environment.user_entry_point,
                    args=training_environment.to_cmd_args(),
                    env_vars=training_environment.to_env_vars(),
                    runner_type=runner.ProcessRunnerType)
Example #14
0
def train(env, cmd_args):
    """Get training job environment from env and run the training job.

    Args:
        env (sagemaker_training.env.TrainingEnv): Instance of TrainingEnv class
    """
    parameter_server_enabled = env.additional_framework_parameters.get(
        SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
    if len(env.hosts) > 1 and parameter_server_enabled:

        tf_config = _build_tf_config(hosts=env.hosts,
                                     current_host=env.current_host)

        logger.info("Running distributed training job with parameter servers")
        logger.info("Launching parameter server process")
        _run_ps(env, tf_config["cluster"])
        logger.info("Launching worker process")
        _run_worker(env, cmd_args, tf_config)

        if not _is_host_master(env.hosts, env.current_host):
            _wait_until_master_is_down(env.hosts[0])

    else:

        mpi_enabled = env.additional_framework_parameters.get(
            "sagemaker_mpi_enabled")

        if mpi_enabled:
            runner_type = runner.MPIRunnerType
        else:
            runner_type = runner.ProcessRunnerType

        entry_point.run(
            uri=env.module_dir,
            user_entry_point=env.user_entry_point,
            args=cmd_args,
            env_vars=env.to_env_vars(),
            capture_error=True,
            runner_type=runner_type,
        )
def test_parameter_server():
    module = test.UserModule(test.File(name="user_script.py", data=PARAMETER_SERVER_SCRIPT))
    hyperparameters = dict(sagemaker_program="user_script.py")

    test.prepare(
        user_module=module,
        hyperparameters=hyperparameters,
        channels=[test.Channel.create(name="training")],
    )
    training_env = environment.Environment()
    process = entry_point.run(
        training_env.module_dir,
        training_env.user_entry_point,
        training_env.to_cmd_args(),
        training_env.to_env_vars(),
        wait=False,
    )
    # confirm the ps process is still hanging
    assert process.poll() is None
    process.kill()
Example #16
0
def _run_worker(env, cmd_args, tf_config):
    env_vars = env.to_env_vars()
    env_vars['TF_CONFIG'] = json.dumps(tf_config)

    entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env_vars)
Example #17
0
def train():
    """The main function responsible for running training in the container."""
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        env = environment.Environment()

        region = os.environ.get("AWS_REGION",
                                os.environ.get(params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None)
        intermediate_sync = intermediate_output.start_sync(
            env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(":")

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing
            # the framework to configure logging at import time.
            logging_config.configure_logger(env.log_level)
            logger.info("Imported framework %s", framework_name)
            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            logging_config.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(
                params.MPI_ENABLED)
            runner_type = (runner.RunnerType.MPI if mpi_enabled and
                           (env.current_instance_group
                            in env.distribution_instance_groups) else
                           runner.RunnerType.Process)

            entry_point.run(
                env.module_dir,
                env.user_entry_point,
                env.to_cmd_args(),
                env.to_env_vars(),
                runner_type=runner_type,
            )
        logger.info("Reporting training SUCCESS")

        files.write_success_file()
    except errors.ClientError as e:

        failure_msg = str(e)
        files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:  # pylint: disable=broad-except
        failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(),
                                                     str(e))

        files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE)
        exit_code = _get_valid_failure_exit_code(error_number)
    finally:
        if intermediate_sync:
            intermediate_sync.join()
        _exit_processes(exit_code)