Ejemplo n.º 1
0
class TestOperator(unittest.TestCase):
    """
    Unit testing for the Mephisto Supervisor
    """
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.requester_name, _req_id = get_test_requester(self.db)
        self.operator = None

    def tearDown(self):
        if self.operator is not None:
            self.operator.shutdown()
        self.db.shutdown()
        shutil.rmtree(self.data_dir)
        self.assertTrue(
            len(threading.enumerate()) == 1,
            f"Expected only main thread at teardown, found {threading.enumerate()}",
        )

    def wait_for_complete_assignment(self, assignment, timeout: int):
        start_time = time.time()
        while time.time() - start_time < timeout:
            if assignment.get_status() == AssignmentState.COMPLETED:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, timeout,
                        "Assignment not completed in time")

    def await_server_start(self, architect: "MockArchitect"):
        start_time = time.time()
        assert architect.server is not None, "Cannot wait on empty server"
        while time.time() - start_time < 5:
            if len(architect.server.subs) > 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, 5,
                        "Mock server not up in time")

    def test_initialize_supervisor(self):
        """Quick test to ensure that the operator can be initialized"""
        self.operator = Operator(self.db)

    def test_run_job_concurrent(self):
        """Ensure that the supervisor object can even be created"""
        self.operator = Operator(self.db)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(
                num_assignments=1,
                is_concurrent=True,
            ),
            provider=MockProviderArgs(requester_name=self.requester_name, ),
            architect=MockArchitectArgs(should_run_server=True),
            task=MOCK_TASK_ARGS,
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id, mock_agent_details)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure the assignment is completed
        task_run = tracked_run.task_run
        assignment = task_run.get_assignments()[0]
        self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)

    def test_run_job_not_concurrent(self):
        """Ensure that the supervisor object can even be created"""
        self.operator = Operator(self.db)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(
                num_assignments=1,
                is_concurrent=False,
            ),
            provider=MockProviderArgs(requester_name=self.requester_name, ),
            architect=MockArchitectArgs(should_run_server=True),
            task=MOCK_TASK_ARGS,
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id, mock_agent_details)

        # Give up to 5 seconds for both tasks to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure the assignment is completed
        task_run = tracked_run.task_run
        assignment = task_run.get_assignments()[0]
        self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)

    def test_run_jobs_with_restrictions(self):
        """Ensure allowed_concurrent and maximum_units_per_worker work"""
        self.operator = Operator(self.db)
        provider_args = MockProviderArgs(requester_name=self.requester_name, )
        architect_args = MockArchitectArgs(should_run_server=True)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(
                num_assignments=3,
                is_concurrent=True,
            ),
            provider=provider_args,
            architect=architect_args,
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name='max-unit-test',
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        self.await_server_start(tracked_run.architect)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_1 = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Try to register a second agent, which should fail due to concurrency
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Second agent was created")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_2 = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Second agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[1].get_unit().get_assignment(), 3)

        # Pass a second task as well
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 3, "Agent was not created properly")
        mock_agent_details = "FAKE_ASSIGNMENT_4"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Fourth agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[3].get_unit().get_assignment(), 3)

        # Both workers should have saturated their tasks, and not be granted agents
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")

        # new workers should be able to work on these just fine though
        mock_worker_name = "MOCK_WORKER_3"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_3 = workers[0].db_id
        mock_worker_name = "MOCK_WORKER_4"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_4 = workers[0].db_id

        # Register agents from new workers
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 5, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_6"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 6, "Additional agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[5].get_unit().get_assignment(), 3)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure all assignments are completed
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)

        # Create a new task
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(
                num_assignments=1,
                is_concurrent=True,
            ),
            provider=MockProviderArgs(requester_name=self.requester_name, ),
            architect=MockArchitectArgs(should_run_server=True),
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name='max-unit-test',
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]
        self.await_server_start(tracked_run.architect)
        architect = tracked_run.architect

        # Workers one and two still shouldn't be able to make agents
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )

        # Three and four should though
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 7, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_8"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 8, "Additional agent was not created")

        # Ensure the task run completed and that all assignments are done
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)
Ejemplo n.º 2
0
    class TestAcuteEval(unittest.TestCase):
        """
        Test the ACUTE-Eval crowdsourcing task.
        """

        def setUp(self):
            self.operator = None

        def tearDown(self):
            if self.operator is not None:
                self.operator.shutdown()

        def test_base_task(self):

            # Define the configuration settings
            relative_task_directory = os.path.relpath(
                TASK_DIRECTORY, os.path.dirname(__file__)
            )
            relative_config_path = os.path.join(relative_task_directory, 'conf')
            with initialize(config_path=relative_config_path):
                self.config = compose(
                    config_name="example",
                    overrides=[
                        f'+mephisto.blueprint._blueprint_type={AcuteEvalBlueprint.BLUEPRINT_TYPE}',
                        f'+mephisto/architect=mock',
                        f'+mephisto/provider=mock',
                        f'mephisto.blueprint.block_on_onboarding_fail={False}',
                        f'+task_dir={TASK_DIRECTORY}',
                        f'+current_time={int(time.time())}',
                    ],
                )
                # TODO: when Hydra 1.1 is released with support for recursive defaults,
                #  don't manually specify all missing blueprint args anymore, but
                #  instead define the blueprint in the defaults list directly.
                #  Currently, the blueprint can't be set in the defaults list without
                #  overriding params in the YAML file, as documented at
                #  https://github.com/facebookresearch/hydra/issues/326 and as fixed in
                #  https://github.com/facebookresearch/hydra/pull/1044.

            self.data_dir = tempfile.mkdtemp()
            database_path = os.path.join(self.data_dir, "mephisto.db")
            self.db = LocalMephistoDB(database_path)
            self.config = augment_config_from_db(self.config, self.db)
            self.config.mephisto.architect.should_run_server = True
            self.operator = Operator(self.db)
            self.operator.validate_and_run_config(
                self.config.mephisto, shared_state=None
            )
            channel_info = list(self.operator.supervisor.channels.values())[0]
            server = channel_info.job.architect.server

            # Register a worker
            mock_worker_name = "MOCK_WORKER"
            server.register_mock_worker(mock_worker_name)
            workers = self.db.find_workers(worker_name=mock_worker_name)
            worker_id = workers[0].db_id

            # Register an agent
            mock_agent_details = "FAKE_ASSIGNMENT"
            server.register_mock_agent(worker_id, mock_agent_details)
            agent = self.db.find_agents()[0]
            agent_id_1 = agent.db_id

            # Set initial data
            server.request_init_data(agent_id_1)

            # Make agent act
            server.send_agent_act(
                agent_id_1, {"MEPHISTO_is_submit": True, "task_data": DESIRED_OUTPUTS}
            )

            # Check that the inputs and outputs are as expected
            agent = self.db.find_agents()[0]
            state = agent.state.get_data()
            self.assertEqual(DESIRED_INPUTS, state['inputs'])
            self.assertEqual(DESIRED_OUTPUTS, state['outputs'])
