def test_get_channel_dir_after_ease_fix(training): with patch('os.path.exists') as patched: patched.return_value = False env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation")
def test_get_channel_dir(training): with patch('os.path.exists') as patched: patched.return_value = True env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training", "blah/blah") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation", "xxx/yyy")
def test_get_channel_dir_no_s3_uri_in_hp(training): with patch('os.path.exists') as patched: patched.return_value = True _write_config_file( training, 'hyperparameters.json', _serialize_hyperparameters({ "sagemaker_s3_uri_training": "blah/blah", "sagemaker_region": "us-west-2" })) env = TrainingEnvironment(training) assert env._get_channel_dir("training") == os.path.join( training, "input", "data", "training", "blah/blah") assert env._get_channel_dir("validation") == os.path.join( training, "input", "data", "validation")