def test_run_module_with_env_vars(chmod, download_and_extract, get_runner, sys_path): module_name = 'default_user_module_name' args = ['--some-arg', '42'] entry_point.run(uri='s3://url', user_entry_point=module_name, args=args, env_vars={'FOO': 'BAR'}) expected_env_vars = {'FOO': 'BAR', 'PYTHONPATH': ''} get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, expected_env_vars, None)
def test_run_module_with_extra_opts(chmod, download_and_extract, get_runner, sys_path): module_name = 'default_user_module_name' args = ['--some-arg', '42'] extra_opts = {'foo': 'bar'} entry_point.run(uri='s3://url', user_entry_point=module_name, args=args, extra_opts=extra_opts) get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, {}, extra_opts)
def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) entry_point.run(uri='s3://url', user_entry_point='launcher.py', args=['42'], runner=runner) gethostbyname.assert_called_with('algo-2') gethostbyname.assert_any_call('algo-1')
def test_run_module_no_wait(chmod, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) module_name = 'default_user_module_name' entry_point.run(uri='s3://url', user_entry_point=module_name, args=['42'], wait=False, runner=runner) runner.run.assert_called_with(False, False)
def test_run_module_wait(chmod, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'], capture_error=True, runner=runner) download_and_extract.assert_called_with('s3://url', _env.code_dir) runner.run.assert_called_with(True, True) chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511)
def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract): gethostbyname.side_effect = [ValueError(), ValueError(), True, True] runner = MagicMock(spec=_process.ProcessRunner) entry_point.run(uri='s3://url', user_entry_point='launcher.py', args=['42'], runner=runner) gethostbyname.assert_has_calls([call('algo-1'), call('algo-1'), call('algo-1'), call('algo-2')])
def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) entry_point.run(uri="s3://url", user_entry_point="launcher.py", args=["42"], runner=runner) gethostbyname.assert_called_with("algo-2") gethostbyname.assert_any_call("algo-1")
def test_run_module_no_wait(gethostbyname, chmod, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) module_name = "default_user_module_name" entry_point.run(uri="s3://url", user_entry_point=module_name, args=["42"], wait=False, runner=runner) runner.run.assert_called_with(False, False)
def test_run_module_with_extra_opts(gethostbyname, chmod, download_and_extract, get_runner, sys_path): module_name = "default_user_module_name" args = ["--some-arg", "42"] extra_opts = {"foo": "bar"} entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts) get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, {}, extra_opts)
def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract, get_runner, sys_path): module_name = "default_user_module_name" args = ["--some-arg", "42"] entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, env_vars={"FOO": "BAR"}) expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""} get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, expected_env_vars, None)
def test_run_module_wait(gethostbyname, check_error, chmod, download_and_extract): runner = MagicMock(spec=_process.ProcessRunner) entry_point.run( uri="s3://url", user_entry_point="launcher.sh", args=["42"], capture_error=True, runner=runner, ) download_and_extract.assert_called_with("s3://url", _env.code_dir) runner.run.assert_called_with(True, True) chmod.assert_called_with(os.path.join(_env.code_dir, "launcher.sh"), 511)
def train(): intermediate_sync = None exit_code = SUCCESS_CODE try: # TODO: iquintero - add error handling for ImportError to let the user know # if the framework module is not defined. env = sagemaker_containers.training_env() # TODO: [issue#144] There is a bug in the logic - # we need os.environ.get(_params.REGION_NAME_ENV) # in certain regions, but it is not going to be available unless # TrainingEnvironment has been initialized. It shouldn't be environment variable. region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV)) s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL") intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(':') framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing the framework to # configure logging at import time. _logging.configure_logger(env.log_level) logger.info('Imported framework %s', framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: _logging.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get(_params.MPI_ENABLED) runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner=runner_type) logger.info('Reporting training SUCCESS') _files.write_success_file() except _errors.ClientError as e: failure_message = str(e) _files.write_failure_file(failure_message) logger.error(failure_message) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e)) _files.write_failure_file(failure_msg) logger.error('Reporting training FAILURE') logger.error(failure_msg) exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)
def test_run_module_no_wait(call, download_and_extract, entry_point_type_module): with pytest.raises(_errors.InstallModuleError): entry_point.run(uri='s3://url', user_entry_point='default_user_module_name', args=['42'], wait=False) download_and_extract.assert_called_with('s3://url', 'default_user_module_name', _env.code_dir) call.assert_called_with('default_user_module_name', ['42'], {}, False)
def test_run_module_wait(chmod, call, download_and_extract): entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'], capture_error=True) download_and_extract.assert_called_with('s3://url', 'launcher.sh', _env.code_dir) call.assert_called_with('launcher.sh', ['42'], {}, True, True) chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511)
def train(): """Placeholder docstring""" intermediate_sync = None exit_code = SUCCESS_CODE try: env = sagemaker_containers.training_env() region = os.environ.get("AWS_REGION", os.environ.get(_params.REGION_NAME_ENV)) s3_endpoint_url = os.environ.get(_params.S3_ENDPOINT_URL, None) intermediate_sync = _intermediate_output.start_sync( env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(":") framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing # the framework to configure logging at import time. _logging.configure_logger(env.log_level) logger.info("Imported framework %s", framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: _logging.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get( _params.MPI_ENABLED) runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process entry_point.run( env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner=runner_type, ) logger.info("Reporting training SUCCESS") _files.write_success_file() except _errors.ClientError as e: failure_message = str(e) _files.write_failure_file(failure_message) logger.error(failure_message) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: # pylint: disable=broad-except failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(), str(e)) _files.write_failure_file(failure_msg) logger.error("Reporting training FAILURE") logger.error(failure_msg) error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE) exit_code = _get_valid_failure_exit_code(error_number) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)