Ejemplo n.º 3
0
class TestSupervisor(unittest.TestCase):
    """
    Unit testing for the Mephisto Supervisor, 
    uses WebsocketChannel and MockArchitect
    """
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun(self.db, self.task_run_id)

        architect_config = OmegaConf.structured(
            MephistoConfig(
                architect=MockArchitectArgs(should_run_server=True), ))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.sup = None

    def tearDown(self):
        if self.sup is not None:
            self.sup.shutdown()
        self.launcher.expire_units()
        self.architect.cleanup()
        self.architect.shutdown()
        self.db.shutdown()
        shutil.rmtree(self.data_dir)

    def get_mock_assignment_data_array(self) -> List[InitializationData]:
        mock_data = MockTaskRunner.get_mock_assignment_data()
        return [mock_data, mock_data]

    def test_initialize_supervisor(self):
        """Ensure that the supervisor object can even be created"""
        sup = Supervisor(self.db)
        self.assertIsNotNone(sup)
        self.assertDictEqual(sup.agents, {})
        self.assertDictEqual(sup.channels, {})
        sup.shutdown()

    def test_channel_operations(self):
        """
        Initialize a channel, and ensure the basic 
        startup and shutdown functions are working
        """
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        test_job = Job(
            architect=self.architect,
            task_runner=task_runner,
            provider=self.provider,
            qualifications=[],
            registered_channel_ids=[],
        )

        channels = self.architect.get_channels(sup._on_channel_open,
                                               sup._on_catastrophic_disconnect,
                                               sup._on_message)
        channel = channels[0]
        channel.open()
        channel_id = channel.channel_id
        self.assertIsNotNone(channel_id)
        channel.close()
        self.assertTrue(channel.is_closed())

    def test_register_concurrent_job(self):
        """Test registering and running a job that requires multiple workers"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        args.is_concurrent = False
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive)
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Ready task was not launched")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)

        self.assertEqual(len(task_runner.running_units), 2,
                         "Tasks were not launched")
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operations to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_units) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed)

    def test_register_job(self):
        """Test registering and running a job run asynchronously"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(len(task_runner.running_assignments), 0,
                         "Task was not yet ready")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_assignments) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())

    def test_register_concurrent_job_with_onboarding(self):
        """Test registering and running a job with onboarding"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = True
        self.task_run.get_task_config()

        # Supervisor expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)

        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register an agent who fails onboarding
        mock_worker_name = "BAD_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_0 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[0].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(sup.agents), 0,
                         "Failed agent registered with supervisor")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_id)

        # Register an onboarding agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[1].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(
            len(task_runner.running_assignments),
            0,
            "Task was not yet ready, should not launch",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_2 = workers[0]
        worker_id = worker_2.db_id

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")

        self.assertFalse(worker_0.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_0.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_assignments) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())

    def test_register_job_with_onboarding(self):
        """Test registering and running a job with onboarding"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        # Register onboarding arguments for blueprint
        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = False
        self.task_run.get_task_config()

        # Supervisor expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_id)

        # Register an agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[0].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(sup.agents), 0,
                         "Failed agent registered with supervisor")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_2 = workers[0]
        worker_id = worker_2.db_id

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Tasks were not launched")

        self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))

        # Register another worker
        mock_worker_name = "MOCK_WORKER_3"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_3 = workers[0]
        worker_id = worker_3.db_id
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[1].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Duplicate agent created after onboarding")
        agent = agents[1]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 2,
                         "Agent not registered supervisor after onboarding")

        self.assertEqual(len(task_runner.running_units), 2,
                         "Task not launched after onboarding")

        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_units) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())
