def test_train(import_module): framework = Mock() import_module.return_value = framework trainer.train() import_module.assert_called_with("my_framework") framework.entry_point.assert_called()
def test_train_no_intermediate(start_intermediate_folder_sync, import_module): framework = Mock() import_module.return_value = framework trainer.train() import_module.assert_called_with("my_framework") framework.entry_point.assert_called() start_intermediate_folder_sync.asser_not_called()
def test_train_fails_with_no_error_number(_exit, import_module): def fail(): raise Exception("No errno defined.") framework = Mock(entry_point=fail) import_module.return_value = framework trainer.train() _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE)
def test_train_fails(_exit, import_module): def fail(): raise OSError(errno.ENOENT, "No such file or directory") framework = Mock(entry_point=fail) import_module.return_value = framework trainer.train() _exit.assert_called_with(errno.ENOENT)
def test_train_with_success(_exit, import_module): def success(): pass framework = Mock(entry_point=success) import_module.return_value = framework trainer.train() _exit.assert_called_with(trainer.SUCCESS_CODE)
def test_train_with_client_error(_exit, import_module): def fail(): raise errors.ClientError(errno.ENOENT, "No such file or directory") framework = Mock(entry_point=fail) import_module.return_value = framework trainer.train() _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE)
def test_train_script(_exit, training_env, run): trainer.train() env = training_env() run.assert_called_with( env.module_dir, env.user_entry_point, env.to_cmd_args(), env.to_env_vars(), runner_type=runner.RunnerType.MPI, ) _exit.assert_called_with(trainer.SUCCESS_CODE)
def test_train_fails_with_invalid_error_number(_exit, import_module): class InvalidErrorNumberExceptionError(Exception): def __init__(self, *args, **kwargs): # real signature unknown self.errno = "invalid" def fail(): raise InvalidErrorNumberExceptionError("No such file or directory") framework = Mock(entry_point=fail) import_module.return_value = framework trainer.train() _exit.assert_called_with(trainer.DEFAULT_FAILURE_CODE)
def main(): """Calls the function that runs training in the container.""" trainer.train()