コード例 #1
0
    def test_task_run(self) -> None:
        """Test creation and querying of task_runs"""
        assert self.db is not None, "No db initialized"
        db: MephistoDB = self.db

        task_name, task_id = get_test_task(db)
        requester_name, requester_id = get_test_requester(db)

        # Check creation and retrieval of a task_run
        init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params()))
        task_run_id = db.new_task_run(
            task_id, requester_id, init_params, "mock", "mock"
        )
        self.assertIsNotNone(task_run_id)
        self.assertTrue(isinstance(task_run_id, str))
        task_run_row = db.get_task_run(task_run_id)
        self.assertEqual(task_run_row["init_params"], init_params)
        task_run = TaskRun(db, task_run_id)
        self.assertEqual(task_run.task_id, task_id)

        # Check finding for task_runs
        task_runs = db.find_task_runs()
        self.assertEqual(len(task_runs), 1)
        self.assertTrue(isinstance(task_runs[0], TaskRun))
        self.assertEqual(task_runs[0].db_id, task_run_id)
        self.assertEqual(task_runs[0].task_id, task_id)
        self.assertEqual(task_runs[0].requester_id, requester_id)

        # Check finding for specific task_runs
        task_runs = db.find_task_runs(task_id=task_id)
        self.assertEqual(len(task_runs), 1)
        self.assertTrue(isinstance(task_runs[0], TaskRun))
        self.assertEqual(task_runs[0].db_id, task_run_id)
        self.assertEqual(task_runs[0].task_id, task_id)
        self.assertEqual(task_runs[0].requester_id, requester_id)

        task_runs = db.find_task_runs(requester_id=requester_id)
        self.assertEqual(len(task_runs), 1)
        self.assertTrue(isinstance(task_runs[0], TaskRun))
        self.assertEqual(task_runs[0].db_id, task_run_id)
        self.assertEqual(task_runs[0].task_id, task_id)
        self.assertEqual(task_runs[0].requester_id, requester_id)

        task_runs = db.find_task_runs(task_id=self.get_fake_id("TaskRun"))
        self.assertEqual(len(task_runs), 0)

        task_runs = db.find_task_runs(is_completed=True)
        self.assertEqual(len(task_runs), 0)

        # Test updating the completion status, requery
        db.update_task_run(task_run_id, True)
        task_runs = db.find_task_runs(is_completed=True)
        self.assertEqual(len(task_runs), 1)
        self.assertTrue(isinstance(task_runs[0], TaskRun))
        self.assertEqual(task_runs[0].db_id, task_run_id)
コード例 #2
0
    def setup_resources_for_task_run(
        self,
        task_run: "TaskRun",
        args: "DictConfig",
        shared_state: "SharedTaskState",
        server_url: str,
    ) -> None:
        """
        Set up SNS queue to recieve agent events from MTurk, and produce the
        HIT type for this task run.
        """
        requester = cast("MTurkRequester", task_run.get_requester())
        session = self.datastore.get_session_for_requester(
            requester._requester_name)
        task_config = task_run.get_task_config()

        # Set up SNS queue
        # TODO(OWN) implement arn?
        task_run_id = task_run.db_id
        # task_name = task_run.get_task().task_name
        # arn_id = setup_sns_topic(session, task_name, server_url, task_run_id)
        arn_id = "TEST"

        # Set up HIT config
        config_dir = os.path.join(self.datastore.datastore_root, task_run_id)
        task_config = TaskConfig(task_run)

        # Find or create relevant qualifications
        qualifications = []
        for qualification in shared_state.qualifications:
            applicable_providers = qualification["applicable_providers"]
            if (applicable_providers is None
                    or self.PROVIDER_TYPE in applicable_providers):
                qualifications.append(qualification)
        for qualification in qualifications:
            qualification_name = qualification["qualification_name"]
            if requester.PROVIDER_TYPE == "mturk_sandbox":
                qualification_name += "_sandbox"
            if self.datastore.get_qualification_mapping(
                    qualification_name) is None:
                qualification[
                    "QualificationTypeId"] = requester._create_new_mturk_qualification(
                        qualification_name)

        if hasattr(shared_state, "mturk_specific_qualifications"):
            qualifications += shared_state.mturk_specific_qualifications

        # Set up HIT type
        client = self._get_client(requester._requester_name)
        hit_type_id = create_hit_type(client, task_config, qualifications)
        self.datastore.register_run(task_run_id, arn_id, hit_type_id,
                                    config_dir)