Ejemplo n.º 4
0
class CrowdsourcingTestMixin:
    """
    Mixin for end-to-end tests of Mephisto-based crowdsourcing tasks.

    Allows for setup and teardown of the operator, as well as for config specification
    and agent registration.
    """
    def setUp(self):
        self.operator = None

    def tearDown(self):
        if self.operator is not None:
            self.operator.shutdown()

    def _set_up_config(
        self,
        blueprint_type: str,
        task_directory: str,
        overrides: Optional[List[str]] = None,
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory, 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name="example",
                overrides=[
                    f'+mephisto.blueprint._blueprint_type={blueprint_type}',
                    f'+mephisto/architect=mock',
                    f'+mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )
            # TODO: when Hydra 1.1 is released with support for recursive defaults,
            #  don't manually specify all missing blueprint args anymore, but
            #  instead define the blueprint in the defaults list directly.
            #  Currently, the blueprint can't be set in the defaults list without
            #  overriding params in the YAML file, as documented at
            #  https://github.com/facebookresearch/hydra/issues/326 and as fixed in
            #  https://github.com/facebookresearch/hydra/pull/1044.

        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True

    def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
        """
        Set up the operator and server.
        """
        self.operator = Operator(self.db)
        self.operator.validate_and_run_config(self.config.mephisto,
                                              shared_state=shared_state)
        channel_info = list(self.operator.supervisor.channels.values())[0]
        self.server = channel_info.job.architect.server

    def _register_mock_agents(self, num_agents: int = 1) -> List[str]:
        """
        Register mock agents for testing, taking the place of crowdsourcing workers.

        Specify the number of agents to register. Return the agents' IDs after creation.
        """

        for idx in range(num_agents):

            # Register the worker
            mock_worker_name = f"MOCK_WORKER_{idx:d}"
            self.server.register_mock_worker(mock_worker_name)
            workers = self.db.find_workers(worker_name=mock_worker_name)
            worker_id = workers[0].db_id

            # Register the agent
            mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}"
            self.server.register_mock_agent(worker_id, mock_agent_details)

        # Get all agents' IDs
        agents = self.db.find_agents()
        agent_ids = [agent.db_id for agent in agents]

        return agent_ids