Ejemplo n.º 1
0
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        assert self.DB_CLASS is not None, "Did not specify db to use"
        self.db = self.DB_CLASS(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun.get(self.db, self.task_run_id)
        self.live_run = None

        architect_config = OmegaConf.structured(
            MephistoConfig(architect=MockArchitectArgs(
                should_run_server=True)))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   EMPTY_STATE, self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.client_io = ClientIOHandler(self.db)
        self.worker_pool = WorkerPool(self.db)
Ejemplo n.º 2
0
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun(self.db, self.task_run_id)

        architect_config = OmegaConf.structured(
            MephistoConfig(architect=MockArchitectArgs(
                should_run_server=True)))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   EMPTY_STATE, self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.sup = None
    def test_launch_assignments_with_concurrent_unit_cap(self):
        """Initialize a launcher on a task run, then create the assignments"""
        cap_values = [1, 2, 3, 4, 5]
        for max_num_units in cap_values:
            mock_data_array = self.get_mock_assignment_data_array()
            launcher = TaskLauncher(
                self.db,
                self.task_run,
                mock_data_array,
                max_num_concurrent_units=max_num_units,
            )
            launcher.launched_units = LimitedDict(
                launcher.max_num_concurrent_units)
            launcher.create_assignments()
            launcher.launch_units("dummy-url:3000")

            start_time = time.time()
            while set([u.get_status() for u in launcher.units
                       ]) != {AssignmentState.COMPLETED}:
                for unit in launcher.units:
                    if unit.get_status() == AssignmentState.LAUNCHED:
                        unit.set_db_status(AssignmentState.COMPLETED)
                    time.sleep(0.1)
                self.assertEqual(launcher.launched_units.exceed_limit, False)
                curr_time = time.time()
                self.assertLessEqual(curr_time - start_time,
                                     MAX_WAIT_TIME_UNIT_LAUNCH)
            launcher.expire_units()
            self.tearDown()
            self.setUp()
    def test_assignments_generator(self):
        """Initialize a launcher on a task run, then try generate the assignments"""
        mock_data_array = self.get_mock_assignment_data_generator()

        start_time = time.time()
        launcher = TaskLauncher(self.db, self.task_run, mock_data_array)
        launcher.create_assignments()
        end_time = time.time()
        self.assertLessEqual(
            end_time - start_time,
            (NUM_GENERATED_ASSIGNMENTS * WAIT_TIME_TILL_NEXT_ASSIGNMENT) / 2,
        )
 def test_init_on_task_run(self):
     """Initialize a launcher on a task_run"""
     launcher = TaskLauncher(self.db, self.task_run,
                             self.get_mock_assignment_data_array())
     self.assertEqual(self.db, launcher.db)
     self.assertEqual(self.task_run, launcher.task_run)
     self.assertEqual(len(launcher.assignments), 0)
     self.assertEqual(len(launcher.units), 0)
     self.assertEqual(launcher.provider_type, MockProvider.PROVIDER_TYPE)
    def test_create_launch_expire_assignments(self):
        """Initialize a launcher on a task run, then create the assignments"""
        mock_data_array = self.get_mock_assignment_data_array()
        launcher = TaskLauncher(self.db, self.task_run, mock_data_array)
        launcher.create_assignments()

        self.assertEqual(
            len(launcher.assignments),
            len(mock_data_array),
            "Inequal number of assignments existed than were launched",
        )
        self.assertEqual(
            len(launcher.units),
            len(mock_data_array) * len(mock_data_array[0].unit_data),
            "Inequal number of units created than were expected",
        )

        for unit in launcher.units:
            self.assertEqual(unit.get_db_status(), AssignmentState.CREATED)
        for assignment in launcher.assignments:
            self.assertEqual(assignment.get_status(), AssignmentState.CREATED)

        launcher.launch_units("dummy-url:3000")

        for unit in launcher.units:
            self.assertEqual(unit.get_db_status(), AssignmentState.LAUNCHED)
            time.sleep(WAIT_TIME_TILL_NEXT_UNIT)
        for assignment in launcher.assignments:
            self.assertEqual(assignment.get_status(), AssignmentState.LAUNCHED)

        launcher.expire_units()

        for unit in launcher.units:
            self.assertEqual(unit.get_db_status(), AssignmentState.EXPIRED)
        for assignment in launcher.assignments:
            self.assertEqual(assignment.get_status(), AssignmentState.EXPIRED)
Ejemplo n.º 7
0
    def validate_and_run_config_or_die(
            self,
            run_config: DictConfig,
            shared_state: Optional[SharedTaskState] = None) -> str:
        """
        Parse the given arguments and launch a job.
        """
        if shared_state is None:
            shared_state = SharedTaskState()

        # First try to find the requester:
        requester_name = run_config.provider.requester_name
        requesters = self.db.find_requesters(requester_name=requester_name)
        if len(requesters) == 0:
            if run_config.provider.requester_name == "MOCK_REQUESTER":
                requesters = [get_mock_requester(self.db)]
            else:
                raise EntryDoesNotExistException(
                    f"No requester found with name {requester_name}")
        requester = requesters[0]
        requester_id = requester.db_id
        provider_type = requester.provider_type
        assert provider_type == run_config.provider._provider_type, (
            f"Found requester for name {requester_name} is not "
            f"of the specified type {run_config.provider._provider_type}, "
            f"but is instead {provider_type}.")

        # Next get the abstraction classes, and run validation
        # before anything is actually created in the database
        blueprint_type = run_config.blueprint._blueprint_type
        architect_type = run_config.architect._architect_type
        BlueprintClass = get_blueprint_from_type(blueprint_type)
        ArchitectClass = get_architect_from_type(architect_type)
        CrowdProviderClass = get_crowd_provider_from_type(provider_type)

        BlueprintClass.assert_task_args(run_config, shared_state)
        ArchitectClass.assert_task_args(run_config, shared_state)
        CrowdProviderClass.assert_task_args(run_config, shared_state)

        # Find an existing task or create a new one
        task_name = run_config.task.get("task_name", None)
        if task_name is None:
            task_name = blueprint_type
            logger.warning(
                f"Task is using the default blueprint name {task_name} as a name, "
                "as no task_name is provided")
        tasks = self.db.find_tasks(task_name=task_name)
        task_id = None
        if len(tasks) == 0:
            task_id = self.db.new_task(task_name, blueprint_type)
        else:
            task_id = tasks[0].db_id

        logger.info(f"Creating a task run under task name: {task_name}")

        # Create a new task run
        new_run_id = self.db.new_task_run(
            task_id,
            requester_id,
            json.dumps(OmegaConf.to_container(run_config, resolve=True)),
            provider_type,
            blueprint_type,
            requester.is_sandbox(),
        )
        task_run = TaskRun(self.db, new_run_id)

        try:
            # Register the blueprint with args to the task run,
            # ensure cached
            blueprint = task_run.get_blueprint(args=run_config,
                                               shared_state=shared_state)

            # If anything fails after here, we have to cleanup the architect
            build_dir = os.path.join(task_run.get_run_dir(), "build")
            os.makedirs(build_dir, exist_ok=True)
            architect = ArchitectClass(self.db, run_config, shared_state,
                                       task_run, build_dir)

            # Setup and deploy the server
            built_dir = architect.prepare()
            task_url = architect.deploy()

            # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only
            # happen after everything has already been reviewed, this way it's possible to
            # retrieve the exact build directory to review a task for real
            architect.cleanup()

            # Create the backend runner
            task_runner = BlueprintClass.TaskRunnerClass(
                task_run, run_config, shared_state)

            # Small hack for auto appending block qualification
            existing_qualifications = shared_state.qualifications
            if run_config.blueprint.get("block_qualification",
                                        None) is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        run_config.blueprint.block_qualification,
                        QUAL_NOT_EXIST, None))
            if run_config.blueprint.get("onboarding_qualification",
                                        None) is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        OnboardingRequired.get_failed_qual(
                            run_config.blueprint.onboarding_qualification),
                        QUAL_NOT_EXIST,
                        None,
                    ))
            shared_state.qualifications = existing_qualifications

            # Register the task with the provider
            provider = CrowdProviderClass(self.db)
            provider.setup_resources_for_task_run(task_run, run_config,
                                                  shared_state, task_url)

            initialization_data_array = blueprint.get_initialization_data()

            # Link the job together
            job = self.supervisor.register_job(architect, task_runner,
                                               provider,
                                               existing_qualifications)
            if self.supervisor.sending_thread is None:
                self.supervisor.launch_sending_thread()
        except (KeyboardInterrupt, Exception) as e:
            logger.error(
                "Encountered error while launching run, shutting down",
                exc_info=True)
            try:
                architect.shutdown()
            except (KeyboardInterrupt, Exception) as architect_exception:
                logger.exception(
                    f"Could not shut down architect: {architect_exception}",
                    exc_info=True,
                )
            raise e

        launcher = TaskLauncher(self.db, task_run, initialization_data_array)
        launcher.create_assignments()
        launcher.launch_units(task_url)

        self._task_runs_tracked[task_run.db_id] = TrackedRun(
            task_run=task_run,
            task_launcher=launcher,
            task_runner=task_runner,
            architect=architect,
            job=job,
        )
        task_run.update_completion_progress(status=False)

        return task_run.db_id
