コード例 #1
0
def test_get_channel_dir(training):
    with patch('os.path.exists') as patched:
        patched.return_value = True
        env = TrainingEnvironment(training)
        assert env._get_channel_dir("training") == os.path.join(
            training, "input", "data", "training", "blah/blah")
        assert env._get_channel_dir("validation") == os.path.join(
            training, "input", "data", "validation", "xxx/yyy")
コード例 #2
0
def test_get_channel_dir_after_ease_fix(training):
    with patch('os.path.exists') as patched:
        patched.return_value = False
        env = TrainingEnvironment(training)
        assert env._get_channel_dir("training") == os.path.join(
            training, "input", "data", "training")
        assert env._get_channel_dir("validation") == os.path.join(
            training, "input", "data", "validation")
コード例 #3
0
def test_pip_install_requirements_training(subprocess_call, path_exists,
                                           training):
    env = TrainingEnvironment(training)
    path_exists.return_value = True

    env.pip_install_requirements()
    subprocess_call.assert_called_with([
        'pip', 'install', '-r',
        os.path.join(training, 'code', 'requirements.txt')
    ])
コード例 #4
0
def test_download_user_module(untar, download_s3, gettemp, training):
    env = TrainingEnvironment(training)
    gettemp.return_value = 'tmp'
    env.user_script_archive = 'test.gz'

    env.download_user_module()

    download_s3.assert_called_with('test.gz', 'tmp/script.tar.gz')
    untar.assert_called_with('tmp/script.tar.gz',
                             os.path.join(training, 'code'))
コード例 #5
0
    def start(cls):
        base_dir = None
        exit_code = 0
        cs.configure_logging()
        logger.info("Training starting")
        try:
            env = TrainingEnvironment()
            env.start_metrics_if_enabled()
            base_dir = env.base_dir

            fw = TrainingEnvironment.load_framework()
            fw.train()
            env.write_success_file()
        except Exception as e:
            trc = traceback.format_exc()
            message = 'uncaught exception during training: {}\n{}\n'.format(
                e, trc)
            logger.error(message)
            TrainingEnvironment.write_failure_file(message, base_dir)
            exit_code = 1 if not hasattr(e, 'errno') else e.errno
            raise e
        finally:
            # Since threads in Python cannot be stopped, this is the only way to stop the application
            # https://stackoverflow.com/questions/9591350/what-is-difference-between-sys-exit0-and-os-exit0
            os._exit(exit_code)
コード例 #6
0
def test_get_channel_dir_no_s3_uri_in_hp(training):
    with patch('os.path.exists') as patched:
        patched.return_value = True
        _write_config_file(
            training, 'hyperparameters.json',
            _serialize_hyperparameters({
                "sagemaker_s3_uri_training": "blah/blah",
                "sagemaker_region": "us-west-2"
            }))
        env = TrainingEnvironment(training)
        assert env._get_channel_dir("training") == os.path.join(
            training, "input", "data", "training", "blah/blah")
        assert env._get_channel_dir("validation") == os.path.join(
            training, "input", "data", "validation")
コード例 #7
0
def test_training_environment_get_env_variables(training):
    with patch('os.path.exists') as patched:
        patched.return_value = True
        env = TrainingEnvironment(training)
        assert os.environ[ContainerEnvironment.JOB_NAME_ENV] == "my_job_name"
        assert os.environ[
            ContainerEnvironment.CURRENT_HOST_ENV] == env.current_host
コード例 #8
0
def test_training_job_name(training):
    env = TrainingEnvironment(training)
    assert env.job_name == 'training_job_name'
コード例 #9
0
def test_user_requirements_file_training(training):
    env = TrainingEnvironment(training)
    assert env.user_requirements_file == 'requirements.txt'
コード例 #10
0
def test_user_script_name_training(training):
    env = TrainingEnvironment(training)
    assert env.user_script_name == "myscript.py"
コード例 #11
0
def test_user_script_archive_training(training):
    env = TrainingEnvironment(training)
    assert env.user_script_archive == "s3://mybucket/code.tar.gz"
コード例 #12
0
def test_hosts_single(training):
    env = TrainingEnvironment(training)
    assert env.hosts == ['algo-1']
コード例 #13
0
def test_current_host(training):
    env = TrainingEnvironment(training)
    assert env.current_host == 'algo-1'
コード例 #14
0
def test_current_host_unset(training):
    _write_resource_config(training, '', [])
    env = TrainingEnvironment(training)
    assert env.current_host == ""
コード例 #15
0
def test_channels(training):
    env = TrainingEnvironment(training)
    assert env.channels == INPUT_DATA_CONFIG
コード例 #16
0
def test_import_user_module(import_module, training):
    env = TrainingEnvironment(training)
    env.import_user_module()
    import_module.assert_called_with('myscript')
コード例 #17
0
def test_hosts_unset(training):
    _write_resource_config(training, '', [])
    env = TrainingEnvironment(training)
    assert env.hosts == []
コード例 #18
0
def test_import_user_module_without_py(import_module, training):
    env = TrainingEnvironment(training)
    env.user_script_name = 'nopy'
    env.import_user_module()
    import_module.assert_called_with('nopy')
コード例 #19
0
def test_hosts(training):
    hosts = ['algo-1', 'algo-2', 'algo-3']
    _write_resource_config(training, 'algo-1', hosts)
    env = TrainingEnvironment(training)
    assert env.hosts == hosts