def test_task_run(self) -> None: """Test creation and querying of task_runs""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) # Check creation and retrieval of a task_run init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params())) task_run_id = db.new_task_run( task_id, requester_id, init_params, "mock", "mock" ) self.assertIsNotNone(task_run_id) self.assertTrue(isinstance(task_run_id, str)) task_run_row = db.get_task_run(task_run_id) self.assertEqual(task_run_row["init_params"], init_params) task_run = TaskRun(db, task_run_id) self.assertEqual(task_run.task_id, task_id) # Check finding for task_runs task_runs = db.find_task_runs() self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) # Check finding for specific task_runs task_runs = db.find_task_runs(task_id=task_id) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) task_runs = db.find_task_runs(requester_id=requester_id) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id) self.assertEqual(task_runs[0].task_id, task_id) self.assertEqual(task_runs[0].requester_id, requester_id) task_runs = db.find_task_runs(task_id=self.get_fake_id("TaskRun")) self.assertEqual(len(task_runs), 0) task_runs = db.find_task_runs(is_completed=True) self.assertEqual(len(task_runs), 0) # Test updating the completion status, requery db.update_task_run(task_run_id, True) task_runs = db.find_task_runs(is_completed=True) self.assertEqual(len(task_runs), 1) self.assertTrue(isinstance(task_runs[0], TaskRun)) self.assertEqual(task_runs[0].db_id, task_run_id)
def setup_resources_for_task_run( self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState", server_url: str, ) -> None: """ Set up SNS queue to recieve agent events from MTurk, and produce the HIT type for this task run. """ requester = cast("MTurkRequester", task_run.get_requester()) session = self.datastore.get_session_for_requester( requester._requester_name) task_config = task_run.get_task_config() # Set up SNS queue # TODO(OWN) implement arn? task_run_id = task_run.db_id # task_name = task_run.get_task().task_name # arn_id = setup_sns_topic(session, task_name, server_url, task_run_id) arn_id = "TEST" # Set up HIT config config_dir = os.path.join(self.datastore.datastore_root, task_run_id) task_config = TaskConfig(task_run) # Find or create relevant qualifications qualifications = [] for qualification in shared_state.qualifications: applicable_providers = qualification["applicable_providers"] if (applicable_providers is None or self.PROVIDER_TYPE in applicable_providers): qualifications.append(qualification) for qualification in qualifications: qualification_name = qualification["qualification_name"] if requester.PROVIDER_TYPE == "mturk_sandbox": qualification_name += "_sandbox" if self.datastore.get_qualification_mapping( qualification_name) is None: qualification[ "QualificationTypeId"] = requester._create_new_mturk_qualification( qualification_name) if hasattr(shared_state, "mturk_specific_qualifications"): qualifications += shared_state.mturk_specific_qualifications # Set up HIT type client = self._get_client(requester._requester_name) hit_type_id = create_hit_type(client, task_config, qualifications) self.datastore.register_run(task_run_id, arn_id, hit_type_id, config_dir)
def _parse_args_from_classes( BlueprintClass: Type["Blueprint"], ArchitectClass: Type["Architect"], CrowdProviderClass: Type["CrowdProvider"], argument_list: List[str], ) -> Tuple[Dict[str, Any], List[str]]: """Parse the given arguments over the parsers for the given types""" # Create the parser parser = ArgumentParser() blueprint_group = parser.add_argument_group("blueprint") BlueprintClass.add_args_to_group(blueprint_group) provider_group = parser.add_argument_group("crowd_provider") CrowdProviderClass.add_args_to_group(provider_group) architect_group = parser.add_argument_group("architect") ArchitectClass.add_args_to_group(architect_group) task_group = parser.add_argument_group("task_config") TaskConfig.add_args_to_group(task_group) # Return parsed args try: known, unknown = parser.parse_known_args(argument_list) except SystemExit: raise Exception("Argparse broke - must fix") return vars(known), unknown
def test_task_run_fails(self) -> None: """Ensure task_runs fail to be created or loaded under failure conditions""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params())) # Can't create task run with invalid ids with self.assertRaises(EntryDoesNotExistException): task_run_id = db.new_task_run( self.get_fake_id("Task"), requester_id, init_params, "mock", "mock" ) with self.assertRaises(EntryDoesNotExistException): task_run_id = db.new_task_run( task_id, self.get_fake_id("Requester"), init_params, "mock", "mock" ) # Ensure no task_runs were created task_runs = db.find_task_runs() self.assertEqual(len(task_runs), 0)
def test_update_task_failures(self) -> None: """Ensure failure conditions trigger for updating tasks""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db task_name = "test_task" task_type = "mock" task_id = db.new_task(task_name, task_type) task_name_2 = "test_task_2" task_id_2 = db.new_task(task_name_2, task_type) task_name_3 = "test_task_3" # Can't update a task to existing name with self.assertRaises(EntryAlreadyExistsException): db.update_task(task_id_2, task_name=task_name) # Can't update to an invalid name with self.assertRaises(MephistoDBException): db.update_task(task_id_2, task_name="") # Can't update to a nonexistent project id with self.assertRaises(EntryDoesNotExistException): fake_id = self.get_fake_id("Project") db.update_task(task_id_2, project_id=fake_id) # can update a task though db.update_task(task_id_2, task_name=task_name_3) # But not after we've created a task run requester_name, requester_id = get_test_requester(db) init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params())) task_run_id = db.new_task_run( task_id_2, requester_id, init_params, "mock", "mock" ) with self.assertRaises(MephistoDBException): db.update_task(task_id_2, task_name=task_name_2)
def get_task_config(self) -> "TaskConfig": if self.__task_config is None: self.__task_config = TaskConfig(self) return self.__task_config
def get_test_task_run(db: MephistoDB) -> str: """Helper to create a task run for tests""" task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) init_params = TaskConfig.get_mock_params() return db.new_task_run(task_id, requester_id, init_params, "mock", "mock")