Пример #1
0
 def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
     """
     Set up the operator and server.
     """
     self.operator = Operator(self.db)
     self.operator.validate_and_run_config(self.config.mephisto,
                                           shared_state=shared_state)
     self.server = self._get_channel_info().job.architect.server
Пример #2
0
 def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
     """
     Set up the operator and server.
     """
     self.operator = Operator(self.db)
     self.operator.validate_and_run_config(self.config.mephisto,
                                           shared_state=shared_state)
     channel_info = list(self.operator.supervisor.channels.values())[0]
     self.server = channel_info.job.architect.server
def main(operator: Operator, cfg: DictConfig) -> None:
    correct_config_answer = cfg.correct_answer

    def onboarding_is_valid(onboarding_data):
        inputs = onboarding_data["inputs"]
        outputs = onboarding_data["outputs"]
        return outputs.get("answer") == correct_config_answer

    shared_state = SharedStaticTaskState(
        onboarding_data={"correct_answer": correct_config_answer},
        validate_onboarding=onboarding_is_valid,
    )

    operator.launch_task_run(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #4
0
def main(operator: Operator, cfg: DictConfig) -> None:
    def onboarding_always_valid(onboarding_data):
        return True

    shared_state = SharedStaticTaskState(
        static_task_data=[
            {
                "text": "This text is good text!"
            },
            {
                "text": "This text is bad text!"
            },
        ],
        validate_onboarding=onboarding_always_valid,
    )

    task_dir = cfg.task_dir
    build_custom_bundle(task_dir)

    operator.launch_task_run(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #5
0
def process_config_and_get_operator(cfg: DictConfig,
                                    print_config=False
                                    ) -> Tuple["Operator", DictConfig]:
    """
    Using a Hydra DictConfig built from a TaskConfig,
    return an operator for that task as well as a validated config.

    Takes in an option to print out the configuration before returning
    """
    db, valid_config = load_db_and_process_config(cfg,
                                                  print_config=print_config)
    return Operator(db), valid_config
Пример #6
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)
    operator.validate_and_run_config(run_config=cfg.mephisto,
                                     shared_state=None)
    operator.wait_for_runs_then_shutdown(skip_input=True,
                                         log_rate=cfg.monitoring_log_rate)
Пример #7
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)

    parser = ParlaiParser(True, False)
    opt = parser.parse_args(
        list(chain.from_iterable(
            ('--' + k, v) for k, v in cfg.teacher.items())))
    agent = RepeatLabelAgent(opt)
    teacher = create_task(opt, agent).get_task_agent()

    world_opt = {"turn_timeout": cfg.turn_timeout, "teacher": teacher}

    custom_bundle_path = cfg.mephisto.blueprint.get("custom_source_bundle",
                                                    None)
    if custom_bundle_path is not None:
        assert os.path.exists(custom_bundle_path), (
            "Must build the custom bundle with `npm install; npm run dev` from within "
            f"the {TASK_DIRECTORY}/webapp directory in order to demo a custom bundle "
        )
        world_opt["send_task_data"] = True

    shared_state = SharedParlAITaskState(world_opt=world_opt,
                                         onboarding_world_opt=world_opt)

    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #8
0
def run_static_task(cfg: DictConfig, task_directory: str):
    """
    Run static task, given configuration.
    """

    db, cfg = load_db_and_process_config(cfg)
    print(f'\nHydra config:\n{OmegaConf.to_yaml(cfg)}')

    random.seed(42)

    task_name = cfg.mephisto.task.get('task_name', 'turn_annotations_static')
    soft_block_qual_name = cfg.mephisto.blueprint.get('block_qualification',
                                                      f'{task_name}_block')
    # Default to a task-specific name to avoid soft-block collisions
    soft_block_mturk_workers(cfg=cfg,
                             db=db,
                             soft_block_qual_name=soft_block_qual_name)

    build_task(task_directory)

    operator = Operator(db)
    operator.validate_and_run_config(run_config=cfg.mephisto,
                                     shared_state=None)
    operator.wait_for_runs_then_shutdown(skip_input=True,
                                         log_rate=cfg.monitoring_log_rate)
