def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, training_env, set_level, train, logger, single_machine_training_env): training_env.return_value = single_machine_training_env os.environ['SAGEMAKER_REGION'] = REGION training.main() configure_s3_env.assert_called_once_with('/opt/ml/model', REGION)
def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_env, set_level, train, logger, single_machine_training_env): training_env.return_value = single_machine_training_env os.environ['SAGEMAKER_REGION'] = REGION training.main() expected_model_dir = '{}/{}/model'.format(MODEL_DIR, single_machine_training_env.job_name) configure_s3_env.assert_called_once_with(expected_model_dir, REGION)
def test_main(configure_s3_env, read_hyperparameters, training_env, set_level, train, logger, single_machine_training_env): training_env.return_value = single_machine_training_env os.environ['SAGEMAKER_REGION'] = REGION training.main() read_hyperparameters.assert_called_once_with() training_env.assert_called_once_with(hyperparameters={}) train.assert_called_once_with(single_machine_training_env, MODEL_DIR_CMD_LIST) configure_s3_env.assert_called_once()
def test_main_simple_training_model_dir( configure_s3_env, read_hyperparameters, training_env, set_level, train, logger, single_machine_training_env, ): training_env.return_value = single_machine_training_env os.environ["SAGEMAKER_REGION"] = REGION training.main() configure_s3_env.assert_called_once_with(MODEL_DIR, REGION)