class AbstractCrowdsourcingTest: """ Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks. Allows for setup and teardown of the operator, as well as for config specification and agent registration. """ def _setup(self): """ To be run before a test. Should be called in a pytest setup/teardown fixture. """ random.seed(0) np.random.seed(0) torch.manual_seed(0) self.operator = None self.server = None def _teardown(self): """ To be run after a test. Should be called in a pytest setup/teardown fixture. """ if self.operator is not None: self.operator.force_shutdown() if self.server is not None: self.server.shutdown_mock() def _set_up_config( self, task_directory: str, overrides: Optional[List[str]] = None, config_name: str = "example", ): """ Set up the config and database. Uses the Hydra compose() API for unit testing and a temporary directory to store the test database. :param blueprint_type: string uniquely specifying Blueprint class :param task_directory: directory containing the `conf/` configuration folder. Will be injected as `${task_dir}` in YAML files. :param overrides: additional config overrides """ # Define the configuration settings relative_task_directory = os.path.relpath(task_directory, os.path.dirname(__file__)) relative_config_path = os.path.join(relative_task_directory, 'hydra_configs', 'conf') if overrides is None: overrides = [] with initialize(config_path=relative_config_path): self.config = compose( config_name=config_name, overrides=[ f'mephisto/architect=mock', f'mephisto/provider=mock', f'+task_dir={task_directory}', f'+current_time={int(time.time())}', ] + overrides, ) self.data_dir = tempfile.mkdtemp() self.database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(self.database_path) self.config = augment_config_from_db(self.config, self.db) self.config.mephisto.architect.should_run_server = True def _set_up_server(self, shared_state: Optional[SharedTaskState] = None): """ Set up the operator and server. """ self.operator = Operator(self.db) self.operator.validate_and_run_config(self.config.mephisto, shared_state=shared_state) self.server = self._get_channel_info().job.architect.server def _get_channel_info(self): """ Return channel info for the currently running job. """ channels = list(self.operator.supervisor.channels.values()) if len(channels) > 0: return channels[0] else: raise ValueError('No channel could be detected!') def _register_mock_agents(self, num_agents: int = 1, assume_onboarding: bool = False) -> List[str]: """ Register mock agents for testing and onboard them if needed, taking the place of crowdsourcing workers. Specify the number of agents to register. Return the agents' IDs after creation. """ for idx in range(num_agents): mock_worker_name = f"MOCK_WORKER_{idx:d}" max_num_tries = 6 initial_wait_time = 0.5 # In seconds num_tries = 0 wait_time = initial_wait_time while num_tries < max_num_tries: try: # Register the worker self.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers( worker_name=mock_worker_name) worker_id = workers[0].db_id # Register the agent mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}" self.server.register_mock_agent(worker_id, mock_agent_details) if assume_onboarding: # Submit onboarding from the agent onboard_agents = self.db.find_onboarding_agents() onboard_data = {"onboarding_data": {"success": True}} self.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[0].get_agent_id(), onboard_data) _ = self.db.find_agents()[idx] # Make sure the agent can be found, or else raise an IndexError break except IndexError: num_tries += 1 print( f'The agent could not be registered after {num_tries:d} ' f'attempt(s), out of {max_num_tries:d} attempts total. Waiting ' f'for {wait_time:0.1f} seconds...') time.sleep(wait_time) wait_time *= 2 # Wait for longer next time else: raise ValueError('The worker could not be registered!') # Get all agents' IDs agents = self.db.find_agents() if len(agents) != num_agents: raise ValueError( f'The actual number of agents is {len(agents):d} instead of the ' f'desired {num_agents:d}!') agent_ids = [agent.db_id for agent in agents] return agent_ids
class AbstractCrowdsourcingTest(unittest.TestCase): """ Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks. Allows for setup and teardown of the operator, as well as for config specification and agent registration. """ def setUp(self): self.operator = None def tearDown(self): if self.operator is not None: self.operator.shutdown() def _set_up_config( self, blueprint_type: str, task_directory: str, overrides: Optional[List[str]] = None, ): """ Set up the config and database. Uses the Hydra compose() API for unit testing and a temporary directory to store the test database. :param blueprint_type: string uniquely specifying Blueprint class :param task_directory: directory containing the `conf/` configuration folder. Will be injected as `${task_dir}` in YAML files. :param overrides: additional config overrides """ # Define the configuration settings relative_task_directory = os.path.relpath(task_directory, os.path.dirname(__file__)) relative_config_path = os.path.join(relative_task_directory, 'conf') if overrides is None: overrides = [] with initialize(config_path=relative_config_path): self.config = compose( config_name="example", overrides=[ f'+mephisto.blueprint._blueprint_type={blueprint_type}', f'+mephisto/architect=mock', f'+mephisto/provider=mock', f'+task_dir={task_directory}', f'+current_time={int(time.time())}', ] + overrides, ) # TODO: when Hydra 1.1 is released with support for recursive defaults, # don't manually specify all missing blueprint args anymore, but # instead define the blueprint in the defaults list directly. # Currently, the blueprint can't be set in the defaults list without # overriding params in the YAML file, as documented at # https://github.com/facebookresearch/hydra/issues/326 and as fixed in # https://github.com/facebookresearch/hydra/pull/1044. self.data_dir = tempfile.mkdtemp() database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(database_path) self.config = augment_config_from_db(self.config, self.db) self.config.mephisto.architect.should_run_server = True def _set_up_server(self, shared_state: Optional[SharedTaskState] = None): """ Set up the operator and server. """ self.operator = Operator(self.db) self.operator.validate_and_run_config(self.config.mephisto, shared_state=shared_state) channel_info = list(self.operator.supervisor.channels.values())[0] self.server = channel_info.job.architect.server def _register_mock_agents(self, num_agents: int = 1) -> List[str]: """ Register mock agents for testing, taking the place of crowdsourcing workers. Specify the number of agents to register. Return the agents' IDs after creation. """ for idx in range(num_agents): # Register the worker mock_worker_name = f"MOCK_WORKER_{idx:d}" self.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register the agent mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}" self.server.register_mock_agent(worker_id, mock_agent_details) # Get all agents' IDs agents = self.db.find_agents() agent_ids = [agent.db_id for agent in agents] return agent_ids
def main(): """ Script to crawl through the database for a specific task run and ensure that all of the states of units and related MTurk data is synced up. """ TASK_RUN = input("Enter task run ID to check integrity of: \n") db = LocalMephistoDB() task_run = TaskRun(db, TASK_RUN) units = task_run.get_units() completed_agentless_units = [ u for u in units if u.get_status() in ["completed", "accepted", "soft_rejected"] and u.get_assigned_agent() is None ] completed_agented_units = [ u for u in units if u.get_status() in ["completed", "accepted", "soft_rejected"] and u.get_assigned_agent() is not None ] completed_timeout_units = [ u for u in completed_agented_units if u.get_assigned_agent().get_status() == "timeout" ] if len(completed_agentless_units) == 0 and len(completed_timeout_units) == 0: print("It appears everything is as should be!") return print( f"Found {len(completed_agentless_units)} completed units without an agent, and " f"{len(completed_timeout_units)} completed units with a timed out agent.\n" "We'll need to query MTurk HITs to determine where these fall..." ) print(completed_timeout_units[-5:]) agents = db.find_agents(task_run_id=TASK_RUN) + db.find_agents( task_run_id=TASK_RUN - 1 ) requester = units[0].get_requester() client = requester._get_client(requester._requester_name) outstanding = get_outstanding_hits(client) print( f"Found {len(outstanding)} different HIT types in flight for this account. " "Select the relevant one below." ) for hit_type_id, hits in outstanding.items(): print(f"{hit_type_id}({len(hits)} hits): {hits[0]['Title']}") if input("Is this correct?: y/(n) ").lower().startswith("y"): break task_hits = outstanding[hit_type_id] print(f"Querying assignments for the {len(hits)} tasks.") task_assignments_uf = [ get_assignments_for_hit(client, h["HITId"]) for h in task_hits ] task_assignments = [t[0] for t in task_assignments_uf if len(t) != 0] print(f"Found {len(task_assignments)} assignments to map.") print("Constructing worker-to-agent mapping...") worker_id_to_agents = {} for a in agents: worker_id = a.get_worker().worker_name if worker_id not in worker_id_to_agents: worker_id_to_agents[worker_id] = [] worker_id_to_agents[worker_id].append(a) print("Constructing hit-id to unit mapping for completed...") hit_ids_to_unit = { u.get_mturk_hit_id(): u for u in units if u.get_mturk_hit_id() is not None } unattributed_assignments = [ t for t in task_assignments if t["HITId"] not in hit_ids_to_unit ] print(f"Found {len(unattributed_assignments)} assignments with no mapping!") print(f"Mapping unattributed assignments to workers") for assignment in unattributed_assignments: worker_id = assignment["WorkerId"] agents = worker_id_to_agents.get(worker_id) print(f"Worker: {worker_id}. Current agents: {agents}") if agents is not None: for agent in agents: if agent.get_status() != "timeout": continue units_agent = agent.get_unit().get_assigned_agent() if units_agent is None or units_agent.db_id != agent.db_id: continue print( f"Agent {agent} would be a good candidate to reconcile {assignment['HITId']}" ) # TODO(WISH) automate the below print( "You can do this manually by selecting the best candidate, then " "updating the MTurk datastore to assign this HITId and assignmentId " "to the given agent and its associated unit. You can then either " "approve if you can reconcile the agent state, or soft_reject " "to pay out properly. " ) do_cleanup = input( f"If all are reconciled, would you like to clean up remaining timeouts? y/(n)" ) if do_cleanup.lower().startswith("y"): for unit in completed_agentless_units: unit.set_db_status("expired") for unit in completed_timeout_units: unit.set_db_status("expired")
class AbstractCrowdsourcingTest: """ Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks. Allows for setup and teardown of the operator, as well as for config specification and agent registration. """ def _setup(self): """ To be run before a test. Should be called in a pytest setup/teardown fixture. """ random.seed(0) np.random.seed(0) torch.manual_seed(0) self.operator = None def _teardown(self): """ To be run after a test. Should be called in a pytest setup/teardown fixture. """ if self.operator is not None: self.operator.force_shutdown() def _set_up_config( self, blueprint_type: str, task_directory: str, overrides: Optional[List[str]] = None, ): """ Set up the config and database. Uses the Hydra compose() API for unit testing and a temporary directory to store the test database. :param blueprint_type: string uniquely specifying Blueprint class :param task_directory: directory containing the `conf/` configuration folder. Will be injected as `${task_dir}` in YAML files. :param overrides: additional config overrides """ # Define the configuration settings relative_task_directory = os.path.relpath(task_directory, os.path.dirname(__file__)) relative_config_path = os.path.join(relative_task_directory, 'conf') if overrides is None: overrides = [] with initialize(config_path=relative_config_path): self.config = compose( config_name="example", overrides=[ f'+mephisto.blueprint._blueprint_type={blueprint_type}', f'+mephisto/architect=mock', f'+mephisto/provider=mock', f'+task_dir={task_directory}', f'+current_time={int(time.time())}', ] + overrides, ) # TODO: when Hydra 1.1 is released with support for recursive defaults, # don't manually specify all missing blueprint args anymore, but # instead define the blueprint in the defaults list directly. # Currently, the blueprint can't be set in the defaults list without # overriding params in the YAML file, as documented at # https://github.com/facebookresearch/hydra/issues/326 and as fixed in # https://github.com/facebookresearch/hydra/pull/1044. self.data_dir = tempfile.mkdtemp() self.database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(self.database_path) self.config = augment_config_from_db(self.config, self.db) self.config.mephisto.architect.should_run_server = True def _set_up_server(self, shared_state: Optional[SharedTaskState] = None): """ Set up the operator and server. """ self.operator = Operator(self.db) self.operator.validate_and_run_config(self.config.mephisto, shared_state=shared_state) self.server = self._get_channel_info().job.architect.server def _get_channel_info(self): """ Return channel info for the currently running job. """ channels = list(self.operator.supervisor.channels.values()) if len(channels) > 0: return channels[0] else: raise ValueError('No channel could be detected!') def _register_mock_agents(self, num_agents: int = 1) -> List[str]: """ Register mock agents for testing, taking the place of crowdsourcing workers. Specify the number of agents to register. Return the agents' IDs after creation. """ for idx in range(num_agents): # Register the worker mock_worker_name = f"MOCK_WORKER_{idx:d}" max_num_tries = 6 initial_wait_time = 0.5 # In seconds num_tries = 0 wait_time = initial_wait_time while num_tries < max_num_tries: try: self.server.register_mock_worker(mock_worker_name) break except IndexError: num_tries += 1 print( f'A subscriber could not be found after {num_tries:d} ' f'attempt(s), out of {max_num_tries:d} attempts total. Waiting ' f'for {wait_time:0.1f} seconds...') time.sleep(wait_time) wait_time *= 2 # Wait for longer next time else: raise ValueError('The worker could not be registered!') workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register the agent mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}" self.server.register_mock_agent(worker_id, mock_agent_details) # Get all agents' IDs agents = self.db.find_agents() agent_ids = [agent.db_id for agent in agents] return agent_ids
class TestSupervisor(unittest.TestCase): """ Unit testing for the Mephisto Supervisor, uses WebsocketChannel and MockArchitect """ def setUp(self): self.data_dir = tempfile.mkdtemp() database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(database_path) self.task_id = self.db.new_task("test_mock", MockBlueprint.BLUEPRINT_TYPE) self.task_run_id = get_test_task_run(self.db) self.task_run = TaskRun(self.db, self.task_run_id) architect_config = OmegaConf.structured( MephistoConfig(architect=MockArchitectArgs( should_run_server=True))) self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE, self.task_run, self.data_dir) self.architect.prepare() self.architect.deploy() self.urls = self.architect._get_socket_urls() # FIXME self.url = self.urls[0] self.provider = MockProvider(self.db) self.provider.setup_resources_for_task_run(self.task_run, self.task_run.args, EMPTY_STATE, self.url) self.launcher = TaskLauncher(self.db, self.task_run, self.get_mock_assignment_data_array()) self.launcher.create_assignments() self.launcher.launch_units(self.url) self.sup = None def tearDown(self): if self.sup is not None: self.sup.shutdown() self.launcher.expire_units() self.architect.cleanup() self.architect.shutdown() self.db.shutdown() shutil.rmtree(self.data_dir, ignore_errors=True) def get_mock_assignment_data_array(self) -> List[InitializationData]: mock_data = MockTaskRunner.get_mock_assignment_data() return [mock_data, mock_data] def test_initialize_supervisor(self): """Ensure that the supervisor object can even be created""" sup = Supervisor(self.db) self.assertIsNotNone(sup) self.assertDictEqual(sup.agents, {}) self.assertDictEqual(sup.channels, {}) sup.shutdown() def test_channel_operations(self): """ Initialize a channel, and ensure the basic startup and shutdown functions are working """ sup = Supervisor(self.db) self.sup = sup TaskRunnerClass = MockBlueprint.TaskRunnerClass args = MockBlueprint.ArgsClass() config = OmegaConf.structured(MephistoConfig(blueprint=args)) task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE) test_job = Job( architect=self.architect, task_runner=task_runner, provider=self.provider, qualifications=[], registered_channel_ids=[], ) channels = self.architect.get_channels(sup._on_channel_open, sup._on_catastrophic_disconnect, sup._on_message) channel = channels[0] channel.open() channel_id = channel.channel_id self.assertIsNotNone(channel_id) channel.close() self.assertTrue(channel.is_closed()) def test_register_concurrent_job(self): """Test registering and running a job that requires multiple workers""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TaskRunnerClass = MockBlueprint.TaskRunnerClass args = MockBlueprint.ArgsClass() args.timeout_time = 5 args.is_concurrent = False config = OmegaConf.structured(MephistoConfig(blueprint=args)) task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 1, "Agent not registered with supervisor") self.assertEqual(len(task_runner.running_units), 1, "Ready task was not launched") # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" self.architect.server.register_mock_agent(worker_id, mock_agent_details) self.assertEqual(len(task_runner.running_units), 2, "Tasks were not launched") agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operations to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_units) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed) def test_register_job(self): """Test registering and running a job run asynchronously""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TaskRunnerClass = MockBlueprint.TaskRunnerClass args = MockBlueprint.ArgsClass() args.timeout_time = 5 config = OmegaConf.structured(MephistoConfig(blueprint=args)) task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive()) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 1, "Agent not registered with supervisor") self.assertEqual(len(task_runner.running_assignments), 0, "Task was not yet ready") # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" self.architect.server.register_mock_agent(worker_id, mock_agent_details) self.assertEqual(len(task_runner.running_assignments), 1, "Task was not launched") agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operation to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_assignments) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed()) def test_register_concurrent_job_with_onboarding(self): """Test registering and running a job with onboarding""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TEST_QUALIFICATION_NAME = "test_onboarding_qualification" task_run_args = self.task_run.args task_run_args.blueprint.use_onboarding = True task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME task_run_args.blueprint.timeout_time = 5 task_run_args.blueprint.is_concurrent = True self.task_run.get_task_config() # Supervisor expects that blueprint setup has already occurred blueprint = self.task_run.get_blueprint() TaskRunnerClass = MockBlueprint.TaskRunnerClass task_runner = TaskRunnerClass(self.task_run, task_run_args, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive()) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) self.assertEqual(len(task_runner.running_units), 0) # Fail to register an agent who fails onboarding mock_worker_name = "BAD_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker_0 = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 1, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": False} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[0].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") self.assertEqual(len(sup.agents), 0, "Failed agent registered with supervisor") self.assertEqual( len(task_runner.running_units), 0, "Task should not launch with failed worker", ) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker_1 = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_assignments), 0) # Fail to register a blocked agent mock_agent_details = "FAKE_ASSIGNMENT" qualification_id = blueprint.onboarding_qualification_id self.db.grant_qualification(qualification_id, worker_1.db_id, 0) self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet, failed onboarding") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for disqualified worker", ) self.assertIsNone(last_packet["data"]["agent_id"], "worker assigned real agent id") self.architect.server.last_packet = None self.db.revoke_qualification(qualification_id, worker_id) # Register an onboarding agent successfully mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 2, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": True} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[1].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent not created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 1, "Agent not registered with supervisor") self.assertEqual( len(task_runner.running_assignments), 0, "Task was not yet ready, should not launch", ) # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_2 = workers[0] worker_id = worker_2.db_id # Register an agent that is already qualified mock_agent_details = "FAKE_ASSIGNMENT_2" self.db.grant_qualification(qualification_id, worker_2.db_id, 1) self.architect.server.register_mock_agent(worker_id, mock_agent_details) time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for qualified agent", ) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Second agent not created without onboarding") self.assertEqual(len(task_runner.running_assignments), 1, "Task was not launched") self.assertFalse(worker_0.is_qualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_0.is_disqualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_1.is_qualified(TEST_QUALIFICATION_NAME)) self.assertFalse(worker_1.is_disqualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME)) self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME)) agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operation to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_assignments) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed()) def test_register_job_with_onboarding(self): """Test registering and running a job with onboarding""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TEST_QUALIFICATION_NAME = "test_onboarding_qualification" # Register onboarding arguments for blueprint task_run_args = self.task_run.args task_run_args.blueprint.use_onboarding = True task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME task_run_args.blueprint.timeout_time = 5 task_run_args.blueprint.is_concurrent = False self.task_run.get_task_config() # Supervisor expects that blueprint setup has already occurred blueprint = self.task_run.get_blueprint() TaskRunnerClass = MockBlueprint.TaskRunnerClass task_runner = TaskRunnerClass(self.task_run, task_run_args, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive()) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker_1 = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_units), 0) # Fail to register a blocked agent mock_agent_details = "FAKE_ASSIGNMENT" qualification_id = blueprint.onboarding_qualification_id self.db.grant_qualification(qualification_id, worker_1.db_id, 0) self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet, failed onboarding") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for disqualified worker", ) self.assertIsNone(last_packet["data"]["agent_id"], "worker assigned real agent id") self.architect.server.last_packet = None self.db.revoke_qualification(qualification_id, worker_id) # Register an agent successfully mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 1, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": False} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[0].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") self.assertEqual(len(sup.agents), 0, "Failed agent registered with supervisor") self.assertEqual( len(task_runner.running_units), 0, "Task should not launch with failed worker", ) # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_2 = workers[0] worker_id = worker_2.db_id # Register an agent that is already qualified mock_agent_details = "FAKE_ASSIGNMENT_2" self.db.grant_qualification(qualification_id, worker_2.db_id, 1) self.architect.server.register_mock_agent(worker_id, mock_agent_details) time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for qualified agent", ) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Second agent not created without onboarding") self.assertEqual(len(task_runner.running_units), 1, "Tasks were not launched") self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME)) self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME)) # Register another worker mock_worker_name = "MOCK_WORKER_3" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_3 = workers[0] worker_id = worker_3.db_id mock_agent_details = "FAKE_ASSIGNMENT_3" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 2, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": True} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[1].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Agent not created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Duplicate agent created after onboarding") agent = agents[1] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 2, "Agent not registered supervisor after onboarding") self.assertEqual(len(task_runner.running_units), 2, "Task not launched after onboarding") agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operation to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_units) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed())
class TestOperator(unittest.TestCase): """ Unit testing for the Mephisto Supervisor """ def setUp(self): self.data_dir = tempfile.mkdtemp() database_path = os.path.join(self.data_dir, "mephisto.db") self.db = LocalMephistoDB(database_path) self.requester_name, _req_id = get_test_requester(self.db) self.operator = None def tearDown(self): if self.operator is not None: self.operator.shutdown() self.db.shutdown() shutil.rmtree(self.data_dir, ignore_errors=True) self.assertTrue( len(threading.enumerate()) == 1, f"Expected only main thread at teardown, found {threading.enumerate()}", ) def wait_for_complete_assignment(self, assignment, timeout: int): start_time = time.time() while time.time() - start_time < timeout: if assignment.get_status() == AssignmentState.COMPLETED: break time.sleep(0.1) self.assertLess(time.time() - start_time, timeout, "Assignment not completed in time") def await_server_start(self, architect: "MockArchitect"): start_time = time.time() assert architect.server is not None, "Cannot wait on empty server" while time.time() - start_time < 5: if len(architect.server.subs) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, 5, "Mock server not up in time") def test_initialize_supervisor(self): """Quick test to ensure that the operator can be initialized""" self.operator = Operator(self.db) def test_run_job_concurrent(self): """Ensure that the supervisor object can even be created""" self.operator = Operator(self.db) config = MephistoConfig( blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True), provider=MockProviderArgs(requester_name=self.requester_name), architect=MockArchitectArgs(should_run_server=True), task=MOCK_TASK_ARGS, ) self.operator.validate_and_run_config(OmegaConf.structured(config)) tracked_runs = self.operator.get_running_task_runs() self.assertEqual(len(tracked_runs), 1, "Run not launched") task_run_id, tracked_run = list(tracked_runs.items())[0] self.assertIsNotNone(tracked_run) self.assertIsNotNone(tracked_run.task_launcher) self.assertIsNotNone(tracked_run.task_runner) self.assertIsNotNone(tracked_run.architect) self.assertIsNotNone(tracked_run.task_run) self.assertEqual(tracked_run.task_run.db_id, task_run_id) # Create two agents to step through the task architect = tracked_run.architect self.assertIsInstance(architect, MockArchitect, "Must use mock in testing") # Register a worker mock_worker_name = "MOCK_WORKER" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id self.assertEqual(len(tracked_run.task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") agent = agents[0] self.assertIsNotNone(agent) # Register another worker mock_worker_name = "MOCK_WORKER_2" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" architect.server.register_mock_agent(worker_id, mock_agent_details) # Give up to 5 seconds for whole mock task to complete start_time = time.time() while time.time() - start_time < TIMEOUT_TIME: if len(self.operator.get_running_task_runs()) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") # Ensure the assignment is completed task_run = tracked_run.task_run assignment = task_run.get_assignments()[0] self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED) def test_run_job_not_concurrent(self): """Ensure that the supervisor object can even be created""" self.operator = Operator(self.db) config = MephistoConfig( blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=False), provider=MockProviderArgs(requester_name=self.requester_name), architect=MockArchitectArgs(should_run_server=True), task=MOCK_TASK_ARGS, ) self.operator.validate_and_run_config(OmegaConf.structured(config)) tracked_runs = self.operator.get_running_task_runs() self.assertEqual(len(tracked_runs), 1, "Run not launched") task_run_id, tracked_run = list(tracked_runs.items())[0] self.assertIsNotNone(tracked_run) self.assertIsNotNone(tracked_run.task_launcher) self.assertIsNotNone(tracked_run.task_runner) self.assertIsNotNone(tracked_run.architect) self.assertIsNotNone(tracked_run.task_run) self.assertEqual(tracked_run.task_run.db_id, task_run_id) # Create two agents to step through the task architect = tracked_run.architect self.assertIsInstance(architect, MockArchitect, "Must use mock in testing") # Register a worker mock_worker_name = "MOCK_WORKER" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id self.assertEqual(len(tracked_run.task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") agent = agents[0] self.assertIsNotNone(agent) # Register another worker mock_worker_name = "MOCK_WORKER_2" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" architect.server.register_mock_agent(worker_id, mock_agent_details) # Give up to 5 seconds for both tasks to complete start_time = time.time() while time.time() - start_time < TIMEOUT_TIME: if len(self.operator.get_running_task_runs()) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") # Ensure the assignment is completed task_run = tracked_run.task_run assignment = task_run.get_assignments()[0] self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED) def test_run_jobs_with_restrictions(self): """Ensure allowed_concurrent and maximum_units_per_worker work""" self.operator = Operator(self.db) provider_args = MockProviderArgs(requester_name=self.requester_name) architect_args = MockArchitectArgs(should_run_server=True) config = MephistoConfig( blueprint=MockBlueprintArgs(num_assignments=3, is_concurrent=True), provider=provider_args, architect=architect_args, task=TaskConfigArgs( task_title="title", task_description="This is a description", task_reward="0.3", task_tags="1,2,3", maximum_units_per_worker=2, allowed_concurrent=1, task_name="max-unit-test", ), ) self.operator.validate_and_run_config(OmegaConf.structured(config)) tracked_runs = self.operator.get_running_task_runs() self.assertEqual(len(tracked_runs), 1, "Run not launched") task_run_id, tracked_run = list(tracked_runs.items())[0] self.assertIsNotNone(tracked_run) self.assertIsNotNone(tracked_run.task_launcher) self.assertIsNotNone(tracked_run.task_runner) self.assertIsNotNone(tracked_run.architect) self.assertIsNotNone(tracked_run.task_run) self.assertEqual(tracked_run.task_run.db_id, task_run_id) self.await_server_start(tracked_run.architect) # Create two agents to step through the task architect = tracked_run.architect self.assertIsInstance(architect, MockArchitect, "Must use mock in testing") # Register a worker mock_worker_name = "MOCK_WORKER" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id_1 = workers[0].db_id self.assertEqual(len(tracked_run.task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" architect.server.register_mock_agent(worker_id_1, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") agent = agents[0] self.assertIsNotNone(agent) # Try to register a second agent, which should fail due to concurrency mock_agent_details = "FAKE_ASSIGNMENT_2" architect.server.register_mock_agent(worker_id_1, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Second agent was created") # Register another worker mock_worker_name = "MOCK_WORKER_2" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id_2 = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" architect.server.register_mock_agent(worker_id_2, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Second agent was not created") # wait for task to pass self.wait_for_complete_assignment( agents[1].get_unit().get_assignment(), 3) # Pass a second task as well mock_agent_details = "FAKE_ASSIGNMENT_3" architect.server.register_mock_agent(worker_id_1, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 3, "Agent was not created properly") mock_agent_details = "FAKE_ASSIGNMENT_4" architect.server.register_mock_agent(worker_id_2, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 4, "Fourth agent was not created") # wait for task to pass self.wait_for_complete_assignment( agents[3].get_unit().get_assignment(), 3) # Both workers should have saturated their tasks, and not be granted agents mock_agent_details = "FAKE_ASSIGNMENT_5" architect.server.register_mock_agent(worker_id_1, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 4, "Additional agent was created") architect.server.register_mock_agent(worker_id_2, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 4, "Additional agent was created") # new workers should be able to work on these just fine though mock_worker_name = "MOCK_WORKER_3" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id_3 = workers[0].db_id mock_worker_name = "MOCK_WORKER_4" architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id_4 = workers[0].db_id # Register agents from new workers mock_agent_details = "FAKE_ASSIGNMENT_5" architect.server.register_mock_agent(worker_id_3, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 5, "Additional agent was not created") mock_agent_details = "FAKE_ASSIGNMENT_6" architect.server.register_mock_agent(worker_id_4, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 6, "Additional agent was not created") # wait for task to pass self.wait_for_complete_assignment( agents[5].get_unit().get_assignment(), 3) # Give up to 5 seconds for whole mock task to complete start_time = time.time() while time.time() - start_time < TIMEOUT_TIME: if len(self.operator.get_running_task_runs()) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") # Ensure all assignments are completed task_run = tracked_run.task_run assignments = task_run.get_assignments() for assignment in assignments: self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED) # Create a new task config = MephistoConfig( blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True), provider=MockProviderArgs(requester_name=self.requester_name), architect=MockArchitectArgs(should_run_server=True), task=TaskConfigArgs( task_title="title", task_description="This is a description", task_reward="0.3", task_tags="1,2,3", maximum_units_per_worker=2, allowed_concurrent=1, task_name="max-unit-test", ), ) self.operator.validate_and_run_config(OmegaConf.structured(config)) tracked_runs = self.operator.get_running_task_runs() self.assertEqual(len(tracked_runs), 1, "Run not launched") task_run_id, tracked_run = list(tracked_runs.items())[0] self.await_server_start(tracked_run.architect) architect = tracked_run.architect # Workers one and two still shouldn't be able to make agents mock_agent_details = "FAKE_ASSIGNMENT_7" architect.server.register_mock_agent(worker_id_1, mock_agent_details) agents = self.db.find_agents() self.assertEqual( len(agents), 6, "Additional agent was created for worker exceeding max units", ) mock_agent_details = "FAKE_ASSIGNMENT_7" architect.server.register_mock_agent(worker_id_2, mock_agent_details) agents = self.db.find_agents() self.assertEqual( len(agents), 6, "Additional agent was created for worker exceeding max units", ) # Three and four should though mock_agent_details = "FAKE_ASSIGNMENT_7" architect.server.register_mock_agent(worker_id_3, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 7, "Additional agent was not created") mock_agent_details = "FAKE_ASSIGNMENT_8" architect.server.register_mock_agent(worker_id_4, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 8, "Additional agent was not created") # Ensure the task run completed and that all assignments are done start_time = time.time() while time.time() - start_time < TIMEOUT_TIME: if len(self.operator.get_running_task_runs()) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") task_run = tracked_run.task_run assignments = task_run.get_assignments() for assignment in assignments: self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)