def test_non_write_ignored(process_mock, upload_file, inotify_mock, copy2): process = process_mock.return_value inotify = inotify_mock.return_value inotify.add_watch.return_value = 1 mask = flags.CREATE for flag in flags: if flag is not flags.CLOSE_WRITE and flag is not flags.ISDIR: mask = mask | flag inotify.read.return_value = [Event(1, mask, 'cookie', 'file_name')] def watch(): call = process_mock.call_args args, kwargs = call _intermediate_output._watch(kwargs['args'][0], kwargs['args'][1], kwargs['args'][2], kwargs['args'][3]) process.start.side_effect = watch _files.write_success_file() _intermediate_output.start_sync(S3_BUCKET, REGION) inotify.add_watch.assert_called() inotify.read.assert_called() copy2.assert_not_called() upload_file.assert_not_called()
def test_new_folders_are_watched(process_mock, upload_file, inotify_mock, copy2): process = process_mock.return_value inotify = inotify_mock.return_value new_dir = 'new_dir' new_dir_path = os.path.join(_env.output_intermediate_dir, new_dir) inotify.add_watch.return_value = 1 inotify.read.return_value = [Event(1, flags.CREATE | flags.ISDIR, 'cookie', new_dir)] def watch(): os.makedirs(new_dir_path) call = process_mock.call_args args, kwargs = call _intermediate_output._watch(kwargs['args'][0], kwargs['args'][1], kwargs['args'][2], kwargs['args'][3]) process.start.side_effect = watch _files.write_success_file() _intermediate_output.start_sync(S3_BUCKET, REGION) watch_flags = flags.CLOSE_WRITE | flags.CREATE inotify.add_watch.assert_any_call(_env.output_intermediate_dir, watch_flags) inotify.add_watch.assert_any_call(new_dir_path, watch_flags) inotify.read.assert_called() copy2.assert_not_called() upload_file.assert_not_called()
def test_modification_triggers_upload(process_mock, upload_file, inotify_mock, copy2): process = process_mock.return_value inotify = inotify_mock.return_value inotify.add_watch.return_value = 'wd' inotify.read.return_value = [ Event('wd', flags.CLOSE_WRITE, 'cookie', 'file_name') ] def watch(): call = process_mock.call_args args, kwargs = call _intermediate_output._watch(kwargs['args'][0], kwargs['args'][1], kwargs['args'][2], kwargs['args'][3]) process.start.side_effect = watch _files.write_success_file() _intermediate_output.start_sync(S3_BUCKET, REGION) inotify.add_watch.assert_called() inotify.read.assert_called() copy2.assert_called() upload_file.assert_called()
def train(): try: # TODO: iquintero - add error handling for ImportError to let the user know # if the framework module is not defined. env = sagemaker_containers.training_env() framework_name, entry_point_name = env.framework_module.split(':') try: framework = importlib.import_module(framework_name) except: logger.error("Import failure loading %s. sys.path=%s" % (framework_name, sys.path)) raise # the logger is configured after importing the framework library, allowing the framework to # configure logging at import time. _logging.configure_logger(env.log_level) logger.info('Imported framework %s', framework_name) entry_point = getattr(framework, entry_point_name) _modules.write_env_vars(env.to_env_vars()) entry_point() logger.info('Reporting training SUCCESS') _files.write_success_file() _exit_processes(SUCCESS_CODE) except _errors.ClientError as e: failure_message = str(e) _files.write_failure_file(failure_message) logger.error(failure_message) _exit_processes(DEFAULT_FAILURE_CODE) except Exception as e: failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e)) _files.write_failure_file(failure_msg) logger.error('Reporting training FAILURE') logger.error(failure_msg) exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE) _exit_processes(exit_code)
def test_nested_delayed_file(): os.environ["TRAINING_JOB_NAME"] = _timestamp() p = _intermediate_output.start_sync(bucket_uri, region) os.makedirs(os.path.join(intermediate_path, "dir1")) dir1 = os.path.join(intermediate_path, "dir1") time.sleep(3) os.makedirs(os.path.join(dir1, "dir2")) dir2 = os.path.join(dir1, "dir2") time.sleep(3) file1 = os.path.join(dir2, "file1.txt") write_file(file1, "file1") os.makedirs(os.path.join(intermediate_path, "dir3")) dir3 = os.path.join(intermediate_path, "dir3") time.sleep(3) file2 = os.path.join(dir3, "file2.txt") write_file(file2, "file2") _files.write_success_file() p.join() # assert that all files that should be under intermediate are still there assert os.path.exists(file1) assert os.path.exists(file2) # assert file exist in S3 key_prefix = os.path.join(os.environ.get("TRAINING_JOB_NAME"), "output", "intermediate") client = boto3.client("s3", region) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file1, intermediate_path))) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file2, intermediate_path)))
def train(): intermediate_sync = None exit_code = SUCCESS_CODE try: # TODO: iquintero - add error handling for ImportError to let the user know # if the framework module is not defined. env = sagemaker_containers.training_env() # TODO: [issue#144] There is a bug in the logic - # we need os.environ.get(_params.REGION_NAME_ENV) # in certain regions, but it is not going to be available unless # TrainingEnvironment has been initialized. It shouldn't be environment variable. region = os.environ.get('AWS_REGION', os.environ.get(_params.REGION_NAME_ENV)) s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL") intermediate_sync = _intermediate_output.start_sync(env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(':') framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing the framework to # configure logging at import time. _logging.configure_logger(env.log_level) logger.info('Imported framework %s', framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: _logging.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get(_params.MPI_ENABLED) runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner=runner_type) logger.info('Reporting training SUCCESS') _files.write_success_file() except _errors.ClientError as e: failure_message = str(e) _files.write_failure_file(failure_message) logger.error(failure_message) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: failure_msg = 'framework error: \n%s\n%s' % (traceback.format_exc(), str(e)) _files.write_failure_file(failure_msg) logger.error('Reporting training FAILURE') logger.error(failure_msg) exit_code = getattr(e, 'errno', DEFAULT_FAILURE_CODE) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)
def test_write_success_file(): file_path = os.path.join(_env.output_dir, 'success') empty_msg = '' _files.write_success_file() open.assert_called_with(file_path, 'w') open().write.assert_called_with(empty_msg)
def test_intermediate_upload(): os.environ["TRAINING_JOB_NAME"] = _timestamp() p = _intermediate_output.start_sync(bucket_uri, region) file1 = os.path.join(intermediate_path, "file1.txt") write_file(file1, "file1!") os.makedirs(os.path.join(intermediate_path, "dir1", "dir2", "dir3")) dir1 = os.path.join(intermediate_path, "dir1") dir2 = os.path.join(dir1, "dir2") dir3 = os.path.join(dir2, "dir3") file2 = os.path.join(dir1, "file2.txt") file3 = os.path.join(dir2, "file3.txt") file4 = os.path.join(dir3, "file4.txt") write_file(file2, "dir1_file2!") write_file(file3, "dir2_file3!") write_file(file4, "dir1_file4!") dir_to_delete1 = os.path.join(dir1, "dir4") file_to_delete1 = os.path.join(dir_to_delete1, "file_to_delete1.txt") os.makedirs(dir_to_delete1) write_file(file_to_delete1, "file_to_delete1!") os.remove(file_to_delete1) os.removedirs(dir_to_delete1) file_to_delete2_but_copy = os.path.join(intermediate_path, "file_to_delete2_but_copy.txt") write_file(file_to_delete2_but_copy, "file_to_delete2!") time.sleep(1) os.remove(file_to_delete2_but_copy) file_to_modify1 = os.path.join(dir3, "file_to_modify1.txt") write_file(file_to_modify1, "dir3_file_to_modify1_1!") write_file(file_to_modify1, "dir3_file_to_modify1_2!") write_file(file_to_modify1, "dir3_file_to_modify1_3!") content_to_assert = "dir3_file_to_modify1_4!" write_file(file_to_modify1, content_to_assert) # the last file to be moved file5 = os.path.join(intermediate_path, "file5.txt") write_file(file5, "file5!") _files.write_success_file() p.join() # shouldn't be moved file6 = os.path.join(intermediate_path, "file6.txt") write_file(file6, "file6!") # assert that all files that should be under intermediate are still there assert os.path.exists(file1) assert os.path.exists(file2) assert os.path.exists(file3) assert os.path.exists(file4) assert os.path.exists(file5) assert os.path.exists(file6) assert os.path.exists(file_to_modify1) # and all the deleted folders and files aren't there assert not os.path.exists(dir_to_delete1) assert not os.path.exists(file_to_delete1) assert not os.path.exists(file_to_delete2_but_copy) # assert files exist in S3 key_prefix = os.path.join(os.environ.get("TRAINING_JOB_NAME"), "output", "intermediate") client = boto3.client("s3", region) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file1, intermediate_path)) ) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file2, intermediate_path)) ) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file3, intermediate_path)) ) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file4, intermediate_path)) ) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file5, intermediate_path)) ) assert _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file_to_modify1, intermediate_path)) ) deleted_file = os.path.join( key_prefix, os.path.relpath(file_to_delete2_but_copy, intermediate_path) ) assert _file_exists_in_s3(client, deleted_file) assert not _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(dir_to_delete1, intermediate_path)) ) assert not _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file_to_delete1, intermediate_path)) ) assert not _file_exists_in_s3( client, os.path.join(key_prefix, os.path.relpath(file6, intermediate_path)) ) # check that modified file has s3 = boto3.resource("s3", region_name=region) key = os.path.join(key_prefix, os.path.relpath(file_to_modify1, intermediate_path)) modified_file = os.path.join(_env.output_dir, "modified_file.txt") s3.Bucket(bucket).download_file(key, modified_file) with open(modified_file) as f: content = f.read() assert content == content_to_assert
def test_write_success_file(): file_path = os.path.join(_env.output_dir, "success") empty_msg = "" _files.write_success_file() open.assert_called_with(file_path, "w") open().write.assert_called_with(empty_msg)
def train(): """Placeholder docstring""" intermediate_sync = None exit_code = SUCCESS_CODE try: env = sagemaker_containers.training_env() region = os.environ.get("AWS_REGION", os.environ.get(_params.REGION_NAME_ENV)) s3_endpoint_url = os.environ.get(_params.S3_ENDPOINT_URL, None) intermediate_sync = _intermediate_output.start_sync( env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url) if env.framework_module: framework_name, entry_point_name = env.framework_module.split(":") framework = importlib.import_module(framework_name) # the logger is configured after importing the framework library, allowing # the framework to configure logging at import time. _logging.configure_logger(env.log_level) logger.info("Imported framework %s", framework_name) entrypoint = getattr(framework, entry_point_name) entrypoint() else: _logging.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get( _params.MPI_ENABLED) runner_type = _runner.RunnerType.MPI if mpi_enabled else _runner.RunnerType.Process entry_point.run( env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner=runner_type, ) logger.info("Reporting training SUCCESS") _files.write_success_file() except _errors.ClientError as e: failure_message = str(e) _files.write_failure_file(failure_message) logger.error(failure_message) if intermediate_sync: intermediate_sync.join() exit_code = DEFAULT_FAILURE_CODE except Exception as e: # pylint: disable=broad-except failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(), str(e)) _files.write_failure_file(failure_msg) logger.error("Reporting training FAILURE") logger.error(failure_msg) error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE) exit_code = _get_valid_failure_exit_code(error_number) finally: if intermediate_sync: intermediate_sync.join() _exit_processes(exit_code)