Ejemplo n.º 8
0
class BaseTestLiveRuns:
    """
    Unit testing for the Mephisto Live Runs,
    uses WebsocketChannel and MockArchitect
    """

    DB_CLASS: ClassVar[Type["MephistoDB"]]

    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        assert self.DB_CLASS is not None, "Did not specify db to use"
        self.db = self.DB_CLASS(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun.get(self.db, self.task_run_id)
        self.live_run = None

        architect_config = OmegaConf.structured(
            MephistoConfig(architect=MockArchitectArgs(
                should_run_server=True)))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   EMPTY_STATE, self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.client_io = ClientIOHandler(self.db)
        self.worker_pool = WorkerPool(self.db)

    def tearDown(self):
        self.launcher.expire_units()
        self.architect.cleanup()
        self.architect.shutdown()
        if self.live_run is not None:
            self.live_run.shutdown()
        else:
            self.worker_pool.shutdown()
            self.client_io.shutdown()
        self.db.shutdown()
        shutil.rmtree(self.data_dir, ignore_errors=True)

    def get_mock_run(self, blueprint, task_runner) -> LiveTaskRun:
        live_run = LiveTaskRun(
            self.task_run,
            self.architect,
            blueprint,
            self.provider,
            [],
            task_runner,
            self.launcher,
            self.client_io,
            self.worker_pool,
            LoopWrapper(asyncio.new_event_loop()),
        )
        self.client_io.register_run(live_run)
        self.worker_pool.register_run(live_run)
        return live_run

    def get_mock_assignment_data_array(self) -> List[InitializationData]:
        mock_data = MockTaskRunner.get_mock_assignment_data()
        return [mock_data, mock_data]

    def make_registered_worker(self, worker_name) -> Worker:
        worker_id = self.db.new_worker(worker_name + "_sandbox", "mock")
        return Worker.get(self.db, worker_id)

    def _run_loop_until(
        self,
        live_run: LiveTaskRun,
        condition_met: Callable[[], bool],
        timeout,
        failure_message=None,
    ) -> bool:
        """
        Function to run the event loop until a specific condition is met, or
        a timeout elapses
        """
        loop = live_run.loop_wrap.loop
        asyncio.set_event_loop(loop)

        async def wait_for_condition_or_timeout():
            condition_was_met = False
            start_time = time.time()
            while time.time() - start_time < timeout:
                await asyncio.sleep(0.01)
                if condition_met():
                    condition_was_met = True
                    break
                await asyncio.sleep(0.2)
            return condition_was_met

        return loop.run_until_complete(wait_for_condition_or_timeout())

    def assert_sandbox_worker_created(self,
                                      live_run,
                                      worker_name,
                                      timeout=2) -> None:
        self.assertTrue(  # type: ignore
            self._run_loop_until(
                live_run,
                lambda: len(
                    self.db.find_workers(worker_name=worker_name + "_sandbox"))
                > 0,
                timeout,
            ),
            f"Worker {worker_name} not created in time!",
        )

    def assert_agent_created(self, live_run, agent_num, timeout=2) -> None:
        self.assertTrue(  # type: ignore
            self._run_loop_until(
                live_run,
                lambda: len(self.db.find_agents()) == agent_num,
                timeout,
            ),
            f"Agent {agent_num} not created in time!",
        )
        agents = self.db.find_agents()
        agent = agents[agent_num - 1]
        self.assertIsNotNone(agent)  # type: ignore

    def _await_current_tasks(self, live_run, timeout=5) -> None:
        self._run_loop_until(
            live_run,
            lambda: len(asyncio.all_tasks(live_run.loop_wrap.loop)) < 3,
            timeout,
        )

    def await_channel_requests(self, live_run, timeout=2) -> None:
        self._await_current_tasks(live_run, timeout)
        self.assertTrue(  # type: ignore
            self._run_loop_until(
                live_run,
                lambda: len(live_run.client_io.request_id_to_channel_id) == 0,
                timeout,
            ),
            f"Channeled requests not processed in time!",
        )

    def test_channel_operations(self):
        """
        Initialize a channel, and ensure the basic
        startup and shutdown functions are working
        """
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)

        channels = self.architect.get_channels(
            self.client_io._on_channel_open,
            self.client_io._on_catastrophic_disconnect,
            self.client_io._on_message,
        )
        channel = channels[0]
        self.client_io._register_channel(channel)
        self.assertTrue(channel.is_alive())
        channel.close()
        self.assertTrue(channel.is_closed())

    def test_register_concurrent_run(self):
        """Test registering and running a run that requires multiple workers"""
        # Handle baseline setup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        args.is_concurrent = False
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        blueprint = self.task_run.get_blueprint()
        live_run = self.get_mock_run(blueprint, task_runner)
        self.live_run = live_run
        live_run.client_io.launch_channels()
        self.assertEqual(len(live_run.client_io.channels), 1)
        channel = list(live_run.client_io.channels.values())[0]
        self.assertIsNotNone(channel)
        self.assertTrue(channel.is_alive())
        task_runner = live_run.task_runner
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name +
                                       "_sandbox")
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        workers = self.db.find_workers(worker_name=mock_worker_name +
                                       "_sandbox")
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(live_run.worker_pool.agents), 1,
                         "Agent not registered with worker pool")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Ready task was not launched")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)

        self.assertEqual(len(task_runner.running_units), 2,
                         "Tasks were not launched")
        agents = [a for a in live_run.worker_pool.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})
        self.await_channel_requests(live_run)

        # Give up to 1 seconds for the actual operations to occur
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(agent_1_data["acts"]) > 0,
                1,
            ),
            "Did not process messages in time",
        )
        self.architect.server.submit_mock_unit(agent_id_1, {"completed": True})
        self.architect.server.submit_mock_unit(agent_id_2, {"completed": True})
        self.await_channel_requests(live_run)

        # Give up to 1 seconds for the task to complete afterwards
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(task_runner.running_units) == 0,
                1,
            ),
            "Did not complete task in time",
        )

        # Give up to 1 seconds for all messages to propogate
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: self.architect.server.actions_observed == 2,
                1,
            ),
            "Not all actions observed in time",
        )

        live_run.shutdown()
        self.assertTrue(channel.is_closed)

    def test_register_run(self):
        """Test registering and running a task run asynchronously"""
        # Handle baseline setup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        blueprint = self.task_run.get_blueprint(args=config)
        live_run = self.get_mock_run(blueprint, task_runner)
        self.live_run = live_run
        live_run.client_io.launch_channels()
        self.assertEqual(len(live_run.client_io.channels), 1)
        channel = list(live_run.client_io.channels.values())[0]
        self.assertIsNotNone(channel)
        self.assertTrue(channel.is_alive())
        task_runner = live_run.task_runner
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name +
                                       "_sandbox")
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(self.worker_pool.agents), 1,
                         "Agent not registered with worker pool")

        self.assertEqual(len(task_runner.running_assignments), 0,
                         "Task was not yet ready")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")
        agents = [a for a in self.worker_pool.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})
        self.await_channel_requests(live_run)

        # Give up to 1 seconds for the actual operation to occur
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(agent_1_data["acts"]) > 0,
                1,
            ),
            "Did not process messages in time",
        )

        self.architect.server.submit_mock_unit(agent_id_1, {"completed": True})
        self.architect.server.submit_mock_unit(agent_id_2, {"completed": True})

        # Give up to 1 seconds for the task to complete afterwards
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(task_runner.running_assignments) == 0,
                1,
            ),
            "Did not complete task in time",
        )

        # Give up to 1 seconds for all messages to propogate
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: self.architect.server.actions_observed == 2,
                1,
            ),
            "Not all actions observed in time",
        )

        live_run.shutdown()
        self.assertTrue(channel.is_closed())

    def test_register_concurrent_run_with_onboarding(self):
        """Test registering and running a run with onboarding"""
        # Handle baseline setup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = True

        # LiveTaskRun expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)

        live_run = self.get_mock_run(blueprint, task_runner)
        self.live_run = live_run
        live_run.client_io.launch_channels()
        self.assertEqual(len(live_run.client_io.channels), 1)
        channel = list(live_run.client_io.channels.values())[0]
        self.assertIsNotNone(channel)
        self.assertTrue(channel.is_alive())
        task_runner = live_run.task_runner
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register an agent who fails onboarding
        mock_worker_name = "BAD_WORKER"

        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name +
                                       "_sandbox")
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_0 = workers[0]
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        if not last_packet["data"].get("status") == "onboarding":
            self.assertIn("onboard_data", last_packet["data"],
                          "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_0.db_id, onboard_agents[0].get_agent_id(), onboard_data)
        self.await_channel_requests(live_run, 4)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_0.db_id,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(self.worker_pool.agents), 0,
                         "Failed agent registered with worker pool")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        worker_1 = self.make_registered_worker(mock_worker_name)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_1.db_id)

        # Register an onboarding agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        if not last_packet["data"].get("status") == "onboarding":
            self.assertIn("onboard_data", last_packet["data"],
                          "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_1.db_id, onboard_agents[1].get_agent_id(), onboard_data)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(self.worker_pool.agents), 1,
                         "Agent not registered with worker pool")

        self.assertEqual(
            len(task_runner.running_assignments),
            0,
            "Task was not yet ready, should not launch",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        worker_2 = self.make_registered_worker(mock_worker_name)

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_4"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")

        self.assertFalse(worker_0.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_0.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))
        agents = [a for a in self.worker_pool.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})
        self.await_channel_requests(live_run)

        # Give up to 1 seconds for the actual operation to occur
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(agent_1_data["acts"]) > 0,
                1,
            ),
            "Did not process messages in time",
        )

        self.architect.server.submit_mock_unit(agent_id_1, {"completed": True})
        self.architect.server.submit_mock_unit(agent_id_2, {"completed": True})

        # Give up to 1 seconds for the task to complete afterwards
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(task_runner.running_assignments) == 0,
                1,
            ),
            "Did not complete task in time",
        )

        # Give up to 1 seconds for all messages to propogate
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: self.architect.server.actions_observed == 2,
                1,
            ),
            "Not all actions observed in time",
        )

        live_run.shutdown()
        self.assertTrue(channel.is_closed())

    def test_register_run_with_onboarding(self):
        """Test registering and running a run with onboarding"""
        # Handle baseline setup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        # Register onboarding arguments for blueprint
        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = False

        # LiveTaskRun expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)
        live_run = self.get_mock_run(blueprint, task_runner)
        self.live_run = live_run
        live_run.client_io.launch_channels()
        self.assertEqual(len(live_run.client_io.channels), 1)
        channel = list(live_run.client_io.channels.values())[0]
        self.assertIsNotNone(channel)
        self.assertTrue(channel.is_alive())
        task_runner = live_run.task_runner
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        worker_1 = self.make_registered_worker(mock_worker_name)

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_1.db_id)

        # Register an agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        if not last_packet["data"].get("status") == "onboarding":
            self.assertIn("onboard_data", last_packet["data"],
                          "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_1.db_id, onboard_agents[0].get_agent_id(), onboard_data)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(self.worker_pool.agents), 0,
                         "Failed agent registered with worker pool")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        worker_2 = self.make_registered_worker(mock_worker_name)

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)

        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Tasks were not launched")

        self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))

        # Register another worker
        mock_worker_name = "MOCK_WORKER_3"
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name +
                                       "_sandbox")
        worker_3 = workers[0]
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        self._await_current_tasks(live_run, 2)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        if not last_packet["data"].get("status") == "onboarding":
            self.assertIn("onboard_data", last_packet["data"],
                          "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_2.db_id, onboard_agents[1].get_agent_id(), onboard_data)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(mock_worker_name,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Duplicate agent created after onboarding")
        agent = agents[1]
        self.assertIsNotNone(agent)
        self.assertEqual(
            len(self.worker_pool.agents),
            2,
            "Agent not registered to worker pool after onboarding",
        )

        self.assertEqual(len(task_runner.running_units), 2,
                         "Task not launched after onboarding")

        agents = [a for a in self.worker_pool.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})
        self.await_channel_requests(live_run)

        # Give up to 1 seconds for the actual operation to occur
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(agent_1_data["acts"]) > 0,
                1,
            ),
            "Did not process messages in time",
        )

        self.architect.server.submit_mock_unit(agent_id_1, {"completed": True})
        self.architect.server.submit_mock_unit(agent_id_2, {"completed": True})

        # Give up to 1 seconds for the task to complete afterwards
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: len(task_runner.running_units) == 0,
                1,
            ),
            "Did not complete task in time",
        )

        # Give up to 1 seconds for all messages to propogate
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: self.architect.server.actions_observed == 2,
                1,
            ),
            "Not all actions observed in time",
        )

        live_run.shutdown()
        self.assertTrue(channel.is_closed())

    def test_register_run_with_screening(self):
        """Test registering and running a run with screening"""
        if self.DB_CLASS != MephistoSingletonDB:
            # TODO(#97) This test only works with singleton for now due to disconnect simulation
            return

        # Handle baseline setup
        PASSED_QUALIFICATION_NAME = "test_screening_qualification"
        FAILED_QUALIFICATION_NAME = "failed_screening_qualification"

        # Register onboarding arguments for blueprint
        task_run_args = self.task_run.args
        task_run_args.blueprint.use_screening_task = True
        task_run_args.blueprint.passed_qualification_name = PASSED_QUALIFICATION_NAME
        task_run_args.blueprint.block_qualification = FAILED_QUALIFICATION_NAME
        task_run_args.blueprint.max_screening_units = 2
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = False

        def screen_unit(unit):
            if unit.get_assigned_agent() is None:
                return None  # No real data to evaluate

            agent = unit.get_assigned_agent()
            output = agent.state.get_data()
            if output is None:
                return None  # no data to evaluate

            return output["success"]

        shared_state = MockSharedState()
        shared_state.on_unit_submitted = ScreenTaskRequired.create_validation_function(
            task_run_args,
            screen_unit,
        )

        # LiveTaskRun expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint(task_run_args, shared_state)

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      shared_state)
        live_run = self.get_mock_run(blueprint, task_runner)
        self.live_run = live_run
        live_run.client_io.launch_channels()
        self.assertEqual(len(live_run.client_io.channels), 1)
        channel = list(live_run.client_io.channels.values())[0]
        self.assertIsNotNone(channel)
        self.assertTrue(channel.is_alive())
        task_runner = live_run.task_runner
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )

        # Register workers
        mock_worker_name_1 = "MOCK_WORKER"
        mock_worker_name_2 = "MOCK_WORKER_2"
        mock_worker_name_3 = "MOCK_WORKER_3"

        # Register a screening agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(mock_worker_name_1,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name_1 +
                                       "_sandbox")
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "No agent created for screening")

        self.assertEqual(
            agents[0].get_unit().unit_index,
            SCREENING_UNIT_INDEX,
            "Agent not assigned screening unit",
        )

        # Register a second screening agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT2"
        self.architect.server.register_mock_agent(mock_worker_name_2,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name_2 +
                                       "_sandbox")
        worker_2 = workers[0]
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "No agent created for screening")
        last_packet = None

        self.assertEqual(
            agents[1].get_unit().unit_index,
            SCREENING_UNIT_INDEX,
            "Agent not assigned screening unit",
        )

        # Fail to register a third screening agent
        mock_agent_details = "FAKE_ASSIGNMENT3"
        self.architect.server.register_mock_agent(mock_worker_name_3,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        workers = self.db.find_workers(worker_name=mock_worker_name_3 +
                                       "_sandbox")
        worker_3 = workers[0]
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Third agent created, when 2 was max")

        # Disconnect first agent
        agents[0].update_status(AgentState.STATUS_DISCONNECT)

        # Register third screening agent
        mock_agent_details = "FAKE_ASSIGNMENT3"
        self.architect.server.register_mock_agent(mock_worker_name_3,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 3, "Third agent not created")

        self.assertEqual(
            agents[2].get_unit().unit_index,
            SCREENING_UNIT_INDEX,
            "Agent not assigned screening unit",
        )

        # Submit screening from the agent
        screening_data = {"success": False}
        self.architect.server.send_agent_act(agents[1].get_agent_id(),
                                             screening_data)
        self.architect.server.submit_mock_unit(agents[1].get_agent_id(),
                                               screening_data)
        self.await_channel_requests(live_run)
        # Assert failed screening screening
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: worker_2.is_qualified(FAILED_QUALIFICATION_NAME),
                5,
            ),
            "Did not disqualify in time",
        )

        # Submit screening from the agent
        screening_data = {"success": True}
        self.architect.server.send_agent_act(agents[2].get_agent_id(),
                                             screening_data)
        self.architect.server.submit_mock_unit(agents[2].get_agent_id(),
                                               screening_data)
        self.await_channel_requests(live_run)
        # Assert successful screening screening
        self.assertTrue(
            self._run_loop_until(
                live_run,
                lambda: worker_3.is_qualified(PASSED_QUALIFICATION_NAME),
                5,
            ),
            "Did not qualify in time",
        )

        # Accept a real task, and complete it, from worker 3
        # Register a task agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT4"
        self.architect.server.register_mock_agent(mock_worker_name_3,
                                                  mock_agent_details)
        self.await_channel_requests(live_run)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "No agent created for task")
        last_packet = None

        self.architect.server.send_agent_act(agents[3].get_agent_id(),
                                             screening_data)
        self.architect.server.submit_mock_unit(agents[3].get_agent_id(),
                                               screening_data)
        self.await_channel_requests(live_run)

        self.assertNotEqual(
            agents[3].get_unit().unit_index,
            SCREENING_UNIT_INDEX,
            "Agent assigned screening unit",
        )

        live_run.shutdown()
        self.assertTrue(channel.is_closed())