Пример #9
0
def main(operator: Operator, cfg: DictConfig) -> None:
    def onboarding_always_valid(onboarding_data):
        # NOTE you can make an onboarding task and validate it here
        print(onboarding_data)
        return True

    # Right now we're building locally, but should eventually
    # use non-local for the real thing
    tasks = build_tasks(cfg.num_tasks)
    context = build_local_context(cfg.num_tasks)

    def handle_with_model(
        _request_id: str, args: Dict[str, Any], agent_state: RemoteProcedureAgentState
    ) -> Dict[str, Any]:
        """Remote call to process external content using a 'model'"""
        # NOTE this body can be whatever you want
        print(f"The parsed args are {args}, you can do what you want with that")
        print(f"You can also use {agent_state.init_data}, to get task keys")
        assert agent_state.init_data is not None
        idx = agent_state.init_data["local_value_key"]
        print(f"And that may let you get local context, like {context[idx]}")
        return {
            "secret_local_value": context[idx],
            "update": f"this was request {args['arg3'] + 1}",
        }

    function_registry = {
        "handle_with_model": handle_with_model,
    }

    shared_state = SharedRemoteProcedureTaskState(
        static_task_data=tasks,
        validate_onboarding=onboarding_always_valid,
        function_registry=function_registry,
    )

    task_dir = cfg.task_dir
    build_custom_bundle(task_dir)
    operator.launch_task_run(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #10
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)
    world_opt = get_world_opt(cfg)
    onboarding_world_opt = get_onboarding_world_opt(cfg)
    shared_state = SharedParlAITaskState(
        world_opt=world_opt, onboarding_world_opt=onboarding_world_opt)

    check_role_training_qualification(
        db=db,
        qname=world_opt[constants.ROLE_QUALIFICATION_NAME_KEY],
        requester_name=cfg.mephisto.provider.requester_name,
    )

    shared_state.task_config['minTurns'] = world_opt['min_turns']
    shared_state.task_config[
        'onboardingPersona'] = constants.ONBOARDING_PERSONA
    shared_state.worker_can_do_unit = get_worker_eval_function(
        world_opt[constants.ROLE_QUALIFICATION_NAME_KEY],
        onboarding_world_opt['onboarding_qualification'],
    )

    banned_words_fpath = cfg.mephisto.blueprint.banned_words_file
    add_banned_words_frontend_conf(shared_state, banned_words_fpath)

    operator = Operator(db)
    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=300)
    update_persona_use_counts_file(cfg.mephisto.blueprint.persona_counts_file,
                                   world_opt['prev_persona_count'])
def main(cfg: DictConfig) -> None:
    def onboarding_is_valid(onboarding_data):
        outputs = onboarding_data["outputs"]
        answer_str = outputs["answer"]
        # NOTE: depending on which OS Turker uses, there could be carriage returns \r or just newlines \n
        # this python module should handle all cases
        commands = answer_str.splitlines()
        # filter empty commands
        filtered_commands = [x for x in commands if x != ""]
        # Number check: Check that the number of commands >= 3
        if len(commands) < 3:
            return False
        # Length check: Check that the average number of words in commands > 4
        commands_split = [x.split(" ") for x in filtered_commands]
        avg_words_in_commands = sum(map(len,
                                        commands_split)) / len(commands_split)
        if avg_words_in_commands < 2:
            return False
        # Diversity check: Check that commands are reasonably diverse
        first_words = [x[0] for x in commands_split]
        if len(set(first_words)) == 1:
            return False
        # TODO: Grammar check: Check that there is punctuation, capitals
        return True

    shared_state = SharedStaticTaskState(
        onboarding_data={},
        validate_onboarding=onboarding_is_valid,
    )

    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #12
