def train(): intermediate_sync = None exit_code = SUCCESS_CODE try: # TODO: iquintero - add error handling for ImportError to let the user know # if the framework module is not defined. env = sagemaker_containers.training_env() # TODO: [issue#144] There is a bug in the logic - # we need os.environ.get(_params.REGION_NAME_ENV) # in certain regions, but it is not going to be available unless # TrainingEnvironment has been initialized. It shouldn't be environment variable. region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV)) intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(':') framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing the framework to # configure logging at import time. logging.configure_logger(env.log_level) logger.info('Imported framework %s', framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: logging.configure_logger(env.log_level) entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars()) logger.info('Reporting training SUCCESS') files.write_success_file() except errors.ClientError as e: failure_message = str(e) files.write_failure_file(failure_message) logger.error(failure_message) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e)) files.write_failure_file(failure_msg) logger.error('Reporting training FAILURE') logger.error(failure_msg) exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)
def framework_training_with_script_mode_fn(capture_error): training_env = sagemaker_containers.training_env() entry_point.run(training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), capture_error=capture_error)
def train(): try: # TODO: iquintero - add error handling for ImportError to let the user know # if the framework module is not defined. env = sagemaker_containers.training_env() if env.framework_module: framework_name, entry_point_name = env.framework_module.split(':') framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing the framework to # configure logging at import time. logging.configure_logger(env.log_level) logger.info('Imported framework %s', framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: logging.configure_logger(env.log_level) entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars()) logger.info('Reporting training SUCCESS') files.write_success_file() _exit_processes(SUCCESS_CODE) except errors.ClientError as e: failure_message = str(e) files.write_failure_file(failure_message) logger.error(failure_message) _exit_processes(DEFAULT_FAILURE_CODE) except Exception as e: failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e)) files.write_failure_file(failure_msg) logger.error('Reporting training FAILURE') logger.error(failure_msg) exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE) _exit_processes(exit_code)
def test_parameter_server(): module = test.UserModule(test.File(name='user_script.py', data=PARAMETER_SERVER_SCRIPT)) hyperparameters = dict(sagemaker_program='user_script.py') test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[test.Channel.create(name='training')]) training_env = sagemaker_containers.training_env() process = entry_point.run(training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), wait=False) # confirm the ps process is still hanging assert process.poll() is None process.kill()