Пример #1
0
def test_run_module_with_env_vars(chmod, download_and_extract, get_runner, sys_path):
    module_name = 'default_user_module_name'
    args = ['--some-arg', '42']
    entry_point.run(uri='s3://url', user_entry_point=module_name, args=args, env_vars={'FOO': 'BAR'})

    expected_env_vars = {'FOO': 'BAR', 'PYTHONPATH': ''}
    get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, expected_env_vars, None)
Пример #2
0
def test_run_module_with_extra_opts(chmod, download_and_extract, get_runner, sys_path):
    module_name = 'default_user_module_name'
    args = ['--some-arg', '42']
    extra_opts = {'foo': 'bar'}

    entry_point.run(uri='s3://url', user_entry_point=module_name, args=args, extra_opts=extra_opts)
    get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args, {}, extra_opts)
Пример #3
0
def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)
    entry_point.run(uri='s3://url', user_entry_point='launcher.py',
                    args=['42'], runner=runner)

    gethostbyname.assert_called_with('algo-2')
    gethostbyname.assert_any_call('algo-1')
Пример #4
0
def test_run_module_no_wait(chmod, download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)

    module_name = 'default_user_module_name'
    entry_point.run(uri='s3://url', user_entry_point=module_name, args=['42'], wait=False, runner=runner)

    runner.run.assert_called_with(False, False)
Пример #5
0
def test_run_module_wait(chmod, download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)
    entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'],
                    capture_error=True, runner=runner)

    download_and_extract.assert_called_with('s3://url', _env.code_dir)
    runner.run.assert_called_with(True, True)
    chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511)
Пример #6
0
def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract):

    gethostbyname.side_effect = [ValueError(), ValueError(), True, True]

    runner = MagicMock(spec=_process.ProcessRunner)
    entry_point.run(uri='s3://url', user_entry_point='launcher.py',
                    args=['42'], runner=runner)

    gethostbyname.assert_has_calls([call('algo-1'), call('algo-1'), call('algo-1'), call('algo-2')])
Пример #7
0
def test_run_calls_hostname_resolution(gethostbyname, install, hosts,
                                       download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)
    entry_point.run(uri="s3://url",
                    user_entry_point="launcher.py",
                    args=["42"],
                    runner=runner)

    gethostbyname.assert_called_with("algo-2")
    gethostbyname.assert_any_call("algo-1")
Пример #8
0
def test_run_module_no_wait(gethostbyname, chmod, download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)

    module_name = "default_user_module_name"
    entry_point.run(uri="s3://url",
                    user_entry_point=module_name,
                    args=["42"],
                    wait=False,
                    runner=runner)

    runner.run.assert_called_with(False, False)
Пример #9
0
def test_run_module_with_extra_opts(gethostbyname, chmod, download_and_extract,
                                    get_runner, sys_path):
    module_name = "default_user_module_name"
    args = ["--some-arg", "42"]
    extra_opts = {"foo": "bar"}

    entry_point.run(uri="s3://url",
                    user_entry_point=module_name,
                    args=args,
                    extra_opts=extra_opts)
    get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args,
                                  {}, extra_opts)
Пример #10
0
def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract,
                                  get_runner, sys_path):
    module_name = "default_user_module_name"
    args = ["--some-arg", "42"]
    entry_point.run(uri="s3://url",
                    user_entry_point=module_name,
                    args=args,
                    env_vars={"FOO": "BAR"})

    expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""}
    get_runner.assert_called_with(_runner.ProcessRunnerType, module_name, args,
                                  expected_env_vars, None)
Пример #11
0
def test_run_module_wait(gethostbyname, check_error, chmod,
                         download_and_extract):
    runner = MagicMock(spec=_process.ProcessRunner)

    entry_point.run(
        uri="s3://url",
        user_entry_point="launcher.sh",
        args=["42"],
        capture_error=True,
        runner=runner,
    )

    download_and_extract.assert_called_with("s3://url", _env.code_dir)
    runner.run.assert_called_with(True, True)
    chmod.assert_called_with(os.path.join(_env.code_dir, "launcher.sh"), 511)
Пример #12
0
def train():
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        # TODO: iquintero - add error handling for ImportError to let the user know
        # if the framework module is not defined.
        env = sagemaker_containers.training_env()

        # TODO: [issue#144] There is a bug in the logic -
        # we need os.environ.get(_params.REGION_NAME_ENV)
        # in certain regions, but it is not going to be available unless
        # TrainingEnvironment has been initialized. It shouldn't be environment variable.
        region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL")
        intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(':')

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing the framework to
            # configure logging at import time.
            _logging.configure_logger(env.log_level)
            logger.info('Imported framework %s', framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            _logging.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(_params.MPI_ENABLED)
            runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process

            entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(),
                            env.to_env_vars(), runner=runner_type)

        logger.info('Reporting training SUCCESS')

        _files.write_success_file()
    except _errors.ClientError as e:

        failure_message = str(e)
        _files.write_failure_file(failure_message)

        logger.error(failure_message)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:
        failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e))

        _files.write_failure_file(failure_msg)
        logger.error('Reporting training FAILURE')

        logger.error(failure_msg)

        exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

        _exit_processes(exit_code)
def test_run_module_no_wait(call, download_and_extract, entry_point_type_module):
    with pytest.raises(_errors.InstallModuleError):
        entry_point.run(uri='s3://url', user_entry_point='default_user_module_name', args=['42'], wait=False)

        download_and_extract.assert_called_with('s3://url', 'default_user_module_name', _env.code_dir)
        call.assert_called_with('default_user_module_name', ['42'], {}, False)
def test_run_module_wait(chmod, call, download_and_extract):
    entry_point.run(uri='s3://url', user_entry_point='launcher.sh', args=['42'], capture_error=True)

    download_and_extract.assert_called_with('s3://url', 'launcher.sh', _env.code_dir)
    call.assert_called_with('launcher.sh', ['42'], {}, True, True)
    chmod.assert_called_with(os.path.join(_env.code_dir, 'launcher.sh'), 511)
Пример #15
0
def train():
    """Placeholder docstring"""
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        env = sagemaker_containers.training_env()

        region = os.environ.get("AWS_REGION",
                                os.environ.get(_params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get(_params.S3_ENDPOINT_URL, None)
        intermediate_sync = _intermediate_output.start_sync(
            env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(":")

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing
            # the framework to configure logging at import time.
            _logging.configure_logger(env.log_level)
            logger.info("Imported framework %s", framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            _logging.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(
                _params.MPI_ENABLED)
            runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process

            entry_point.run(
                env.module_dir,
                env.user_entry_point,
                env.to_cmd_args(),
                env.to_env_vars(),
                runner=runner_type,
            )

        logger.info("Reporting training SUCCESS")

        _files.write_success_file()
    except _errors.ClientError as e:

        failure_message = str(e)
        _files.write_failure_file(failure_message)

        logger.error(failure_message)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:  # pylint: disable=broad-except
        failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(),
                                                     str(e))

        _files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE)
        exit_code = _get_valid_failure_exit_code(error_number)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

        _exit_processes(exit_code)