0
def main(operator: Operator, cfg: DictConfig) -> None:
    tasks: List[Dict[str, Any]] = [{}] * cfg.num_tasks
    mnist_model = mnist(pretrained=True)

    def handle_with_model(
            _request_id: str, args: Dict[str, Any],
            agent_state: RemoteProcedureAgentState) -> Dict[str, Any]:
        """Convert the image to be read by MNIST classifier, then classify"""
        img_dat = args["urlData"].split("data:image/png;base64,")[1]
        im = Image.open(BytesIO(base64.b64decode(img_dat)))
        im_gray = im.convert("L")
        im_resized = im_gray.resize((28, 28))
        im_vals = list(im_resized.getdata())
        norm_vals = [(255 - x) * 1.0 / 255.0 for x in im_vals]
        in_tensor = torch.tensor([norm_vals])
        output = mnist_model(in_tensor)
        pred = output.data.max(1)[1]
        print("Predicted digit:", pred.item())
        return {
            "digit_prediction": pred.item(),
        }

    function_registry = {
        "classify_digit": handle_with_model,
    }

    shared_state = SharedRemoteProcedureTaskState(
        static_task_data=tasks,
        function_registry=function_registry,
    )

    task_dir = cfg.task_dir
    build_custom_bundle(task_dir)

    operator.launch_task_run(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #13
0
def main(cfg: DictConfig) -> None:

    shared_state = SharedStaticTaskState(qualifications=[
        make_qualification_dict(ALLOWLIST_QUALIFICATION, QUAL_EXISTS, None),
    ], )

    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #14
0
 def run_acute_eval(self):
     """
     Run ACUTE Eval.
     """
     self.set_up_acute_eval()
     db, cfg = load_db_and_process_config(self.args)
     operator = Operator(db)
     operator.validate_and_run_config(run_config=cfg.mephisto, shared_state=None)
     operator.wait_for_runs_then_shutdown(
         skip_input=True, log_rate=cfg.monitoring_log_rate
     )
Пример #15
0
 def run_acute_eval(self):
     """
     Run ACUTE Eval.
     """
     self.set_up_acute_eval()
     db, cfg = load_db_and_process_config(self.args)
     print(f'*** RUN ID: {cfg.mephisto.task.task_name} ***')
     print(f'\nHydra config:\n{OmegaConf.to_yaml(cfg)}')
     operator = Operator(db)
     operator.validate_and_run_config(run_config=cfg.mephisto,
                                      shared_state=None)
     operator.wait_for_runs_then_shutdown(skip_input=True,
                                          log_rate=cfg.monitoring_log_rate)
Пример #16
0
def main():
    app = Flask(
        __name__, static_url_path="/static", static_folder="webapp/build/static"
    )
    app.config.from_object(Config)

    app.register_blueprint(api, url_prefix="/api/v1")

    # Register extensions
    db = LocalMephistoDB()
    operator = Operator(db)
    if not hasattr(app, "extensions"):
        app.extensions = {}
    app.extensions["db"] = db
    app.extensions["operator"] = operator

    @app.route("/", defaults={"path": "index.html"})
    @app.route("/<path:path>")
    def index(path):
        return send_file(os.path.join("webapp", "build", "index.html"))

    @app.after_request
    def after_request(response):
        response.headers.add("Access-Control-Allow-Origin", "*")
        response.headers.add(
            "Access-Control-Allow-Headers", "Content-Type,Authorization"
        )
        response.headers.add(
            "Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS"
        )
        response.headers.add("Cache-Control", "no-store")
        return response

    term_handler = signal.getsignal(signal.SIGINT)

    def cleanup_resources(*args, **kwargs):
        operator.shutdown()
        db.shutdown()
        term_handler(*args, **kwargs)

    atexit.register(cleanup_resources)
    signal.signal(signal.SIGINT, cleanup_resources)
Пример #17
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    validator = validate_unit

    shared_state = SharedStaticTaskState(on_unit_submitted=validator, )
    # Do not allow workers to take pilot task the second time
    shared_state.qualifications = [
        make_qualification_dict(
            PILOT_BLOCK_QUAL_NAME,
            QUAL_NOT_EXIST,
            None,
        ),
    ]

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
def main(cfg: DictConfig) -> None:
    correct_config_answer = cfg.correct_answer

    def onboarding_is_valid(onboarding_data):
        inputs = onboarding_data["inputs"]
        outputs = onboarding_data["outputs"]
        return outputs.get("answer") == correct_config_answer

    shared_state = SharedStaticTaskState(
        onboarding_data={"correct_answer": correct_config_answer},
        validate_onboarding=onboarding_is_valid,
    )

    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #19
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)

    teacher = get_teacher(cfg)
    world_opt = {"turn_timeout": cfg.turn_timeout, "teacher": teacher}

    custom_bundle_path = cfg.mephisto.blueprint.get("custom_source_bundle",
                                                    None)
    if custom_bundle_path is not None:
        if not os.path.exists(custom_bundle_path):
            build_task(TASK_DIRECTORY)

    shared_state = SharedParlAITaskState(world_opt=world_opt,
                                         onboarding_world_opt=world_opt)

    operator = Operator(db)
    operator.validate_and_run_config(run_config=cfg.mephisto,
                                     shared_state=shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True,
                                         log_rate=cfg.monitoring_log_rate)
