def test_get_channel_dir(training): with patch('os.path.exists') as patched: patched.return_value = True env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training", "blah/blah") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation", "xxx/yyy")
def test_get_channel_dir_after_ease_fix(training): with patch('os.path.exists') as patched: patched.return_value = False env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation")
def test_pip_install_requirements_training(subprocess_call, path_exists, training): env = TrainingEnvironment(training) path_exists.return_value = True env.pip_install_requirements() subprocess_call.assert_called_with([ 'pip', 'install', '-r', os.path.join(training, 'code', 'requirements.txt') ])
def test_download_user_module(untar, download_s3, gettemp, training): env = TrainingEnvironment(training) gettemp.return_value = 'tmp' env.user_script_archive = 'test.gz' env.download_user_module() download_s3.assert_called_with('test.gz', 'tmp/script.tar.gz') untar.assert_called_with('tmp/script.tar.gz', os.path.join(training, 'code'))
def start(cls): base_dir = None exit_code = 0 cs.configure_logging() logger.info("Training starting") try: env = TrainingEnvironment() env.start_metrics_if_enabled() base_dir = env.base_dir fw = TrainingEnvironment.load_framework() fw.train() env.write_success_file() except Exception as e: trc = traceback.format_exc() message = 'uncaught exception during training: {}\n{}\n'.format( e, trc) logger.error(message) TrainingEnvironment.write_failure_file(message, base_dir) exit_code = 1 if not hasattr(e, 'errno') else e.errno raise e finally: # Since threads in Python cannot be stopped, this is the only way to stop the application # https://stackoverflow.com/questions/9591350/what-is-difference-between-sys-exit0-and-os-exit0 os._exit(exit_code)
def test_get_channel_dir_no_s3_uri_in_hp(training): with patch('os.path.exists') as patched: patched.return_value = True _write_config_file( training, 'hyperparameters.json', _serialize_hyperparameters({ "sagemaker_s3_uri_training": "blah/blah", "sagemaker_region": "us-west-2" })) env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training", "blah/blah") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation")
def test_training_environment_get_env_variables(training): with patch('os.path.exists') as patched: patched.return_value = True env = TrainingEnvironment(training) assert os.environ[ContainerEnvironment.JOB_NAME_ENV] == "my_job_name" assert os.environ[ ContainerEnvironment.CURRENT_HOST_ENV] == env.current_host
def test_training_job_name(training): env = TrainingEnvironment(training) assert env.job_name == 'training_job_name'
def test_user_requirements_file_training(training): env = TrainingEnvironment(training) assert env.user_requirements_file == 'requirements.txt'
def test_user_script_name_training(training): env = TrainingEnvironment(training) assert env.user_script_name == "myscript.py"
def test_user_script_archive_training(training): env = TrainingEnvironment(training) assert env.user_script_archive == "s3://mybucket/code.tar.gz"
def test_hosts_single(training): env = TrainingEnvironment(training) assert env.hosts == ['algo-1']
def test_current_host(training): env = TrainingEnvironment(training) assert env.current_host == 'algo-1'
def test_current_host_unset(training): _write_resource_config(training, '', []) env = TrainingEnvironment(training) assert env.current_host == ""
def test_channels(training): env = TrainingEnvironment(training) assert env.channels == INPUT_DATA_CONFIG
def test_import_user_module(import_module, training): env = TrainingEnvironment(training) env.import_user_module() import_module.assert_called_with('myscript')
def test_hosts_unset(training): _write_resource_config(training, '', []) env = TrainingEnvironment(training) assert env.hosts == []
def test_import_user_module_without_py(import_module, training): env = TrainingEnvironment(training) env.user_script_name = 'nopy' env.import_user_module() import_module.assert_called_with('nopy')
def test_hosts(training): hosts = ['algo-1', 'algo-2', 'algo-3'] _write_resource_config(training, 'algo-1', hosts) env = TrainingEnvironment(training) assert env.hosts == hosts