def test_task_run(self) -> None: """Test creation and querying of task_runs""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) # Check creation and retrieval of a task_run init_params = json.dumps( OmegaConf.to_yaml(TaskConfig.get_mock_params())) task_run_id = db.new_task_run(task_id, requester_id, init_params, "mock", "mock") self.assertIsNotNone(task_run_id) self.assertTrue(isinstance(task_run_id, str)) task_run_row = db.get_task_run(task_run_id) self.assertEqual(task_run_row["init_params"], init_params) task_run = TaskRun(db, task_run_id) self.assertEqual(task_run.task_id, task_id) # Check finding for task_runs task_runs = db.find_task_runs() self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) # Check finding for specific task_runs task_runs = db.find_task_runs(task_id=task_id) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) task_runs = db.find_task_runs(requester_id=requester_id) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) task_runs = db.find_task_runs(task_id=self.get_fake_id("TaskRun")) self.assertEqual(len(task_runs), 0) task_runs = db.find_task_runs(is_completed=True) self.assertEqual(len(task_runs), 0) # Test updating the completion status, requery db.update_task_run(task_run_id, True) task_runs = db.find_task_runs(is_completed=True) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id)
def test_task_run_fails(self) -> None: """Ensure task_runs fail to be created or loaded under failure conditions""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) init_params = TaskConfig.get_mock_params() # Can't create task run with invalid ids with self.assertRaises(EntryDoesNotExistException): task_run_id = db.new_task_run(self.get_fake_id("Task"), requester_id, init_params, "mock", "mock") with self.assertRaises(EntryDoesNotExistException): task_run_id = db.new_task_run(task_id, self.get_fake_id("Requester"), init_params, "mock", "mock") # Ensure no task_runs were created task_runs = db.find_task_runs() self.assertEqual(len(task_runs), 0)
def test_update_task_failures(self) -> None: """Ensure failure conditions trigger for updating tasks""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name = "test_task" task_type = "mock" task_id = db.new_task(task_name, task_type) task_name_2 = "test_task_2" task_id_2 = db.new_task(task_name_2, task_type) task_name_3 = "test_task_3" # Can't update a task to existing name with self.assertRaises(EntryAlreadyExistsException): db.update_task(task_id_2, task_name=task_name) # Can't update to an invalid name with self.assertRaises(MephistoDBException): db.update_task(task_id_2, task_name="") # Can't update to a nonexistent project id with self.assertRaises(EntryDoesNotExistException): fake_id = self.get_fake_id("Project") db.update_task(task_id_2, project_id=fake_id) # can update a task though db.update_task(task_id_2, task_name=task_name_3) # But not after we've created a task run requester_name, requester_id = get_test_requester(db) init_params = json.dumps( OmegaConf.to_yaml(TaskConfig.get_mock_params())) task_run_id = db.new_task_run(task_id_2, requester_id, init_params, "mock", "mock") with self.assertRaises(MephistoDBException): db.update_task(task_id_2, task_name=task_name_2)
def setUp(self): self.data_dir = tempfile.mkdtemp() database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(database_path) self.requester_name, _req_id = get_test_requester(self.db) self.operator = None