Пример #20
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)

    world_opt = {"num_turns": cfg.num_turns, "turn_timeout": cfg.turn_timeout}

    custom_bundle_path = cfg.mephisto.blueprint.get("custom_source_bundle",
                                                    None)
    if custom_bundle_path is not None:
        assert os.path.exists(custom_bundle_path), (
            "Must build the custom bundle with `npm install; npm run dev` from within "
            f"the {TASK_DIRECTORY}/webapp directory in order to demo a custom bundle "
        )
        world_opt["send_task_data"] = True

    shared_state = SharedParlAITaskState(world_opt=world_opt,
                                         onboarding_world_opt=world_opt)

    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #21
0
def main(cfg: DictConfig) -> None:
    task_dir = cfg.task_dir

    def onboarding_always_valid(onboarding_data):
        return True

    shared_state = SharedStaticTaskState(
        static_task_data=[
            {"text": "This text is good text!"},
            {"text": "This text is bad text!"},
        ],
        validate_onboarding=onboarding_always_valid,
    )

    build_task(task_dir)

    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto, shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #22
0
def run_task(cfg: DictConfig, task_directory: str):
    """
    Run task, given configuration.
    """

    frontend_source_dir = os.path.join(task_directory, "webapp")
    frontend_build_dir = os.path.join(frontend_source_dir, "build")
    _ = frontend_build_dir  # Unused at the moment

    db, cfg = load_db_and_process_config(cfg)
    print(f'\nHydra config:\n{OmegaConf.to_yaml(cfg)}')

    random.seed(42)

    # Update task name when on sandbox or local to ensure data is split.
    task_name = cfg.mephisto.task.get('task_name', 'model_chat')
    architect_type = cfg.mephisto.architect._architect_type
    if architect_type == 'local':
        task_name = f"{task_name}_local"
    elif architect_type == 'mturk_sandbox':
        task_name = f"{task_name}_sandbox"
    cfg.mephisto.task.task_name = task_name

    soft_block_qual_name = cfg.mephisto.blueprint.get('block_qualification',
                                                      f'{task_name}_block')
    # Default to a task-specific name to avoid soft-block collisions
    soft_block_mturk_workers(cfg=cfg,
                             db=db,
                             soft_block_qual_name=soft_block_qual_name)

    # Init
    shared_state = SharedModelChatTaskState(world_module=world_module)

    operator = Operator(db)
    operator.validate_and_run_config(run_config=cfg.mephisto,
                                     shared_state=shared_state)
    operator.wait_for_runs_then_shutdown(skip_input=True,
                                         log_rate=cfg.monitoring_log_rate)
Пример #23
0
def main(cfg: DictConfig) -> None:
    db, cfg = load_db_and_process_config(cfg)
    operator = Operator(db)

    operator.validate_and_run_config(cfg.mephisto)
    operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30)
