def train():
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        # TODO: iquintero - add error handling for ImportError to let the user know
        # if the framework module is not defined.
        env = sagemaker_containers.training_env()

        # TODO: [issue#144] There is a bug in the logic -
        # we need os.environ.get(_params.REGION_NAME_ENV)
        # in certain regions, but it is not going to be available unless
        # TrainingEnvironment has been initialized. It shouldn't be environment variable.
        region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV))
        intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(':')

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing the framework to
            # configure logging at import time.
            logging.configure_logger(env.log_level)
            logger.info('Imported framework %s', framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            logging.configure_logger(env.log_level)
            entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars())

        logger.info('Reporting training SUCCESS')

        files.write_success_file()
    except errors.ClientError as e:

        failure_message = str(e)
        files.write_failure_file(failure_message)

        logger.error(failure_message)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:
        failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e))

        files.write_failure_file(failure_msg)
        logger.error('Reporting training FAILURE')

        logger.error(failure_msg)

        exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

        _exit_processes(exit_code)
示例#2
0
def framework_training_with_script_mode_fn(capture_error):
    training_env = sagemaker_containers.training_env()

    entry_point.run(training_env.module_dir,
                    training_env.user_entry_point,
                    training_env.to_cmd_args(),
                    training_env.to_env_vars(),
                    capture_error=capture_error)
示例#3
0
def train():
    try:
        # TODO: iquintero - add error handling for ImportError to let the user know
        # if the framework module is not defined.
        env = sagemaker_containers.training_env()

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(':')

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing the framework to
            # configure logging at import time.
            logging.configure_logger(env.log_level)
            logger.info('Imported framework %s', framework_name)

            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            logging.configure_logger(env.log_level)
            entry_point.run(env.module_dir, env.user_entry_point,
                            env.to_cmd_args(), env.to_env_vars())

        logger.info('Reporting training SUCCESS')
        files.write_success_file()
        _exit_processes(SUCCESS_CODE)

    except errors.ClientError as e:

        failure_message = str(e)
        files.write_failure_file(failure_message)

        logger.error(failure_message)
        _exit_processes(DEFAULT_FAILURE_CODE)
    except Exception as e:
        failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(),
                                                     str(e))

        files.write_failure_file(failure_msg)
        logger.error('Reporting training FAILURE')

        logger.error(failure_msg)

        exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE)
        _exit_processes(exit_code)
示例#4
0
def test_parameter_server():
    module = test.UserModule(test.File(name='user_script.py', data=PARAMETER_SERVER_SCRIPT))
    hyperparameters = dict(sagemaker_program='user_script.py')

    test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[test.Channel.create(name='training')])
    training_env = sagemaker_containers.training_env()
    process = entry_point.run(training_env.module_dir, training_env.user_entry_point,
                              training_env.to_cmd_args(), training_env.to_env_vars(), wait=False)
    # confirm the ps process is still hanging
    assert process.poll() is None
    process.kill()