Exemplo n.º 1
0
 def test_initialize_supervisor(self):
     """Ensure that the supervisor object can even be created"""
     sup = Supervisor(self.db)
     self.assertIsNotNone(sup)
     self.assertDictEqual(sup.agents, {})
     self.assertDictEqual(sup.channels, {})
     sup.shutdown()
Exemplo n.º 2
0
 def __init__(self, db: "MephistoDB"):
     self.db = db
     self.supervisor = Supervisor(db)
     self._task_runs_tracked: Dict[str, TrackedRun] = {}
     self.is_shutdown = False
     self._run_tracker_thread = threading.Thread(
         target=self._track_and_kill_runs, name="Operator-tracking-thread")
     self._run_tracker_thread.start()
Exemplo n.º 3
0
    def test_channel_operations(self):
        """
        Initialize a channel, and ensure the basic 
        startup and shutdown functions are working
        """
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        test_job = Job(
            architect=self.architect,
            task_runner=task_runner,
            provider=self.provider,
            qualifications=[],
            registered_channel_ids=[],
        )

        channels = self.architect.get_channels(sup._on_channel_open,
                                               sup._on_catastrophic_disconnect,
                                               sup._on_message)
        channel = channels[0]
        channel.open()
        channel_id = channel.channel_id
        self.assertIsNotNone(channel_id)
        channel.close()
        self.assertTrue(channel.is_closed())