Пример #24
0
    def test_run_job_concurrent(self):
        """Ensure that the supervisor object can even be created"""
        self.operator = Operator(self.db)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True),
            provider=MockProviderArgs(requester_name=self.requester_name),
            architect=MockArchitectArgs(should_run_server=True),
            task=MOCK_TASK_ARGS,
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id, mock_agent_details)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure the assignment is completed
        task_run = tracked_run.task_run
        assignment = task_run.get_assignments()[0]
        self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)
Пример #25
0
 def test_initialize_supervisor(self):
     """Quick test to ensure that the operator can be initialized"""
     self.operator = Operator(self.db)
Пример #26
0
class OperatorBaseTest(object):
    """
    Unit testing for the Mephisto Operator
    """

    DB_CLASS = None

    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        assert self.DB_CLASS is not None, "Did not specify db to use"
        self.db = self.DB_CLASS(database_path)
        self.requester_name, _req_id = get_test_requester(self.db)
        self.operator = None

    def tearDown(self):
        if self.operator is not None:
            self.operator.shutdown()
        self.db.shutdown()
        shutil.rmtree(self.data_dir, ignore_errors=True)
        self.assertTrue(
            len(threading.enumerate()) == 1,
            f"Expected only main thread at teardown, found {threading.enumerate()}",
        )

    def wait_for_complete_assignment(self, assignment, timeout: int):
        start_time = time.time()
        while time.time() - start_time < timeout:
            if assignment.get_status() == AssignmentState.COMPLETED:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, timeout,
                        "Assignment not completed in time")

    def await_server_start(self, architect: "MockArchitect"):
        start_time = time.time()
        assert architect.server is not None, "Cannot wait on empty server"
        while time.time() - start_time < 5:
            if len(architect.server.subs) > 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, 5,
                        "Mock server not up in time")

    def test_initialize_supervisor(self):
        """Quick test to ensure that the operator can be initialized"""
        self.operator = Operator(self.db)

    def test_run_job_concurrent(self):
        """Ensure that the supervisor object can even be created"""
        self.operator = Operator(self.db)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True),
            provider=MockProviderArgs(requester_name=self.requester_name),
            architect=MockArchitectArgs(should_run_server=True),
            task=MOCK_TASK_ARGS,
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id, mock_agent_details)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure the assignment is completed
        task_run = tracked_run.task_run
        assignment = task_run.get_assignments()[0]
        self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)

    def test_run_job_not_concurrent(self):
        """Ensure that the supervisor object can even be created"""
        self.operator = Operator(self.db)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=1,
                                        is_concurrent=False),
            provider=MockProviderArgs(requester_name=self.requester_name),
            architect=MockArchitectArgs(should_run_server=True),
            task=MOCK_TASK_ARGS,
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id, mock_agent_details)

        # Give up to 5 seconds for both tasks to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure the assignment is completed
        task_run = tracked_run.task_run
        assignment = task_run.get_assignments()[0]
        self.assertEqual(assignment.get_status(), AssignmentState.COMPLETED)

    def test_run_jobs_with_restrictions(self):
        """Ensure allowed_concurrent and maximum_units_per_worker work"""
        self.operator = Operator(self.db)
        provider_args = MockProviderArgs(requester_name=self.requester_name)
        architect_args = MockArchitectArgs(should_run_server=True)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=3, is_concurrent=True),
            provider=provider_args,
            architect=architect_args,
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name="max-unit-test",
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        self.await_server_start(tracked_run.architect)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_1 = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Try to register a second agent, which should fail due to concurrency
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Second agent was created")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_2 = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Second agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[1].get_unit().get_assignment(), 3)

        # Pass a second task as well
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 3, "Agent was not created properly")
        mock_agent_details = "FAKE_ASSIGNMENT_4"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Fourth agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[3].get_unit().get_assignment(), 3)

        # Both workers should have saturated their tasks, and not be granted agents
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")

        # new workers should be able to work on these just fine though
        mock_worker_name = "MOCK_WORKER_3"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_3 = workers[0].db_id
        mock_worker_name = "MOCK_WORKER_4"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_4 = workers[0].db_id

        # Register agents from new workers
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 5, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_6"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 6, "Additional agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[5].get_unit().get_assignment(), 3)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure all assignments are completed
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)

        # Create a new task
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True),
            provider=MockProviderArgs(requester_name=self.requester_name),
            architect=MockArchitectArgs(should_run_server=True),
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name="max-unit-test",
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]
        self.await_server_start(tracked_run.architect)
        architect = tracked_run.architect

        # Workers one and two still shouldn't be able to make agents
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )

        # Three and four should though
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 7, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_8"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 8, "Additional agent was not created")

        # Ensure the task run completed and that all assignments are done
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)
Пример #27
0
    def test_run_jobs_with_restrictions(self):
        """Ensure allowed_concurrent and maximum_units_per_worker work"""
        self.operator = Operator(self.db)
        provider_args = MockProviderArgs(requester_name=self.requester_name)
        architect_args = MockArchitectArgs(should_run_server=True)
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=3, is_concurrent=True),
            provider=provider_args,
            architect=architect_args,
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name="max-unit-test",
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]

        self.assertIsNotNone(tracked_run)
        self.assertIsNotNone(tracked_run.task_launcher)
        self.assertIsNotNone(tracked_run.task_runner)
        self.assertIsNotNone(tracked_run.architect)
        self.assertIsNotNone(tracked_run.task_run)
        self.assertEqual(tracked_run.task_run.db_id, task_run_id)

        self.await_server_start(tracked_run.architect)

        # Create two agents to step through the task
        architect = tracked_run.architect
        self.assertIsInstance(architect, MockArchitect,
                              "Must use mock in testing")
        # Register a worker
        mock_worker_name = "MOCK_WORKER"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_1 = workers[0].db_id

        self.assertEqual(len(tracked_run.task_runner.running_assignments), 0)

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Agent was not created properly")
        agent = agents[0]
        self.assertIsNotNone(agent)

        # Try to register a second agent, which should fail due to concurrency
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 1, "Second agent was created")

        # Register another worker
        mock_worker_name = "MOCK_WORKER_2"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_2 = workers[0].db_id

        # Register an agent
        mock_agent_details = "FAKE_ASSIGNMENT_2"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 2, "Second agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[1].get_unit().get_assignment(), 3)

        # Pass a second task as well
        mock_agent_details = "FAKE_ASSIGNMENT_3"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 3, "Agent was not created properly")
        mock_agent_details = "FAKE_ASSIGNMENT_4"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Fourth agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[3].get_unit().get_assignment(), 3)

        # Both workers should have saturated their tasks, and not be granted agents
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 4, "Additional agent was created")

        # new workers should be able to work on these just fine though
        mock_worker_name = "MOCK_WORKER_3"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_3 = workers[0].db_id
        mock_worker_name = "MOCK_WORKER_4"
        architect.server.register_mock_worker(mock_worker_name)
        workers = self.db.find_workers(worker_name=mock_worker_name)
        worker_id_4 = workers[0].db_id

        # Register agents from new workers
        mock_agent_details = "FAKE_ASSIGNMENT_5"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 5, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_6"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 6, "Additional agent was not created")

        # wait for task to pass
        self.wait_for_complete_assignment(
            agents[5].get_unit().get_assignment(), 3)

        # Give up to 5 seconds for whole mock task to complete
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")

        # Ensure all assignments are completed
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)

        # Create a new task
        config = MephistoConfig(
            blueprint=MockBlueprintArgs(num_assignments=1, is_concurrent=True),
            provider=MockProviderArgs(requester_name=self.requester_name),
            architect=MockArchitectArgs(should_run_server=True),
            task=TaskConfigArgs(
                task_title="title",
                task_description="This is a description",
                task_reward="0.3",
                task_tags="1,2,3",
                maximum_units_per_worker=2,
                allowed_concurrent=1,
                task_name="max-unit-test",
            ),
        )
        self.operator.validate_and_run_config(OmegaConf.structured(config))
        tracked_runs = self.operator.get_running_task_runs()
        self.assertEqual(len(tracked_runs), 1, "Run not launched")
        task_run_id, tracked_run = list(tracked_runs.items())[0]
        self.await_server_start(tracked_run.architect)
        architect = tracked_run.architect

        # Workers one and two still shouldn't be able to make agents
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_1, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_2, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(
            len(agents),
            6,
            "Additional agent was created for worker exceeding max units",
        )

        # Three and four should though
        mock_agent_details = "FAKE_ASSIGNMENT_7"
        architect.server.register_mock_agent(worker_id_3, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 7, "Additional agent was not created")
        mock_agent_details = "FAKE_ASSIGNMENT_8"
        architect.server.register_mock_agent(worker_id_4, mock_agent_details)
        agents = self.db.find_agents()
        self.assertEqual(len(agents), 8, "Additional agent was not created")

        # Ensure the task run completed and that all assignments are done
        start_time = time.time()
        while time.time() - start_time < TIMEOUT_TIME:
            if len(self.operator.get_running_task_runs()) == 0:
                break
            time.sleep(0.1)
        self.assertLess(time.time() - start_time, TIMEOUT_TIME,
                        "Task not completed in time")
        task_run = tracked_run.task_run
        assignments = task_run.get_assignments()
        for assignment in assignments:
            self.assertEqual(assignment.get_status(),
                             AssignmentState.COMPLETED)