Ejemplo n.º 9
0
    def _create_live_task_run(
        self,
        run_config: DictConfig,
        shared_state: SharedTaskState,
        task_run: TaskRun,
        architect_class: Type["Architect"],
        blueprint_class: Type["Blueprint"],
        provider_class: Type["CrowdProvider"],
    ) -> LiveTaskRun:
        """
        Initialize all of the members of a live task run object
        """
        # Register the blueprint with args to the task run to ensure cached
        blueprint = task_run.get_blueprint(args=run_config,
                                           shared_state=shared_state)

        # prepare the architect
        build_dir = os.path.join(task_run.get_run_dir(), "build")
        os.makedirs(build_dir, exist_ok=True)
        architect = architect_class(self.db, run_config, shared_state,
                                    task_run, build_dir)
        # Create the backend runner
        task_runner = blueprint_class.TaskRunnerClass(task_run, run_config,
                                                      shared_state)

        # Small hack for auto appending block qualification
        # TODO(OWN) we can use blueprint.mro() to discover BlueprintMixins and extract from there
        existing_qualifications = shared_state.qualifications
        if run_config.blueprint.get("block_qualification", None) is not None:
            existing_qualifications.append(
                make_qualification_dict(
                    run_config.blueprint.block_qualification, QUAL_NOT_EXIST,
                    None))
        if run_config.blueprint.get("onboarding_qualification",
                                    None) is not None:
            existing_qualifications.append(
                make_qualification_dict(
                    OnboardingRequired.get_failed_qual(
                        run_config.blueprint.onboarding_qualification),
                    QUAL_NOT_EXIST,
                    None,
                ))
        shared_state.qualifications = existing_qualifications

        # Create provider
        provider = provider_class(self.db)

        # Create the launcher
        initialization_data_iterable = blueprint.get_initialization_data()
        launcher = TaskLauncher(
            self.db,
            task_run,
            initialization_data_iterable,
            max_num_concurrent_units=run_config.task.max_num_concurrent_units,
        )

        worker_pool = WorkerPool(self.db)
        client_io = ClientIOHandler(self.db)
        live_run = LiveTaskRun(
            task_run=task_run,
            architect=architect,
            blueprint=blueprint,
            provider=provider,
            qualifications=shared_state.qualifications,
            task_runner=task_runner,
            task_launcher=launcher,
            client_io=client_io,
            worker_pool=worker_pool,
            loop_wrap=self._loop_wrapper,
        )
        worker_pool.register_run(live_run)
        client_io.register_run(live_run)

        return live_run