コード例 #3
0
ファイル: operator.py プロジェクト: chateval/Mephisto
    def _parse_args_from_classes(
        BlueprintClass: Type["Blueprint"],
        ArchitectClass: Type["Architect"],
        CrowdProviderClass: Type["CrowdProvider"],
        argument_list: List[str],
    ) -> Tuple[Dict[str, Any], List[str]]:
        """Parse the given arguments over the parsers for the given types"""
        # Create the parser
        parser = ArgumentParser()
        blueprint_group = parser.add_argument_group("blueprint")
        BlueprintClass.add_args_to_group(blueprint_group)
        provider_group = parser.add_argument_group("crowd_provider")
        CrowdProviderClass.add_args_to_group(provider_group)
        architect_group = parser.add_argument_group("architect")
        ArchitectClass.add_args_to_group(architect_group)
        task_group = parser.add_argument_group("task_config")
        TaskConfig.add_args_to_group(task_group)

        # Return parsed args
        try:
            known, unknown = parser.parse_known_args(argument_list)
        except SystemExit:
            raise Exception("Argparse broke - must fix")
        return vars(known), unknown
コード例 #4
0
    def test_task_run_fails(self) -> None:
        """Ensure task_runs fail to be created or loaded under failure conditions"""
        assert self.db is not None, "No db initialized"
        db: MephistoDB = self.db

        task_name, task_id = get_test_task(db)
        requester_name, requester_id = get_test_requester(db)
        init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params()))

        # Can't create task run with invalid ids
        with self.assertRaises(EntryDoesNotExistException):
            task_run_id = db.new_task_run(
                self.get_fake_id("Task"), requester_id, init_params, "mock", "mock"
            )
        with self.assertRaises(EntryDoesNotExistException):
            task_run_id = db.new_task_run(
                task_id, self.get_fake_id("Requester"), init_params, "mock", "mock"
            )

        # Ensure no task_runs were created
        task_runs = db.find_task_runs()
        self.assertEqual(len(task_runs), 0)
コード例 #5
0
    def test_update_task_failures(self) -> None:
        """Ensure failure conditions trigger for updating tasks"""
        assert self.db is not None, "No db initialized"
        db: MephistoDB = self.db

        task_name = "test_task"
        task_type = "mock"
        task_id = db.new_task(task_name, task_type)

        task_name_2 = "test_task_2"
        task_id_2 = db.new_task(task_name_2, task_type)

        task_name_3 = "test_task_3"

        # Can't update a task to existing name
        with self.assertRaises(EntryAlreadyExistsException):
            db.update_task(task_id_2, task_name=task_name)

        # Can't update to an invalid name
        with self.assertRaises(MephistoDBException):
            db.update_task(task_id_2, task_name="")

        # Can't update to a nonexistent project id
        with self.assertRaises(EntryDoesNotExistException):
            fake_id = self.get_fake_id("Project")
            db.update_task(task_id_2, project_id=fake_id)

        # can update a task though
        db.update_task(task_id_2, task_name=task_name_3)

        # But not after we've created a task run
        requester_name, requester_id = get_test_requester(db)
        init_params = json.dumps(OmegaConf.to_yaml(TaskConfig.get_mock_params()))
        task_run_id = db.new_task_run(
            task_id_2, requester_id, init_params, "mock", "mock"
        )
        with self.assertRaises(MephistoDBException):
            db.update_task(task_id_2, task_name=task_name_2)
コード例 #6
0
 def get_task_config(self) -> "TaskConfig":
     if self.__task_config is None:
         self.__task_config = TaskConfig(self)
     return self.__task_config
コード例 #7
0
ファイル: utils.py プロジェクト: chateval/Mephisto
def get_test_task_run(db: MephistoDB) -> str:
    """Helper to create a task run for tests"""
    task_name, task_id = get_test_task(db)
    requester_name, requester_id = get_test_requester(db)
    init_params = TaskConfig.get_mock_params()
    return db.new_task_run(task_id, requester_id, init_params, "mock", "mock")