Пример #28
0
class AbstractCrowdsourcingTest(unittest.TestCase):
    """
    Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks.

    Allows for setup and teardown of the operator, as well as for config specification
    and agent registration.
    """
    def setUp(self):
        self.operator = None

    def tearDown(self):
        if self.operator is not None:
            self.operator.shutdown()

    def _set_up_config(
        self,
        blueprint_type: str,
        task_directory: str,
        overrides: Optional[List[str]] = None,
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory, 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name="example",
                overrides=[
                    f'+mephisto.blueprint._blueprint_type={blueprint_type}',
                    f'+mephisto/architect=mock',
                    f'+mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )
            # TODO: when Hydra 1.1 is released with support for recursive defaults,
            #  don't manually specify all missing blueprint args anymore, but
            #  instead define the blueprint in the defaults list directly.
            #  Currently, the blueprint can't be set in the defaults list without
            #  overriding params in the YAML file, as documented at
            #  https://github.com/facebookresearch/hydra/issues/326 and as fixed in
            #  https://github.com/facebookresearch/hydra/pull/1044.

        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True

    def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
        """
        Set up the operator and server.
        """
        self.operator = Operator(self.db)
        self.operator.validate_and_run_config(self.config.mephisto,
                                              shared_state=shared_state)
        channel_info = list(self.operator.supervisor.channels.values())[0]
        self.server = channel_info.job.architect.server

    def _register_mock_agents(self, num_agents: int = 1) -> List[str]:
        """
        Register mock agents for testing, taking the place of crowdsourcing workers.

        Specify the number of agents to register. Return the agents' IDs after creation.
        """

        for idx in range(num_agents):

            # Register the worker
            mock_worker_name = f"MOCK_WORKER_{idx:d}"
            self.server.register_mock_worker(mock_worker_name)
            workers = self.db.find_workers(worker_name=mock_worker_name)
            worker_id = workers[0].db_id

            # Register the agent
            mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}"
            self.server.register_mock_agent(worker_id, mock_agent_details)

        # Get all agents' IDs
        agents = self.db.find_agents()
        agent_ids = [agent.db_id for agent in agents]

        return agent_ids
Пример #29
0
class AbstractCrowdsourcingTest:
    """
    Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks.

    Allows for setup and teardown of the operator, as well as for config specification
    and agent registration.
    """
    def _setup(self):
        """
        To be run before a test.

        Should be called in a pytest setup/teardown fixture.
        """

        random.seed(0)
        np.random.seed(0)
        torch.manual_seed(0)

        self.operator = None
        self.server = None

    def _teardown(self):
        """
        To be run after a test.

        Should be called in a pytest setup/teardown fixture.
        """

        if self.operator is not None:
            self.operator.force_shutdown()

        if self.server is not None:
            self.server.shutdown_mock()

    def _set_up_config(
        self,
        task_directory: str,
        overrides: Optional[List[str]] = None,
        config_name: str = "example",
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory,
                                            'hydra_configs', 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name=config_name,
                overrides=[
                    f'mephisto/architect=mock',
                    f'mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )

        self.data_dir = tempfile.mkdtemp()
        self.database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(self.database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True

    def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
        """
        Set up the operator and server.
        """
        self.operator = Operator(self.db)
        self.operator.validate_and_run_config(self.config.mephisto,
                                              shared_state=shared_state)
        self.server = self._get_channel_info().job.architect.server

    def _get_channel_info(self):
        """
        Return channel info for the currently running job.
        """
        channels = list(self.operator.supervisor.channels.values())
        if len(channels) > 0:
            return channels[0]
        else:
            raise ValueError('No channel could be detected!')

    def _register_mock_agents(self,
                              num_agents: int = 1,
                              assume_onboarding: bool = False) -> List[str]:
        """
        Register mock agents for testing and onboard them if needed, taking the place of
        crowdsourcing workers.

        Specify the number of agents to register. Return the agents' IDs after creation.
        """

        for idx in range(num_agents):

            mock_worker_name = f"MOCK_WORKER_{idx:d}"
            max_num_tries = 6
            initial_wait_time = 0.5  # In seconds
            num_tries = 0
            wait_time = initial_wait_time
            while num_tries < max_num_tries:
                try:

                    # Register the worker
                    self.server.register_mock_worker(mock_worker_name)
                    workers = self.db.find_workers(
                        worker_name=mock_worker_name)
                    worker_id = workers[0].db_id

                    # Register the agent
                    mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}"
                    self.server.register_mock_agent(worker_id,
                                                    mock_agent_details)

                    if assume_onboarding:
                        # Submit onboarding from the agent
                        onboard_agents = self.db.find_onboarding_agents()
                        onboard_data = {"onboarding_data": {"success": True}}
                        self.server.register_mock_agent_after_onboarding(
                            worker_id, onboard_agents[0].get_agent_id(),
                            onboard_data)
                    _ = self.db.find_agents()[idx]
                    # Make sure the agent can be found, or else raise an IndexError

                    break
                except IndexError:
                    num_tries += 1
                    print(
                        f'The agent could not be registered after {num_tries:d} '
                        f'attempt(s), out of {max_num_tries:d} attempts total. Waiting '
                        f'for {wait_time:0.1f} seconds...')
                    time.sleep(wait_time)
                    wait_time *= 2  # Wait for longer next time
            else:
                raise ValueError('The worker could not be registered!')

        # Get all agents' IDs
        agents = self.db.find_agents()
        if len(agents) != num_agents:
            raise ValueError(
                f'The actual number of agents is {len(agents):d} instead of the '
                f'desired {num_agents:d}!')
        agent_ids = [agent.db_id for agent in agents]

        return agent_ids
Пример #30
0
from mephisto.abstractions.databases.local_database import LocalMephistoDB

import os
import atexit
import signal

app = Flask(__name__,
            static_url_path="/static",
            static_folder="webapp/build/static")
app.config.from_object(Config)

app.register_blueprint(api, url_prefix="/api/v1")

# Register extensions
db = LocalMephistoDB()
operator = Operator(db)
if not hasattr(app, "extensions"):
    app.extensions = {}
app.extensions["db"] = db
app.extensions["operator"] = operator


@app.route("/", defaults={"path": "index.html"})
@app.route("/<path:path>")
def index(path):
    return send_file(os.path.join("webapp", "build", "index.html"))


@app.after_request
def after_request(response):
    response.headers.add("Access-Control-Allow-Origin", "*")