Ejemplo n.º 10
0
class TestSupervisor(unittest.TestCase):
    """
    Unit testing for the Mephisto Supervisor,
    uses WebsocketChannel and MockArchitect
    """
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun(self.db, self.task_run_id)

        architect_config = OmegaConf.structured(
            MephistoConfig(architect=MockArchitectArgs(
                should_run_server=True)))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   EMPTY_STATE, self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.sup = None

    def tearDown(self):
        if self.sup is not None:
            self.sup.shutdown()
        self.launcher.expire_units()
        self.architect.cleanup()
        self.architect.shutdown()
        self.db.shutdown()
        shutil.rmtree(self.data_dir, ignore_errors=True)

    def get_mock_assignment_data_array(self) -> List[InitializationData]:
        mock_data = MockTaskRunner.get_mock_assignment_data()
        return [mock_data, mock_data]

    def test_initialize_supervisor(self):
        """Ensure that the supervisor object can even be created"""
        sup = Supervisor(self.db)
        self.assertIsNotNone(sup)
        self.assertDictEqual(sup.agents, {})
        self.assertDictEqual(sup.channels, {})
        sup.shutdown()

    def test_channel_operations(self):
        """
        Initialize a channel, and ensure the basic
        startup and shutdown functions are working
        """
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        test_job = Job(
            architect=self.architect,
            task_runner=task_runner,
            provider=self.provider,
            qualifications=[],
            registered_channel_ids=[],
        )

        channels = self.architect.get_channels(sup._on_channel_open,
                                               sup._on_catastrophic_disconnect,
                                               sup._on_message)
        channel = channels[0]
        channel.open()
        channel_id = channel.channel_id
        self.assertIsNotNone(channel_id)
        channel.close()
        self.assertTrue(channel.is_closed())

    def test_register_concurrent_job(self):
        """Test registering and running a job that requires multiple workers"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        args.is_concurrent = False
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive)
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Ready task was not launched")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)

        self.assertEqual(len(task_runner.running_units), 2,
                         "Tasks were not launched")
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operations to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_units) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed)

    def test_register_job(self):
        """Test registering and running a job run asynchronously"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(len(task_runner.running_assignments), 0,
                         "Task was not yet ready")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_assignments) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())

    def test_register_concurrent_job_with_onboarding(self):
        """Test registering and running a job with onboarding"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = True
        self.task_run.get_task_config()

        # Supervisor expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)

        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register an agent who fails onboarding
        mock_worker_name = "BAD_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_0 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[0].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(sup.agents), 0,
                         "Failed agent registered with supervisor")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_id)

        # Register an onboarding agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[1].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(
            len(task_runner.running_assignments),
            0,
            "Task was not yet ready, should not launch",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_2 = workers[0]
        worker_id = worker_2.db_id

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")

        self.assertFalse(worker_0.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_0.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_assignments) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())

    def test_register_job_with_onboarding(self):
        """Test registering and running a job with onboarding"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        # Register onboarding arguments for blueprint
        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = False
        self.task_run.get_task_config()

        # Supervisor expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_id)

        # Register an agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[0].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(sup.agents), 0,
                         "Failed agent registered with supervisor")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_2 = workers[0]
        worker_id = worker_2.db_id

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Tasks were not launched")

        self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))

        # Register another worker
        mock_worker_name = "MOCK_WORKER_3"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_3 = workers[0]
        worker_id = worker_3.db_id
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[1].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Duplicate agent created after onboarding")
        agent = agents[1]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 2,
                         "Agent not registered supervisor after onboarding")

        self.assertEqual(len(task_runner.running_units), 2,
                         "Task not launched after onboarding")

        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_units) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())