def test_initialize_supervisor(self): """Ensure that the supervisor object can even be created""" sup = Supervisor(self.db) self.assertIsNotNone(sup) self.assertDictEqual(sup.agents, {}) self.assertDictEqual(sup.channels, {}) sup.shutdown()
def __init__(self, db: "MephistoDB"): self.db = db self.supervisor = Supervisor(db) self._task_runs_tracked: Dict[str, TrackedRun] = {} self.is_shutdown = False self._run_tracker_thread = threading.Thread( target=self._track_and_kill_runs, name="Operator-tracking-thread") self._run_tracker_thread.start()
def test_channel_operations(self): """ Initialize a channel, and ensure the basic startup and shutdown functions are working """ sup = Supervisor(self.db) self.sup = sup TaskRunnerClass = MockBlueprint.TaskRunnerClass args = MockBlueprint.ArgsClass() config = OmegaConf.structured(MephistoConfig(blueprint=args)) task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE) test_job = Job( architect=self.architect, task_runner=task_runner, provider=self.provider, qualifications=[], registered_channel_ids=[], ) channels = self.architect.get_channels(sup._on_channel_open, sup._on_catastrophic_disconnect, sup._on_message) channel = channels[0] channel.open() channel_id = channel.channel_id self.assertIsNotNone(channel_id) channel.close() self.assertTrue(channel.is_closed())
class Operator: """ Acting as the controller behind the curtain, the Operator class is responsible for managing the knobs, switches, and dials of the rest of the Mephisto architecture. Most convenience scripts for using Mephisto will use an Operator to get the job done, though this class itself is also a good model to use to understand how the underlying architecture works in order to build custom jobs or workflows. """ def __init__(self, db: "MephistoDB"): self.db = db self.supervisor = Supervisor(db) self._task_runs_tracked: Dict[str, TrackedRun] = {} self.is_shutdown = False self._run_tracker_thread = threading.Thread( target=self._track_and_kill_runs, name="Operator-tracking-thread" ) self._run_tracker_thread.start() @staticmethod def _get_baseline_argparser() -> ArgumentParser: """Return a parser for the baseline requirements to launch a job""" parser = ArgumentParser() parser.add_argument( "--blueprint-type", dest="blueprint_type", help="Name of the blueprint to launch", required=True, ) parser.add_argument( "--architect-type", dest="architect_type", help="Name of the architect to launch with", required=True, ) parser.add_argument( "--requester-name", dest="requester_name", help="Identifier for the requester to launch as", required=True, ) return parser @staticmethod def _parse_args_from_classes( BlueprintClass: Type["Blueprint"], ArchitectClass: Type["Architect"], CrowdProviderClass: Type["CrowdProvider"], argument_list: List[str], ) -> Tuple[Dict[str, Any], List[str]]: """Parse the given arguments over the parsers for the given types""" # Create the parser parser = ArgumentParser() blueprint_group = parser.add_argument_group("blueprint") BlueprintClass.add_args_to_group(blueprint_group) provider_group = parser.add_argument_group("crowd_provider") CrowdProviderClass.add_args_to_group(provider_group) architect_group = parser.add_argument_group("architect") ArchitectClass.add_args_to_group(architect_group) task_group = parser.add_argument_group("task_config") TaskConfig.add_args_to_group(task_group) # Return parsed args try: known, unknown = parser.parse_known_args(argument_list) except SystemExit: raise Exception("Argparse broke - must fix") return vars(known), unknown def get_running_task_runs(self): """Return the currently running task runs and their handlers""" return self._task_runs_tracked.copy() # TODO(#94) there should be a way to provide default arguments via a config file def parse_and_launch_run( self, arg_list: Optional[List[str]] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> str: """ Parse the given arguments and launch a job. """ if extra_args is None: extra_args = {} # Extract the abstractions being used parser = self._get_baseline_argparser() type_args, task_args_string = parser.parse_known_args(arg_list) requesters = self.db.find_requesters(requester_name=type_args.requester_name) if len(requesters) == 0: raise EntryDoesNotExistException( f"No requester found with name {type_args.requester_name}" ) requester = requesters[0] requester_id = requester.db_id provider_type = requester.provider_type # Parse the arguments for the abstractions to ensure # everything required is set BlueprintClass = get_blueprint_from_type(type_args.blueprint_type) ArchitectClass = get_architect_from_type(type_args.architect_type) CrowdProviderClass = get_crowd_provider_from_type(provider_type) task_args, _unknown = self._parse_args_from_classes( BlueprintClass, ArchitectClass, CrowdProviderClass, task_args_string ) task_args.update(extra_args) # Load the classes to force argument validation before anything # is actually created in the database # TODO(#94) perhaps parse the arguments for these things one at a time? BlueprintClass.assert_task_args(task_args) ArchitectClass.assert_task_args(task_args) CrowdProviderClass.assert_task_args(task_args) # Find an existing task or create a new one task_name = task_args.get("task_name") if task_name is None: task_name = type_args.blueprint_type logger.warning( f"Task is using the default blueprint name {task_name} as a name, as no task_name is provided" ) tasks = self.db.find_tasks(task_name=task_name) task_id = None if len(tasks) == 0: task_id = self.db.new_task(task_name, type_args.blueprint_type) else: task_id = tasks[0].db_id logger.info(f"Creating a task run under task name: {task_name}") # Create a new task run new_run_id = self.db.new_task_run( task_id, requester_id, " ".join([shlex.quote(x) for x in task_args_string]), provider_type, type_args.blueprint_type, requester.is_sandbox(), ) task_run = TaskRun(self.db, new_run_id) try: # If anything fails after here, we have to cleanup the architect build_dir = os.path.join(task_run.get_run_dir(), "build") os.makedirs(build_dir, exist_ok=True) architect = ArchitectClass(self.db, task_args, task_run, build_dir) # Register the blueprint with args to the task run, # ensure cached blueprint = BlueprintClass(task_run, task_args) task_run.get_blueprint(opts=task_args) # Setup and deploy the server built_dir = architect.prepare() task_url = architect.deploy() # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only # happen after everything has already been reviewed, this way it's possible to # retrieve the exact build directory to review a task for real architect.cleanup() # Create the backend runner task_runner = BlueprintClass.TaskRunnerClass(task_run, task_args) # Small hack for auto appending block qualification existing_qualifications = task_args.get("qualifications", []) if task_args.get("block_qualification") is not None: existing_qualifications.append( make_qualification_dict( task_args["block_qualification"], QUAL_NOT_EXIST, None ) ) if task_args.get("onboarding_qualification") is not None: existing_qualifications.append( make_qualification_dict( OnboardingRequired.get_failed_qual( task_args["onboarding_qualification"] ), QUAL_NOT_EXIST, None, ) ) task_args["qualifications"] = existing_qualifications # Register the task with the provider provider = CrowdProviderClass(self.db) provider.setup_resources_for_task_run(task_run, task_args, task_url) initialization_data_array = blueprint.get_initialization_data() # Link the job together job = self.supervisor.register_job( architect, task_runner, provider, existing_qualifications ) if self.supervisor.sending_thread is None: self.supervisor.launch_sending_thread() except (KeyboardInterrupt, Exception) as e: logger.error( "Encountered error while launching run, shutting down", exc_info=True ) try: architect.shutdown() except (KeyboardInterrupt, Exception) as architect_exception: logger.exception( f"Could not shut down architect: {architect_exception}", exc_info=True, ) raise e launcher = TaskLauncher(self.db, task_run, initialization_data_array) launcher.create_assignments() launcher.launch_units(task_url) self._task_runs_tracked[task_run.db_id] = TrackedRun( task_run=task_run, task_launcher=launcher, task_runner=task_runner, architect=architect, job=job, ) return task_run.db_id def _track_and_kill_runs(self): """ Background thread that shuts down servers when a task is fully done. """ while not self.is_shutdown: runs_to_check = list(self._task_runs_tracked.values()) for tracked_run in runs_to_check: task_run = tracked_run.task_run if task_run.get_is_completed(): self.supervisor.shutdown_job(tracked_run.job) tracked_run.architect.shutdown() tracked_run.task_launcher.shutdown() del self._task_runs_tracked[task_run.db_id] time.sleep(2) def shutdown(self, skip_input=True): logger.info("operator shutting down") self.is_shutdown = True for tracked_run in self._task_runs_tracked.values(): logger.info("expiring units") tracked_run.task_launcher.shutdown() tracked_run.task_launcher.expire_units() try: remaining_runs = self._task_runs_tracked.values() while len(remaining_runs) > 0: next_runs = [] for tracked_run in remaining_runs: if tracked_run.task_run.get_is_completed(): tracked_run.architect.shutdown() else: next_runs.append(tracked_run) if len(next_runs) > 0: logger.info( f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT" ) time.sleep(30) remaining_runs = next_runs except Exception as e: logger.exception( f"Encountered problem during shutting down {e}", exc_info=True ) import traceback traceback.print_exc() except (KeyboardInterrupt, SystemExit) as e: logger.info( "Skipping waiting for outstanding task completions, shutting down servers now!" ) for tracked_run in remaining_runs: tracked_run.architect.shutdown() finally: self.supervisor.shutdown() self._run_tracker_thread.join() def parse_and_launch_run_wrapper( self, arg_list: Optional[List[str]] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """ Wrapper around parse and launch run that prints errors on failure, rather than throwing. Generally for use in scripts. """ try: return self.parse_and_launch_run(arg_list=arg_list, extra_args=extra_args) except (KeyboardInterrupt, Exception) as e: logger.error("Ran into error while launching run: ", exc_info=True) return None def print_run_details(self): """Print details about running tasks""" # TODO(#93) parse these tasks and get the full details for task in self.get_running_task_runs(): logger.info(f"Operator running task ID = {task}") def wait_for_runs_then_shutdown( self, skip_input=False, log_rate: Optional[int] = None ) -> None: """ Wait for task_runs to complete, and then shutdown. Set log_rate to get print statements of currently running tasks at the specified interval """ try: try: last_log = 0.0 while len(self.get_running_task_runs()) > 0: if log_rate is not None: if time.time() - last_log > log_rate: last_log = time.time() self.print_run_details() time.sleep(10) except Exception as e: if skip_input: raise e traceback.print_exc() should_quit = input( "The above exception happened while running a task, do " "you want to shut down? (y)/n: " ) if should_quit not in ["n", "N", "no", "No"]: raise e except Exception as e: import traceback traceback.print_exc() except (KeyboardInterrupt, SystemExit) as e: logger.exception( "Cleaning up after keyboard interrupt, please wait!", exc_info=True ) finally: self.shutdown()
class Operator: """ Acting as the controller behind the curtain, the Operator class is responsible for managing the knobs, switches, and dials of the rest of the Mephisto architecture. Most convenience scripts for using Mephisto will use an Operator to get the job done, though this class itself is also a good model to use to understand how the underlying architecture works in order to build custom jobs or workflows. """ def __init__(self, db: "MephistoDB"): self.db = db self.supervisor = Supervisor(db) self._task_runs_tracked: Dict[str, TrackedRun] = {} self.is_shutdown = False self._run_tracker_thread = threading.Thread( target=self._track_and_kill_runs, name="Operator-tracking-thread") self._run_tracker_thread.start() @staticmethod def _get_baseline_argparser() -> ArgumentParser: """Return a parser for the baseline requirements to launch a job""" parser = ArgumentParser() parser.add_argument( "--blueprint-type", dest="blueprint_type", help="Name of the blueprint to launch", required=True, ) parser.add_argument( "--architect-type", dest="architect_type", help="Name of the architect to launch with", required=True, ) parser.add_argument( "--requester-name", dest="requester_name", help="Identifier for the requester to launch as", required=True, ) return parser def get_running_task_runs(self): """Return the currently running task runs and their handlers""" return self._task_runs_tracked.copy() def parse_and_launch_run( self, arg_list: Optional[List[str]] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """ Wrapper around parse and launch run that prints errors on failure, rather than throwing. Generally for use in scripts. """ raise Exception( 'Operator.parse_and_launch_run has been deprecated in favor ' 'of using Hydra for argument configuration. See the docs at ' 'https://github.com/facebookresearch/Mephisto/blob/master/docs/hydra_migration.md ' 'in order to upgrade.') def validate_and_run_config_or_die( self, run_config: DictConfig, shared_state: Optional[SharedTaskState] = None, ) -> str: """ Parse the given arguments and launch a job. """ if shared_state is None: shared_state = SharedTaskState() # First try to find the requester: requester_name = run_config.provider.requester_name requesters = self.db.find_requesters(requester_name=requester_name) if len(requesters) == 0: if run_config.provider.requester_name == "MOCK_REQUESTER": requesters = [get_mock_requester(self.db)] else: raise EntryDoesNotExistException( f"No requester found with name {requester_name}") requester = requesters[0] requester_id = requester.db_id provider_type = requester.provider_type assert provider_type == run_config.provider._provider_type, ( f"Found requester for name {requester_name} is not " f"of the specified type {run_config.provider._provider_type}, " f"but is instead {provider_type}.") # Next get the abstraction classes, and run validation # before anything is actually created in the database blueprint_type = run_config.blueprint._blueprint_type architect_type = run_config.architect._architect_type BlueprintClass = get_blueprint_from_type(blueprint_type) ArchitectClass = get_architect_from_type(architect_type) CrowdProviderClass = get_crowd_provider_from_type(provider_type) BlueprintClass.assert_task_args(run_config, shared_state) ArchitectClass.assert_task_args(run_config, shared_state) CrowdProviderClass.assert_task_args(run_config, shared_state) # Find an existing task or create a new one task_name = run_config.task.get("task_name", None) if task_name is None: task_name = blueprint_type logger.warning( f"Task is using the default blueprint name {task_name} as a name, " "as no task_name is provided") tasks = self.db.find_tasks(task_name=task_name) task_id = None if len(tasks) == 0: task_id = self.db.new_task(task_name, blueprint_type) else: task_id = tasks[0].db_id logger.info(f"Creating a task run under task name: {task_name}") # Create a new task run new_run_id = self.db.new_task_run( task_id, requester_id, json.dumps(OmegaConf.to_container(run_config, resolve=True)), provider_type, blueprint_type, requester.is_sandbox(), ) task_run = TaskRun(self.db, new_run_id) try: # If anything fails after here, we have to cleanup the architect build_dir = os.path.join(task_run.get_run_dir(), "build") os.makedirs(build_dir, exist_ok=True) architect = ArchitectClass(self.db, run_config, shared_state, task_run, build_dir) # Register the blueprint with args to the task run, # ensure cached blueprint = BlueprintClass(task_run, run_config, shared_state) task_run.get_blueprint(args=run_config, shared_state=shared_state) # Setup and deploy the server built_dir = architect.prepare() task_url = architect.deploy() # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only # happen after everything has already been reviewed, this way it's possible to # retrieve the exact build directory to review a task for real architect.cleanup() # Create the backend runner task_runner = BlueprintClass.TaskRunnerClass( task_run, run_config, shared_state) # Small hack for auto appending block qualification existing_qualifications = shared_state.qualifications if run_config.blueprint.get("block_qualification", None) is not None: existing_qualifications.append( make_qualification_dict( run_config.blueprint.block_qualification, QUAL_NOT_EXIST, None)) if run_config.blueprint.get("onboarding_qualification", None) is not None: existing_qualifications.append( make_qualification_dict( OnboardingRequired.get_failed_qual( run_config.blueprint.onboarding_qualification, ), QUAL_NOT_EXIST, None, )) shared_state.qualifications = existing_qualifications # Register the task with the provider provider = CrowdProviderClass(self.db) provider.setup_resources_for_task_run(task_run, run_config, task_url) initialization_data_array = blueprint.get_initialization_data() # Link the job together job = self.supervisor.register_job(architect, task_runner, provider, existing_qualifications) if self.supervisor.sending_thread is None: self.supervisor.launch_sending_thread() except (KeyboardInterrupt, Exception) as e: logger.error( "Encountered error while launching run, shutting down", exc_info=True) try: architect.shutdown() except (KeyboardInterrupt, Exception) as architect_exception: logger.exception( f"Could not shut down architect: {architect_exception}", exc_info=True, ) raise e launcher = TaskLauncher(self.db, task_run, initialization_data_array) launcher.create_assignments() launcher.launch_units(task_url) self._task_runs_tracked[task_run.db_id] = TrackedRun( task_run=task_run, task_launcher=launcher, task_runner=task_runner, architect=architect, job=job, ) task_run.update_completion_progress(status=False) return task_run.db_id def _track_and_kill_runs(self): """ Background thread that shuts down servers when a task is fully done. """ while not self.is_shutdown: runs_to_check = list(self._task_runs_tracked.values()) for tracked_run in runs_to_check: task_run = tracked_run.task_run task_run.update_completion_progress( task_launcher=tracked_run.task_launcher) if not task_run.get_is_completed(): continue else: self.supervisor.shutdown_job(tracked_run.job) tracked_run.architect.shutdown() tracked_run.task_launcher.shutdown() del self._task_runs_tracked[task_run.db_id] time.sleep(2) def shutdown(self, skip_input=True): logger.info("operator shutting down") self.is_shutdown = True for tracked_run in self._task_runs_tracked.values(): logger.info("expiring units") tracked_run.task_launcher.shutdown() tracked_run.task_launcher.expire_units() try: remaining_runs = self._task_runs_tracked.values() while len(remaining_runs) > 0: next_runs = [] for tracked_run in remaining_runs: if tracked_run.task_run.get_is_completed(): tracked_run.architect.shutdown() else: next_runs.append(tracked_run) if len(next_runs) > 0: logger.info( f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT" ) time.sleep(30) remaining_runs = next_runs except Exception as e: logger.exception(f"Encountered problem during shutting down {e}", exc_info=True) import traceback traceback.print_exc() except (KeyboardInterrupt, SystemExit) as e: logger.info( "Skipping waiting for outstanding task completions, shutting down servers now!" ) for tracked_run in remaining_runs: tracked_run.architect.shutdown() finally: self.supervisor.shutdown() self._run_tracker_thread.join() def validate_and_run_config( self, run_config: DictConfig, shared_state: Optional[SharedTaskState] = None, ) -> Optional[str]: """ Wrapper around validate_and_run_config_or_die that prints errors on failure, rather than throwing. Generally for use in scripts. """ try: return self.validate_and_run_config_or_die( run_config=run_config, shared_state=shared_state, ) except (KeyboardInterrupt, Exception) as e: logger.error("Ran into error while launching run: ", exc_info=True) return None def parse_and_launch_run_wrapper( self, arg_list: Optional[List[str]] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """ Wrapper around parse and launch run that prints errors on failure, rather than throwing. Generally for use in scripts. """ raise Exception( 'Operator.parse_and_launch_run_wrapper has been deprecated in favor ' 'of using Hydra for argument configuration. See the docs at ' 'https://github.com/facebookresearch/Mephisto/blob/master/docs/hydra_migration.md ' 'in order to upgrade.') def print_run_details(self): """Print details about running tasks""" # TODO(#93) parse these tasks and get the full details for task in self.get_running_task_runs(): logger.info(f"Operator running task ID = {task}") def wait_for_runs_then_shutdown(self, skip_input=False, log_rate: Optional[int] = None) -> None: """ Wait for task_runs to complete, and then shutdown. Set log_rate to get print statements of currently running tasks at the specified interval """ try: try: last_log = 0.0 while len(self.get_running_task_runs()) > 0: if log_rate is not None: if time.time() - last_log > log_rate: last_log = time.time() self.print_run_details() time.sleep(10) except Exception as e: if skip_input: raise e traceback.print_exc() should_quit = input( "The above exception happened while running a task, do " "you want to shut down? (y)/n: ") if should_quit not in ["n", "N", "no", "No"]: raise e except Exception as e: import traceback traceback.print_exc() except (KeyboardInterrupt, SystemExit) as e: logger.exception( "Cleaning up after keyboard interrupt, please wait!", exc_info=True) finally: self.shutdown()
def test_register_job_with_onboarding(self): """Test registering and running a job with onboarding""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TEST_QUALIFICATION_NAME = "test_onboarding_qualification" # Register onboarding arguments for blueprint task_run_args = self.task_run.args task_run_args.blueprint.use_onboarding = True task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME task_run_args.blueprint.timeout_time = 5 task_run_args.blueprint.is_concurrent = False self.task_run.get_task_config() # Supervisor expects that blueprint setup has already occurred blueprint = self.task_run.get_blueprint() TaskRunnerClass = MockBlueprint.TaskRunnerClass task_runner = TaskRunnerClass(self.task_run, task_run_args, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive()) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker_1 = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_units), 0) # Fail to register a blocked agent mock_agent_details = "FAKE_ASSIGNMENT" qualification_id = blueprint.onboarding_qualification_id self.db.grant_qualification(qualification_id, worker_1.db_id, 0) self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet, failed onboarding") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for disqualified worker", ) self.assertIsNone(last_packet["data"]["agent_id"], "worker assigned real agent id") self.architect.server.last_packet = None self.db.revoke_qualification(qualification_id, worker_id) # Register an agent successfully mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 1, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": False} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[0].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 0, "Failed agent created after onboarding") self.assertEqual(len(sup.agents), 0, "Failed agent registered with supervisor") self.assertEqual( len(task_runner.running_units), 0, "Task should not launch with failed worker", ) # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_2 = workers[0] worker_id = worker_2.db_id # Register an agent that is already qualified mock_agent_details = "FAKE_ASSIGNMENT_2" self.db.grant_qualification(qualification_id, worker_2.db_id, 1) self.architect.server.register_mock_agent(worker_id, mock_agent_details) time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertNotIn( "onboard_data", last_packet["data"], "Onboarding triggered for qualified agent", ) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Second agent not created without onboarding") self.assertEqual(len(task_runner.running_units), 1, "Tasks were not launched") self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME)) self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME)) # Register another worker mock_worker_name = "MOCK_WORKER_3" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_3 = workers[0] worker_id = worker_3.db_id mock_agent_details = "FAKE_ASSIGNMENT_3" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() self.assertEqual(len(onboard_agents), 2, "Onboarding agent should have been created") time.sleep(0.1) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent onboard_data = {"should_pass": True} self.architect.server.register_mock_agent_after_onboarding( worker_id, onboard_agents[1].get_agent_id(), onboard_data) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Agent not created after onboarding") # Re-register as if refreshing self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Duplicate agent created after onboarding") agent = agents[1] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 2, "Agent not registered supervisor after onboarding") self.assertEqual(len(task_runner.running_units), 2, "Task not launched after onboarding") agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operation to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_units) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed())
def test_register_job(self): """Test registering and running a job run asynchronously""" # Handle baseline setup sup = Supervisor(self.db) self.sup = sup TaskRunnerClass = MockBlueprint.TaskRunnerClass args = MockBlueprint.ArgsClass() args.timeout_time = 5 config = OmegaConf.structured(MephistoConfig(blueprint=args)) task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE) sup.register_job(self.architect, task_runner, self.provider) self.assertEqual(len(sup.channels), 1) channel_info = list(sup.channels.values())[0] self.assertIsNotNone(channel_info) self.assertTrue(channel_info.channel.is_alive()) channel_id = channel_info.channel_id task_runner = channel_info.job.task_runner self.assertIsNotNone(channel_id) self.assertEqual( len(self.architect.server.subs), 1, "MockServer doesn't see registered channel", ) self.assertIsNotNone( self.architect.server.last_alive_packet, "No alive packet received by server", ) sup.launch_sending_thread() self.assertIsNotNone(sup.sending_thread) # Register a worker mock_worker_name = "MOCK_WORKER" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker not successfully registered") worker = workers[0] self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) self.assertEqual(len(workers), 1, "Worker potentially re-registered") worker_id = workers[0].db_id self.assertEqual(len(task_runner.running_assignments), 0) # Register an agent mock_agent_details = "FAKE_ASSIGNMENT" self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent was not created properly") self.architect.server.register_mock_agent(worker_id, mock_agent_details) agents = self.db.find_agents() self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) self.assertEqual(len(sup.agents), 1, "Agent not registered with supervisor") self.assertEqual(len(task_runner.running_assignments), 0, "Task was not yet ready") # Register another worker mock_worker_name = "MOCK_WORKER_2" self.architect.server.register_mock_worker(mock_worker_name) workers = self.db.find_workers(worker_name=mock_worker_name) worker_id = workers[0].db_id # Register an agent mock_agent_details = "FAKE_ASSIGNMENT_2" self.architect.server.register_mock_agent(worker_id, mock_agent_details) self.assertEqual(len(task_runner.running_assignments), 1, "Task was not launched") agents = [a.agent for a in sup.agents.values()] # Make both agents act agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id agent_1_data = agents[0].datastore.agent_data[agent_id_1] agent_2_data = agents[1].datastore.agent_data[agent_id_2] self.architect.server.send_agent_act(agent_id_1, {"text": "message1"}) self.architect.server.send_agent_act(agent_id_2, {"text": "message2"}) # Give up to 1 seconds for the actual operation to occur start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(agent_1_data["acts"]) > 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not process messages in time") # Give up to 1 seconds for the task to complete afterwards start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if len(task_runner.running_assignments) == 0: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Did not complete task in time") # Give up to 1 seconds for all messages to propogate start_time = time.time() TIMEOUT_TIME = 1 while time.time() - start_time < TIMEOUT_TIME: if self.architect.server.actions_observed == 2: break time.sleep(0.1) self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Not all actions observed in time") sup.shutdown() self.assertTrue(channel_info.channel.is_closed())