Exemplo n.º 4
0
class Operator:
    """
    Acting as the controller behind the curtain, the Operator class
    is responsible for managing the knobs, switches, and dials 
    of the rest of the Mephisto architecture. 

    Most convenience scripts for using Mephisto will use an Operator
    to get the job done, though this class itself is also a 
    good model to use to understand how the underlying 
    architecture works in order to build custom jobs or workflows.
    """

    def __init__(self, db: "MephistoDB"):
        self.db = db
        self.supervisor = Supervisor(db)
        self._task_runs_tracked: Dict[str, TrackedRun] = {}
        self.is_shutdown = False
        self._run_tracker_thread = threading.Thread(
            target=self._track_and_kill_runs, name="Operator-tracking-thread"
        )
        self._run_tracker_thread.start()

    @staticmethod
    def _get_baseline_argparser() -> ArgumentParser:
        """Return a parser for the baseline requirements to launch a job"""
        parser = ArgumentParser()
        parser.add_argument(
            "--blueprint-type",
            dest="blueprint_type",
            help="Name of the blueprint to launch",
            required=True,
        )
        parser.add_argument(
            "--architect-type",
            dest="architect_type",
            help="Name of the architect to launch with",
            required=True,
        )
        parser.add_argument(
            "--requester-name",
            dest="requester_name",
            help="Identifier for the requester to launch as",
            required=True,
        )
        return parser

    @staticmethod
    def _parse_args_from_classes(
        BlueprintClass: Type["Blueprint"],
        ArchitectClass: Type["Architect"],
        CrowdProviderClass: Type["CrowdProvider"],
        argument_list: List[str],
    ) -> Tuple[Dict[str, Any], List[str]]:
        """Parse the given arguments over the parsers for the given types"""
        # Create the parser
        parser = ArgumentParser()
        blueprint_group = parser.add_argument_group("blueprint")
        BlueprintClass.add_args_to_group(blueprint_group)
        provider_group = parser.add_argument_group("crowd_provider")
        CrowdProviderClass.add_args_to_group(provider_group)
        architect_group = parser.add_argument_group("architect")
        ArchitectClass.add_args_to_group(architect_group)
        task_group = parser.add_argument_group("task_config")
        TaskConfig.add_args_to_group(task_group)

        # Return parsed args
        try:
            known, unknown = parser.parse_known_args(argument_list)
        except SystemExit:
            raise Exception("Argparse broke - must fix")
        return vars(known), unknown

    def get_running_task_runs(self):
        """Return the currently running task runs and their handlers"""
        return self._task_runs_tracked.copy()

    # TODO(#94) there should be a way to provide default arguments via a config file

    def parse_and_launch_run(
        self,
        arg_list: Optional[List[str]] = None,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> str:
        """
        Parse the given arguments and launch a job.
        """
        if extra_args is None:
            extra_args = {}
        # Extract the abstractions being used
        parser = self._get_baseline_argparser()
        type_args, task_args_string = parser.parse_known_args(arg_list)

        requesters = self.db.find_requesters(requester_name=type_args.requester_name)
        if len(requesters) == 0:
            raise EntryDoesNotExistException(
                f"No requester found with name {type_args.requester_name}"
            )
        requester = requesters[0]
        requester_id = requester.db_id
        provider_type = requester.provider_type

        # Parse the arguments for the abstractions to ensure
        # everything required is set
        BlueprintClass = get_blueprint_from_type(type_args.blueprint_type)
        ArchitectClass = get_architect_from_type(type_args.architect_type)
        CrowdProviderClass = get_crowd_provider_from_type(provider_type)
        task_args, _unknown = self._parse_args_from_classes(
            BlueprintClass, ArchitectClass, CrowdProviderClass, task_args_string
        )

        task_args.update(extra_args)

        # Load the classes to force argument validation before anything
        # is actually created in the database
        # TODO(#94) perhaps parse the arguments for these things one at a time?
        BlueprintClass.assert_task_args(task_args)
        ArchitectClass.assert_task_args(task_args)
        CrowdProviderClass.assert_task_args(task_args)

        # Find an existing task or create a new one
        task_name = task_args.get("task_name")
        if task_name is None:
            task_name = type_args.blueprint_type
            logger.warning(
                f"Task is using the default blueprint name {task_name} as a name, as no task_name is provided"
            )
        tasks = self.db.find_tasks(task_name=task_name)
        task_id = None
        if len(tasks) == 0:
            task_id = self.db.new_task(task_name, type_args.blueprint_type)
        else:
            task_id = tasks[0].db_id

        logger.info(f"Creating a task run under task name: {task_name}")

        # Create a new task run
        new_run_id = self.db.new_task_run(
            task_id,
            requester_id,
            " ".join([shlex.quote(x) for x in task_args_string]),
            provider_type,
            type_args.blueprint_type,
            requester.is_sandbox(),
        )
        task_run = TaskRun(self.db, new_run_id)

        try:
            # If anything fails after here, we have to cleanup the architect

            build_dir = os.path.join(task_run.get_run_dir(), "build")
            os.makedirs(build_dir, exist_ok=True)
            architect = ArchitectClass(self.db, task_args, task_run, build_dir)

            # Register the blueprint with args to the task run,
            # ensure cached
            blueprint = BlueprintClass(task_run, task_args)
            task_run.get_blueprint(opts=task_args)

            # Setup and deploy the server
            built_dir = architect.prepare()
            task_url = architect.deploy()

            # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only
            # happen after everything has already been reviewed, this way it's possible to
            # retrieve the exact build directory to review a task for real
            architect.cleanup()

            # Create the backend runner
            task_runner = BlueprintClass.TaskRunnerClass(task_run, task_args)

            # Small hack for auto appending block qualification
            existing_qualifications = task_args.get("qualifications", [])
            if task_args.get("block_qualification") is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        task_args["block_qualification"], QUAL_NOT_EXIST, None
                    )
                )
            if task_args.get("onboarding_qualification") is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        OnboardingRequired.get_failed_qual(
                            task_args["onboarding_qualification"]
                        ),
                        QUAL_NOT_EXIST,
                        None,
                    )
                )
            task_args["qualifications"] = existing_qualifications

            # Register the task with the provider
            provider = CrowdProviderClass(self.db)
            provider.setup_resources_for_task_run(task_run, task_args, task_url)

            initialization_data_array = blueprint.get_initialization_data()

            # Link the job together
            job = self.supervisor.register_job(
                architect, task_runner, provider, existing_qualifications
            )
            if self.supervisor.sending_thread is None:
                self.supervisor.launch_sending_thread()
        except (KeyboardInterrupt, Exception) as e:
            logger.error(
                "Encountered error while launching run, shutting down", exc_info=True
            )
            try:
                architect.shutdown()
            except (KeyboardInterrupt, Exception) as architect_exception:
                logger.exception(
                    f"Could not shut down architect: {architect_exception}",
                    exc_info=True,
                )
            raise e

        launcher = TaskLauncher(self.db, task_run, initialization_data_array)
        launcher.create_assignments()
        launcher.launch_units(task_url)

        self._task_runs_tracked[task_run.db_id] = TrackedRun(
            task_run=task_run,
            task_launcher=launcher,
            task_runner=task_runner,
            architect=architect,
            job=job,
        )
        return task_run.db_id

    def _track_and_kill_runs(self):
        """
        Background thread that shuts down servers when a task
        is fully done.
        """
        while not self.is_shutdown:
            runs_to_check = list(self._task_runs_tracked.values())
            for tracked_run in runs_to_check:
                task_run = tracked_run.task_run
                if task_run.get_is_completed():
                    self.supervisor.shutdown_job(tracked_run.job)
                    tracked_run.architect.shutdown()
                    tracked_run.task_launcher.shutdown()
                    del self._task_runs_tracked[task_run.db_id]
            time.sleep(2)

    def shutdown(self, skip_input=True):
        logger.info("operator shutting down")
        self.is_shutdown = True
        for tracked_run in self._task_runs_tracked.values():
            logger.info("expiring units")
            tracked_run.task_launcher.shutdown()
            tracked_run.task_launcher.expire_units()
        try:
            remaining_runs = self._task_runs_tracked.values()
            while len(remaining_runs) > 0:
                next_runs = []
                for tracked_run in remaining_runs:
                    if tracked_run.task_run.get_is_completed():
                        tracked_run.architect.shutdown()
                    else:
                        next_runs.append(tracked_run)
                if len(next_runs) > 0:
                    logger.info(
                        f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT"
                    )
                    time.sleep(30)
                remaining_runs = next_runs
        except Exception as e:
            logger.exception(
                f"Encountered problem during shutting down {e}", exc_info=True
            )
            import traceback

            traceback.print_exc()
        except (KeyboardInterrupt, SystemExit) as e:
            logger.info(
                "Skipping waiting for outstanding task completions, shutting down servers now!"
            )
            for tracked_run in remaining_runs:
                tracked_run.architect.shutdown()
        finally:
            self.supervisor.shutdown()
            self._run_tracker_thread.join()

    def parse_and_launch_run_wrapper(
        self,
        arg_list: Optional[List[str]] = None,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Optional[str]:
        """
        Wrapper around parse and launch run that prints errors on failure, rather
        than throwing. Generally for use in scripts.
        """
        try:
            return self.parse_and_launch_run(arg_list=arg_list, extra_args=extra_args)
        except (KeyboardInterrupt, Exception) as e:
            logger.error("Ran into error while launching run: ", exc_info=True)
            return None

    def print_run_details(self):
        """Print details about running tasks"""
        # TODO(#93) parse these tasks and get the full details
        for task in self.get_running_task_runs():
            logger.info(f"Operator running task ID = {task}")

    def wait_for_runs_then_shutdown(
        self, skip_input=False, log_rate: Optional[int] = None
    ) -> None:
        """
        Wait for task_runs to complete, and then shutdown.  

        Set log_rate to get print statements of currently running tasks
        at the specified interval
        """
        try:
            try:
                last_log = 0.0
                while len(self.get_running_task_runs()) > 0:
                    if log_rate is not None:
                        if time.time() - last_log > log_rate:
                            last_log = time.time()
                            self.print_run_details()
                    time.sleep(10)

            except Exception as e:
                if skip_input:
                    raise e

                traceback.print_exc()
                should_quit = input(
                    "The above exception happened while running a task, do "
                    "you want to shut down? (y)/n: "
                )
                if should_quit not in ["n", "N", "no", "No"]:
                    raise e

        except Exception as e:
            import traceback

            traceback.print_exc()
        except (KeyboardInterrupt, SystemExit) as e:
            logger.exception(
                "Cleaning up after keyboard interrupt, please wait!", exc_info=True
            )
        finally:
            self.shutdown()
Exemplo n.º 5
0
class Operator:
    """
    Acting as the controller behind the curtain, the Operator class
    is responsible for managing the knobs, switches, and dials 
    of the rest of the Mephisto architecture. 

    Most convenience scripts for using Mephisto will use an Operator
    to get the job done, though this class itself is also a 
    good model to use to understand how the underlying 
    architecture works in order to build custom jobs or workflows.
    """
    def __init__(self, db: "MephistoDB"):
        self.db = db
        self.supervisor = Supervisor(db)
        self._task_runs_tracked: Dict[str, TrackedRun] = {}
        self.is_shutdown = False
        self._run_tracker_thread = threading.Thread(
            target=self._track_and_kill_runs, name="Operator-tracking-thread")
        self._run_tracker_thread.start()

    @staticmethod
    def _get_baseline_argparser() -> ArgumentParser:
        """Return a parser for the baseline requirements to launch a job"""
        parser = ArgumentParser()
        parser.add_argument(
            "--blueprint-type",
            dest="blueprint_type",
            help="Name of the blueprint to launch",
            required=True,
        )
        parser.add_argument(
            "--architect-type",
            dest="architect_type",
            help="Name of the architect to launch with",
            required=True,
        )
        parser.add_argument(
            "--requester-name",
            dest="requester_name",
            help="Identifier for the requester to launch as",
            required=True,
        )
        return parser

    def get_running_task_runs(self):
        """Return the currently running task runs and their handlers"""
        return self._task_runs_tracked.copy()

    def parse_and_launch_run(
        self,
        arg_list: Optional[List[str]] = None,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Optional[str]:
        """
        Wrapper around parse and launch run that prints errors on failure, rather
        than throwing. Generally for use in scripts.
        """
        raise Exception(
            'Operator.parse_and_launch_run has been deprecated in favor '
            'of using Hydra for argument configuration. See the docs at '
            'https://github.com/facebookresearch/Mephisto/blob/master/docs/hydra_migration.md '
            'in order to upgrade.')

    def validate_and_run_config_or_die(
        self,
        run_config: DictConfig,
        shared_state: Optional[SharedTaskState] = None,
    ) -> str:
        """
        Parse the given arguments and launch a job.
        """
        if shared_state is None:
            shared_state = SharedTaskState()

        # First try to find the requester:
        requester_name = run_config.provider.requester_name
        requesters = self.db.find_requesters(requester_name=requester_name)
        if len(requesters) == 0:
            if run_config.provider.requester_name == "MOCK_REQUESTER":
                requesters = [get_mock_requester(self.db)]
            else:
                raise EntryDoesNotExistException(
                    f"No requester found with name {requester_name}")
        requester = requesters[0]
        requester_id = requester.db_id
        provider_type = requester.provider_type
        assert provider_type == run_config.provider._provider_type, (
            f"Found requester for name {requester_name} is not "
            f"of the specified type {run_config.provider._provider_type}, "
            f"but is instead {provider_type}.")

        # Next get the abstraction classes, and run validation
        # before anything is actually created in the database
        blueprint_type = run_config.blueprint._blueprint_type
        architect_type = run_config.architect._architect_type
        BlueprintClass = get_blueprint_from_type(blueprint_type)
        ArchitectClass = get_architect_from_type(architect_type)
        CrowdProviderClass = get_crowd_provider_from_type(provider_type)

        BlueprintClass.assert_task_args(run_config, shared_state)
        ArchitectClass.assert_task_args(run_config, shared_state)
        CrowdProviderClass.assert_task_args(run_config, shared_state)

        # Find an existing task or create a new one
        task_name = run_config.task.get("task_name", None)
        if task_name is None:
            task_name = blueprint_type
            logger.warning(
                f"Task is using the default blueprint name {task_name} as a name, "
                "as no task_name is provided")
        tasks = self.db.find_tasks(task_name=task_name)
        task_id = None
        if len(tasks) == 0:
            task_id = self.db.new_task(task_name, blueprint_type)
        else:
            task_id = tasks[0].db_id

        logger.info(f"Creating a task run under task name: {task_name}")

        # Create a new task run
        new_run_id = self.db.new_task_run(
            task_id,
            requester_id,
            json.dumps(OmegaConf.to_container(run_config, resolve=True)),
            provider_type,
            blueprint_type,
            requester.is_sandbox(),
        )
        task_run = TaskRun(self.db, new_run_id)

        try:
            # If anything fails after here, we have to cleanup the architect

            build_dir = os.path.join(task_run.get_run_dir(), "build")
            os.makedirs(build_dir, exist_ok=True)
            architect = ArchitectClass(self.db, run_config, shared_state,
                                       task_run, build_dir)

            # Register the blueprint with args to the task run,
            # ensure cached
            blueprint = BlueprintClass(task_run, run_config, shared_state)
            task_run.get_blueprint(args=run_config, shared_state=shared_state)

            # Setup and deploy the server
            built_dir = architect.prepare()
            task_url = architect.deploy()

            # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only
            # happen after everything has already been reviewed, this way it's possible to
            # retrieve the exact build directory to review a task for real
            architect.cleanup()

            # Create the backend runner
            task_runner = BlueprintClass.TaskRunnerClass(
                task_run, run_config, shared_state)

            # Small hack for auto appending block qualification
            existing_qualifications = shared_state.qualifications
            if run_config.blueprint.get("block_qualification",
                                        None) is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        run_config.blueprint.block_qualification,
                        QUAL_NOT_EXIST, None))
            if run_config.blueprint.get("onboarding_qualification",
                                        None) is not None:
                existing_qualifications.append(
                    make_qualification_dict(
                        OnboardingRequired.get_failed_qual(
                            run_config.blueprint.onboarding_qualification, ),
                        QUAL_NOT_EXIST,
                        None,
                    ))
            shared_state.qualifications = existing_qualifications

            # Register the task with the provider
            provider = CrowdProviderClass(self.db)
            provider.setup_resources_for_task_run(task_run, run_config,
                                                  task_url)

            initialization_data_array = blueprint.get_initialization_data()

            # Link the job together
            job = self.supervisor.register_job(architect, task_runner,
                                               provider,
                                               existing_qualifications)
            if self.supervisor.sending_thread is None:
                self.supervisor.launch_sending_thread()
        except (KeyboardInterrupt, Exception) as e:
            logger.error(
                "Encountered error while launching run, shutting down",
                exc_info=True)
            try:
                architect.shutdown()
            except (KeyboardInterrupt, Exception) as architect_exception:
                logger.exception(
                    f"Could not shut down architect: {architect_exception}",
                    exc_info=True,
                )
            raise e

        launcher = TaskLauncher(self.db, task_run, initialization_data_array)
        launcher.create_assignments()
        launcher.launch_units(task_url)

        self._task_runs_tracked[task_run.db_id] = TrackedRun(
            task_run=task_run,
            task_launcher=launcher,
            task_runner=task_runner,
            architect=architect,
            job=job,
        )
        task_run.update_completion_progress(status=False)

        return task_run.db_id

    def _track_and_kill_runs(self):
        """
        Background thread that shuts down servers when a task
        is fully done.
        """
        while not self.is_shutdown:
            runs_to_check = list(self._task_runs_tracked.values())
            for tracked_run in runs_to_check:
                task_run = tracked_run.task_run
                task_run.update_completion_progress(
                    task_launcher=tracked_run.task_launcher)
                if not task_run.get_is_completed():
                    continue
                else:
                    self.supervisor.shutdown_job(tracked_run.job)
                    tracked_run.architect.shutdown()
                    tracked_run.task_launcher.shutdown()
                    del self._task_runs_tracked[task_run.db_id]
            time.sleep(2)

    def shutdown(self, skip_input=True):
        logger.info("operator shutting down")
        self.is_shutdown = True
        for tracked_run in self._task_runs_tracked.values():
            logger.info("expiring units")
            tracked_run.task_launcher.shutdown()
            tracked_run.task_launcher.expire_units()
        try:
            remaining_runs = self._task_runs_tracked.values()
            while len(remaining_runs) > 0:
                next_runs = []
                for tracked_run in remaining_runs:
                    if tracked_run.task_run.get_is_completed():
                        tracked_run.architect.shutdown()
                    else:
                        next_runs.append(tracked_run)
                if len(next_runs) > 0:
                    logger.info(
                        f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT"
                    )
                    time.sleep(30)
                remaining_runs = next_runs
        except Exception as e:
            logger.exception(f"Encountered problem during shutting down {e}",
                             exc_info=True)
            import traceback

            traceback.print_exc()
        except (KeyboardInterrupt, SystemExit) as e:
            logger.info(
                "Skipping waiting for outstanding task completions, shutting down servers now!"
            )
            for tracked_run in remaining_runs:
                tracked_run.architect.shutdown()
        finally:
            self.supervisor.shutdown()
            self._run_tracker_thread.join()

    def validate_and_run_config(
        self,
        run_config: DictConfig,
        shared_state: Optional[SharedTaskState] = None,
    ) -> Optional[str]:
        """
        Wrapper around validate_and_run_config_or_die that prints errors on 
        failure, rather than throwing. Generally for use in scripts.
        """
        try:
            return self.validate_and_run_config_or_die(
                run_config=run_config,
                shared_state=shared_state,
            )
        except (KeyboardInterrupt, Exception) as e:
            logger.error("Ran into error while launching run: ", exc_info=True)
            return None

    def parse_and_launch_run_wrapper(
        self,
        arg_list: Optional[List[str]] = None,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Optional[str]:
        """
        Wrapper around parse and launch run that prints errors on failure, rather
        than throwing. Generally for use in scripts.
        """
        raise Exception(
            'Operator.parse_and_launch_run_wrapper has been deprecated in favor '
            'of using Hydra for argument configuration. See the docs at '
            'https://github.com/facebookresearch/Mephisto/blob/master/docs/hydra_migration.md '
            'in order to upgrade.')

    def print_run_details(self):
        """Print details about running tasks"""
        # TODO(#93) parse these tasks and get the full details
        for task in self.get_running_task_runs():
            logger.info(f"Operator running task ID = {task}")

    def wait_for_runs_then_shutdown(self,
                                    skip_input=False,
                                    log_rate: Optional[int] = None) -> None:
        """
        Wait for task_runs to complete, and then shutdown.  

        Set log_rate to get print statements of currently running tasks
        at the specified interval
        """
        try:
            try:
                last_log = 0.0
                while len(self.get_running_task_runs()) > 0:
                    if log_rate is not None:
                        if time.time() - last_log > log_rate:
                            last_log = time.time()
                            self.print_run_details()
                    time.sleep(10)

            except Exception as e:
                if skip_input:
                    raise e

                traceback.print_exc()
                should_quit = input(
                    "The above exception happened while running a task, do "
                    "you want to shut down? (y)/n: ")
                if should_quit not in ["n", "N", "no", "No"]:
                    raise e

        except Exception as e:
            import traceback

            traceback.print_exc()
        except (KeyboardInterrupt, SystemExit) as e:
            logger.exception(
                "Cleaning up after keyboard interrupt, please wait!",
                exc_info=True)
        finally:
            self.shutdown()
Exemplo n.º 6
0
    def test_register_job_with_onboarding(self):
        """Test registering and running a job with onboarding"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TEST_QUALIFICATION_NAME = "test_onboarding_qualification"

        # Register onboarding arguments for blueprint
        task_run_args = self.task_run.args
        task_run_args.blueprint.use_onboarding = True
        task_run_args.blueprint.onboarding_qualification = TEST_QUALIFICATION_NAME
        task_run_args.blueprint.timeout_time = 5
        task_run_args.blueprint.is_concurrent = False
        self.task_run.get_task_config()

        # Supervisor expects that blueprint setup has already occurred
        blueprint = self.task_run.get_blueprint()

        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        task_runner = TaskRunnerClass(self.task_run, task_run_args,
                                      EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker_1 = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_units), 0)

        # Fail to register a blocked agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        qualification_id = blueprint.onboarding_qualification_id
        self.db.grant_qualification(qualification_id, worker_1.db_id, 0)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet, failed onboarding")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for disqualified worker",
        )
        self.assertIsNone(last_packet["data"]["agent_id"],
                          "worker assigned real agent id")
        self.architect.server.last_packet = None
        self.db.revoke_qualification(qualification_id, worker_id)

        # Register an agent successfully
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 1,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": False}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[0].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 0,
                         "Failed agent created after onboarding")
        self.assertEqual(len(sup.agents), 0,
                         "Failed agent registered with supervisor")

        self.assertEqual(
            len(task_runner.running_units),
            0,
            "Task should not launch with failed worker",
        )

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_2 = workers[0]
        worker_id = worker_2.db_id

        # Register an agent that is already qualified
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.db.grant_qualification(qualification_id, worker_2.db_id, 1)
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertNotIn(
            "onboard_data",
            last_packet["data"],
            "Onboarding triggered for qualified agent",
        )
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Second agent not created without onboarding")

        self.assertEqual(len(task_runner.running_units), 1,
                         "Tasks were not launched")

        self.assertFalse(worker_1.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_1.is_disqualified(TEST_QUALIFICATION_NAME))
        self.assertTrue(worker_2.is_qualified(TEST_QUALIFICATION_NAME))
        self.assertFalse(worker_2.is_disqualified(TEST_QUALIFICATION_NAME))

        # Register another worker
        mock_worker_name = "MOCK_WORKER_3"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_3 = workers[0]
        worker_id = worker_3.db_id
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1,
                         "Agent should not be created yet - need onboarding")
        onboard_agents = self.db.find_onboarding_agents()
        self.assertEqual(len(onboard_agents), 2,
                         "Onboarding agent should have been created")
        time.sleep(0.1)
        last_packet = self.architect.server.last_packet
        self.assertIsNotNone(last_packet)
        self.assertIn("onboard_data", last_packet["data"],
                      "Onboarding not triggered")
        self.architect.server.last_packet = None

        # Submit onboarding from the agent
        onboard_data = {"should_pass": True}
        self.architect.server.register_mock_agent_after_onboarding(
            worker_id, onboard_agents[1].get_agent_id(), onboard_data)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Agent not created after onboarding")

        # Re-register as if refreshing
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2,
                         "Duplicate agent created after onboarding")
        agent = agents[1]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 2,
                         "Agent not registered supervisor after onboarding")

        self.assertEqual(len(task_runner.running_units), 2,
                         "Task not launched after onboarding")

        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_units) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())
Exemplo n.º 7
0
    def test_register_job(self):
        """Test registering and running a job run asynchronously"""
        # Handle baseline setup
        sup = Supervisor(self.db)
        self.sup = sup
        TaskRunnerClass = MockBlueprint.TaskRunnerClass
        args = MockBlueprint.ArgsClass()
        args.timeout_time = 5
        config = OmegaConf.structured(MephistoConfig(blueprint=args))
        task_runner = TaskRunnerClass(self.task_run, config, EMPTY_STATE)
        sup.register_job(self.architect, task_runner, self.provider)
        self.assertEqual(len(sup.channels), 1)
        channel_info = list(sup.channels.values())[0]
        self.assertIsNotNone(channel_info)
        self.assertTrue(channel_info.channel.is_alive())
        channel_id = channel_info.channel_id
        task_runner = channel_info.job.task_runner
        self.assertIsNotNone(channel_id)
        self.assertEqual(
            len(self.architect.server.subs),
            1,
            "MockServer doesn't see registered channel",
        )
        self.assertIsNotNone(
            self.architect.server.last_alive_packet,
            "No alive packet received by server",
        )
        sup.launch_sending_thread()
        self.assertIsNotNone(sup.sending_thread)

        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker not successfully registered")
        worker = workers[0]

        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        self.assertEqual(len(workers), 1, "Worker potentially re-registered")
        worker_id = workers[0].db_id

        self.assertEqual(len(task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")

        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent may have been duplicated")
        agent = agents[0]
        self.assertIsNotNone(agent)
        self.assertEqual(len(sup.agents), 1,
                         "Agent not registered with supervisor")

        self.assertEqual(len(task_runner.running_assignments), 0,
                         "Task was not yet ready")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        self.architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        self.architect.server.register_mock_agent(worker_id,
                                                  mock_agent_details)

        self.assertEqual(len(task_runner.running_assignments), 1,
                         "Task was not launched")
        agents = [a.agent for a in sup.agents.values()]

        # Make both agents act
        agent_id_1, agent_id_2 = agents[0].db_id, agents[1].db_id
        agent_1_data = agents[0].datastore.agent_data[agent_id_1]
        agent_2_data = agents[1].datastore.agent_data[agent_id_2]
        self.architect.server.send_agent_act(agent_id_1, {"text": "message1"})
        self.architect.server.send_agent_act(agent_id_2, {"text": "message2"})

        # Give up to 1 seconds for the actual operation to occur
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(agent_1_data["acts"]) > 0:
                break
            time.sleep(0.1)

        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not process messages in time")

        # Give up to 1 seconds for the task to complete afterwards
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if len(task_runner.running_assignments) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Did not complete task in time")

        # Give up to 1 seconds for all messages to propogate
        start_time = time.time()
        TIMEOUT_TIME = 1
        while time.time() - start_time < TIMEOUT_TIME:
            if self.architect.server.actions_observed == 2:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Not all actions observed in time")

        sup.shutdown()
        self.assertTrue(channel_info.channel.is_closed())