def train(env): logger.info('MXNet training environment: {}'.format(env.to_env_vars())) if env.additional_framework_parameters.get(LAUNCH_PS_ENV_NAME, False): _verify_hosts(env.hosts) ps_port = env.hyperparameters.get('_ps_port', '8000') ps_verbose = env.hyperparameters.get('_ps_verbose', '0') logger.info('Starting distributed training task') if scheduler_host(env.hosts) == env.current_host: _run_mxnet_process('scheduler', env.hosts, ps_port, ps_verbose) _run_mxnet_process('server', env.hosts, ps_port, ps_verbose) os.environ.update(_env_vars_for_role('worker', env.hosts, ps_port, ps_verbose)) mpi_enabled = env.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME) if mpi_enabled: runner_type = runner.MPIRunnerType else: runner_type = runner.ProcessRunnerType entry_point.run(uri=env.module_dir, user_entry_point=env.user_entry_point, args=env.to_cmd_args(), env_vars=env.to_env_vars(), runner_type=runner_type)
def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract): runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock ) gethostbyname.assert_called_with("algo-2") gethostbyname.assert_any_call("algo-1")
def train(training_env): logger.info("Invoking user training script.") entry_point.run( training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), )
def test_run_module_with_extra_opts( gethostbyname, chmod, download_and_extract, get_runner, sys_path ): module_name = "default_user_module_name" args = ["--some-arg", "42"] extra_opts = {"foo": "bar"} entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts) get_runner.assert_called_with(runner.ProcessRunnerType, module_name, args, {}, extra_opts)
def framework_training_with_script_mode_fn(capture_error): training_env = environment.Environment() entry_point.run( training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), capture_error=capture_error, )
def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract): gethostbyname.side_effect = [ValueError(), ValueError(), True, True] runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock ) gethostbyname.assert_has_calls([call("algo-1"), call("algo-1"), call("algo-1"), call("algo-2")])
def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract, get_runner, sys_path): module_name = "default_user_module_name" args = ["--some-arg", "42"] entry_point.run( uri="s3://url", user_entry_point=module_name, args=args, env_vars={"FOO": "BAR"} ) expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""} get_runner.assert_called_with( runner.ProcessRunnerType, module_name, args, expected_env_vars, None )
def _run_worker(env, cmd_args, tf_config): env_vars = env.to_env_vars() env_vars["TF_CONFIG"] = json.dumps(tf_config) entry_point.run( uri=env.module_dir, user_entry_point=env.user_entry_point, args=cmd_args, env_vars=env_vars, capture_error=True, )
def train(training_environment): """Run PyTorch training on a user supplied module. The user supplied module is run in either a local or distributed SageMaker environment. The user supplied module and its dependencies are downloaded from S3. Training is invoked by calling a "train" function in the user supplied module. if the environment contains multiple hosts, then a distributed learning task is started. Args: training_environment: training environment object containing environment variables, training arguments and hyperparameters. """ # Block until all host DNS lookups succeed. Relies on retrying dns_lookup. logger.info('Block until all host DNS lookups succeed.') for host in training_environment.hosts: _dns_lookup(host) _set_nccl_environment(training_environment.network_interface_name) _set_distributed_environment(training_environment.hosts) mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled') smdataparallel_enabled = training_environment.additional_framework_parameters.get( LAUNCH_SMDATAPARALLEL_ENV_NAME, False ) if mpi_enabled: runner_type = runner.MPIRunnerType elif smdataparallel_enabled: runner_type = runner.SMDataParallelRunnerType logger.info('Invoking SMDataParallel') else: runner_type = runner.ProcessRunnerType logger.info('Invoking user training script.') try: entry_point.run(uri=training_environment.module_dir, user_entry_point=training_environment.user_entry_point, args=training_environment.to_cmd_args(), env_vars=training_environment.to_env_vars(), capture_error=True, runner_type=runner_type) except errors.ExecuteUserScriptError as err: message = str(err) if message.find('terminate called after throwing an instance of \'gloo::EnforceNotMet\'') > -1: logger.warn('Known exception: {}'.format(message)) else: info = sys.exc_info() six.reraise(info[0], err, info[2])
def test_run_module_no_wait(gethostbyname, chmod, download_and_extract): runner_mock = MagicMock(spec=process.ProcessRunner) module_name = "default_user_module_name" entry_point.run( uri="s3://url", user_entry_point=module_name, args=["42"], wait=False, runner_type=runner_mock, ) runner_mock.run.assert_called_with(False, False)
def test_script_entry_point_with_python_package( gethostbyname, check_error, chmod, entry_point_type_module ): runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( uri="s3://dummy-uri", user_entry_point="train.sh", args=["dummy_arg"], runner_type=runner_mock, ) chmod.assert_called_with(os.path.join(environment.code_dir, "train.sh"), 511)
def test_run_module_wait(gethostbyname, check_error, chmod, download_and_extract): runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( uri="s3://url", user_entry_point="launcher.sh", args=["42"], capture_error=True, runner_type=runner_mock, ) download_and_extract.assert_called_with(uri="s3://url", path=environment.code_dir) runner_mock.run.assert_called_with(True, True) chmod.assert_called_with(os.path.join(environment.code_dir, "launcher.sh"), 511)
def train(training_environment): """Runs Scikit-learn training on a user supplied module in local SageMaker environment. The user supplied module and its dependencies are downloaded from S3. Training is invoked by calling a "train" function in the user supplied module. Args: training_environment: training environment object containing environment variables, training arguments and hyperparameters """ logger.info('Invoking user training script.') entry_point.run(uri=training_environment.module_dir, user_entry_point=training_environment.user_entry_point, args=training_environment.to_cmd_args(), env_vars=training_environment.to_env_vars(), runner_type=runner.ProcessRunnerType)
def train(env, cmd_args): """Get training job environment from env and run the training job. Args: env (sagemaker_training.env.TrainingEnv): Instance of TrainingEnv class """ parameter_server_enabled = env.additional_framework_parameters.get( SAGEMAKER_PARAMETER_SERVER_ENABLED, False) if len(env.hosts) > 1 and parameter_server_enabled: tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host) logger.info("Running distributed training job with parameter servers") logger.info("Launching parameter server process") _run_ps(env, tf_config["cluster"]) logger.info("Launching worker process") _run_worker(env, cmd_args, tf_config) if not _is_host_master(env.hosts, env.current_host): _wait_until_master_is_down(env.hosts[0]) else: mpi_enabled = env.additional_framework_parameters.get( "sagemaker_mpi_enabled") if mpi_enabled: runner_type = runner.MPIRunnerType else: runner_type = runner.ProcessRunnerType entry_point.run( uri=env.module_dir, user_entry_point=env.user_entry_point, args=cmd_args, env_vars=env.to_env_vars(), capture_error=True, runner_type=runner_type, )
def test_parameter_server(): module = test.UserModule(test.File(name="user_script.py", data=PARAMETER_SERVER_SCRIPT)) hyperparameters = dict(sagemaker_program="user_script.py") test.prepare( user_module=module, hyperparameters=hyperparameters, channels=[test.Channel.create(name="training")], ) training_env = environment.Environment() process = entry_point.run( training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), wait=False, ) # confirm the ps process is still hanging assert process.poll() is None process.kill()
def _run_worker(env, cmd_args, tf_config): env_vars = env.to_env_vars() env_vars['TF_CONFIG'] = json.dumps(tf_config) entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env_vars)
def train(): """The main function responsible for running training in the container.""" intermediate_sync = None exit_code = SUCCESS_CODE try: env = environment.Environment() region = os.environ.get("AWS_REGION", os.environ.get(params.REGION_NAME_ENV)) s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None) intermediate_sync = intermediate_output.start_sync( env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(":") framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing # the framework to configure logging at import time. logging_config.configure_logger(env.log_level) logger.info("Imported framework %s", framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: logging_config.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get( params.MPI_ENABLED) runner_type = (runner.RunnerType.MPI if mpi_enabled and (env.current_instance_group in env.distribution_instance_groups) else runner.RunnerType.Process) entry_point.run( env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner_type=runner_type, ) logger.info("Reporting training SUCCESS") files.write_success_file() except errors.ClientError as e: failure_msg = str(e) files.write_failure_file(failure_msg) logger.error("Reporting training FAILURE") logger.error(failure_msg) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: # pylint: disable=broad-except failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(), str(e)) files.write_failure_file(failure_msg) logger.error("Reporting training FAILURE") logger.error(failure_msg) error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE) exit_code = _get_valid_failure_exit_code(error_number) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)