コード例 #1
0
 def setUp(self) -> None:
     """
     Initialize a temporary database
     """
     self.data_dir = tempfile.mkdtemp()
     database_path = os.path.join(self.data_dir, "mephisto.db")
     self.db = LocalMephistoDB(database_path)
コード例 #2
0
ファイル: test_supervisor.py プロジェクト: wade3han/Mephisto
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)
        self.task_id = self.db.new_task("test_mock",
                                        MockBlueprint.BLUEPRINT_TYPE)
        self.task_run_id = get_test_task_run(self.db)
        self.task_run = TaskRun(self.db, self.task_run_id)

        architect_config = OmegaConf.structured(
            MephistoConfig(architect=MockArchitectArgs(
                should_run_server=True)))

        self.architect = MockArchitect(self.db, architect_config, EMPTY_STATE,
                                       self.task_run, self.data_dir)
        self.architect.prepare()
        self.architect.deploy()
        self.urls = self.architect._get_socket_urls()  # FIXME
        self.url = self.urls[0]
        self.provider = MockProvider(self.db)
        self.provider.setup_resources_for_task_run(self.task_run,
                                                   self.task_run.args,
                                                   EMPTY_STATE, self.url)
        self.launcher = TaskLauncher(self.db, self.task_run,
                                     self.get_mock_assignment_data_array())
        self.launcher.create_assignments()
        self.launcher.launch_units(self.url)
        self.sup = None
コード例 #3
0
def retrieve_units(run_id: int) -> list:
    db = LocalMephistoDB()
    units = db.find_units(task_run_id=run_id)
    completed_units = []
    for unit in units:
        if unit.db_status == "completed":
            completed_units.append(unit)
    return completed_units
コード例 #4
0
ファイル: cli.py プロジェクト: vaibhavad/Mephisto
def list_requesters():
    """Lists all registered requesters"""
    from mephisto.abstractions.databases.local_database import LocalMephistoDB
    from tabulate import tabulate

    db = LocalMephistoDB()
    requesters = db.find_requesters()
    dict_requesters = [r.to_dict() for r in requesters]
    click.echo(tabulate(dict_requesters, headers="keys"))
コード例 #5
0
 def setUp(self) -> None:
     """
     Setup should put together any requirements for starting the database for a test.
     """
     if not self.warned_about_setup:
         print(
             "Provider tests require using a test account for that crowd provider, "
             "you may need to set this up on your own.")
         self.warned_about_setup = True
     self.data_dir = tempfile.mkdtemp()
     database_path = os.path.join(self.data_dir, "mephisto.db")
     self.db = LocalMephistoDB(database_path)
コード例 #6
0
 def setUp(self) -> None:
     """
     Setup should put together any requirements for starting the database for a test.
     """
     self.data_dir = tempfile.mkdtemp()
     self.build_dir = tempfile.mkdtemp()
     database_path = os.path.join(self.data_dir, "mephisto.db")
     self.db = LocalMephistoDB(database_path)
     # TODO(#97) we need to actually pull the task type from the Blueprint
     self.task_run = TaskRun(self.db, get_test_task_run(self.db))
     # TODO(#97) create a mock agent with the given task type?
     self.TaskRunnerClass = self.BlueprintClass.TaskRunnerClass
     self.AgentStateClass = self.BlueprintClass.AgentStateClass
     self.TaskBuilderClass = self.BlueprintClass.TaskBuilderClass
コード例 #7
0
def plot_OS_browser(run_id: int) -> None:
    completed_units = retrieve_units(run_id)
    db = LocalMephistoDB()
    data_browser = DataBrowser(db=db)
    browsers = {}
    browser_versions = {}
    OSs = {}
    mobile = {"yes": 0, "no": 0}
    for unit in completed_units:
        data = data_browser.get_data_from_unit(unit)
        user_agent = json.loads(data["data"]["outputs"]["userAgent"])
        browser = user_agent["browser"]["name"]
        browsers = increment_dict(browsers, browser)
        browser_version = browser + str(user_agent["browser"]["v"])
        browser_versions = increment_dict(browser_versions, browser_version)
        OSs = increment_dict(OSs, user_agent["browser"]["os"])
        if user_agent["mobile"]:
            mobile["yes"] += 1
        else:
            mobile["no"] += 1

    plot_hist(browsers, xlabel="Browsers", ylabel=None)
    plot_hist(browser_versions, xlabel="Browser Versions", ylabel=None)
    plot_hist(OSs, xlabel="OS's", ylabel=None)
    plot_hist(mobile, xlabel="On Mobile", ylabel=None)
    return
コード例 #8
0
ファイル: test_database.py プロジェクト: wade3han/Mephisto
class TestLocalMephistoDB(BaseDatabaseTests):
    """
    Unit testing for the LocalMephistoDB

    Inherits all tests directly from BaseDataModelTests, and
    writes no additional tests.
    """

    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)

    def tearDown(self):
        self.db.shutdown()
        shutil.rmtree(self.data_dir)
コード例 #9
0
ファイル: run.py プロジェクト: sagar-spkt/ParlAI
def check_role_training_qualification(db: LocalMephistoDB, qname: str,
                                      requester_name: str):
    """
    Initializes the qualification name in DB, if it does not exist.
    """

    logging.info(f'Checking for "{qname}"" qualification.')
    if not db.find_qualifications(qname):
        logging.info('Creating the qualification.')
        db.make_qualification(qname)
        reqs = db.find_requesters(requester_name=requester_name,
                                  provider_type='mturk')
        requester = reqs[-1]
        requester._create_new_mturk_qualification(qname)
    else:
        logging.info('Qualification exists.')
コード例 #10
0
    def _set_up_config(
        self,
        blueprint_type: str,
        task_directory: str,
        overrides: Optional[List[str]] = None,
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory, 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name="example",
                overrides=[
                    f'+mephisto.blueprint._blueprint_type={blueprint_type}',
                    f'+mephisto/architect=mock',
                    f'+mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )
            # TODO: when Hydra 1.1 is released with support for recursive defaults,
            #  don't manually specify all missing blueprint args anymore, but
            #  instead define the blueprint in the defaults list directly.
            #  Currently, the blueprint can't be set in the defaults list without
            #  overriding params in the YAML file, as documented at
            #  https://github.com/facebookresearch/hydra/issues/326 and as fixed in
            #  https://github.com/facebookresearch/hydra/pull/1044.

        self.data_dir = tempfile.mkdtemp()
        self.database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(self.database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True
コード例 #11
0
def timing_charts(run_id: int) -> None:
    completed_units = retrieve_units(run_id)
    db = LocalMephistoDB()
    data_browser = DataBrowser(db=db)
    workers = {"total": []}
    unit_timing = {"total": [], "end": []}
    question_results = {1: [], 2: [], 3: [], 4: []}
    pass_rates = {1: [], 2: [], 3: [], 4: []}
    starttime = math.inf
    endtime = -math.inf
    feedback = []
    num_correct_hist = []
    bug_count = 0
    for unit in completed_units:
        data = data_browser.get_data_from_unit(unit)
        worker = Worker(db, data["worker_id"]).worker_name
        workers["total"].append(worker)
        starttime, endtime, unit_timing = hit_timing(data["data"], starttime, endtime, unit_timing)

        outputs = data["data"]["outputs"]
        feedback.append(outputs["feedback"])
        if outputs["bug"] == "true":
            bug_count += 1
        num_correct = 0
        for q in question_results.keys():
            key = "q" + str(q) + "Answer"
            question_results[q].append(outputs[key])
            if outputs[key] == "true":
                num_correct += 1
        num_correct_hist.append(num_correct)

    print(f"Job start time: {datetime.fromtimestamp(starttime)}")
    print(f"Job end time: {datetime.fromtimestamp(endtime)}")

    plot_hist_sorted(
        unit_timing["total"], cutoff=1200, target_val=600, xlabel="", ylabel="Total HIT Time (sec)"
    )
    calc_percentiles(unit_timing["total"], "HIT Length")

    for q in question_results.keys():
        results_dict = Counter(question_results[q])
        pass_rates[q] = (
            results_dict["true"] / (results_dict["true"] + results_dict["false"])
        ) * 100
        print(
            f"Question #{q} pass rate: {(results_dict['true']/(results_dict['true'] + results_dict['false']))*100:.1f}%"
        )
    plot_hist(pass_rates, xlabel="Question #", ylabel=f"Pass Rate %")
    print(
        f"Number of workers who didn't get any right: {len([x for x in num_correct_hist if x == 0])}"
    )

    keys = range(len(num_correct_hist))
    vals_dict = dict(zip(keys, num_correct_hist))
    plot_hist(vals_dict, xlabel="HIT #", ylabel="# Correct", ymax=4)

    print(f"Number of workers who experienced a window crash: {bug_count}")
    print(feedback)
コード例 #12
0
ファイル: cli.py プロジェクト: vaibhavad/Mephisto
def register_provider(args):
    """Register a requester with a crowd provider"""
    if len(args) == 0:
        click.echo(
            "Usage: mephisto register <provider_type> arg1=value arg2=value")
        return

    from mephisto.abstractions.databases.local_database import LocalMephistoDB
    from mephisto.operations.registry import get_crowd_provider_from_type
    from mephisto.operations.utils import parse_arg_dict, get_extra_argument_dicts

    provider_type, requester_args = args[0], args[1:]
    args_dict = dict(arg.split("=", 1) for arg in requester_args)

    crowd_provider = get_crowd_provider_from_type(provider_type)
    RequesterClass = crowd_provider.RequesterClass

    if len(requester_args) == 0:
        from tabulate import tabulate

        params = get_extra_argument_dicts(RequesterClass)
        for param in params:
            click.echo(param["desc"])
            click.echo(tabulate(param["args"].values(), headers="keys"))
        return

    try:
        parsed_options = parse_arg_dict(RequesterClass, args_dict)
    except Exception as e:
        click.echo(str(e))

    if parsed_options.name is None:
        click.echo("No name was specified for the requester.")

    db = LocalMephistoDB()
    requesters = db.find_requesters(requester_name=parsed_options.name)
    if len(requesters) == 0:
        requester = RequesterClass.new(db, parsed_options.name)
    else:
        requester = requesters[0]
    try:
        requester.register(parsed_options)
        click.echo("Registered successfully.")
    except Exception as e:
        click.echo(str(e))
コード例 #13
0
class TestMTurkComponents(unittest.TestCase):
    """
    Unit testing for components of the MTurk crowd provider
    """
    def setUp(self) -> None:
        """
        Initialize a temporary database
        """
        self.data_dir = tempfile.mkdtemp()
        database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(database_path)

    def tearDown(self) -> None:
        """
        Delete the temporary database
        """
        self.db.shutdown()
        shutil.rmtree(self.data_dir)

    @pytest.mark.req_creds
    def test_create_and_find_worker(self) -> None:
        """Ensure we can find a worker by MTurk id"""
        db = self.db
        TEST_MTURK_WORKER_ID = "ABCDEFGHIJ"

        test_worker = MTurkWorker.new(db, TEST_MTURK_WORKER_ID)
        test_worker_2 = Worker(db, test_worker.db_id)
        self.assertEqual(
            test_worker.worker_name,
            test_worker_2.worker_name,
            "Worker gotten from db not same as first init",
        )

        test_worker_3 = MTurkWorker.get_from_mturk_worker_id(
            db, TEST_MTURK_WORKER_ID)
        self.assertEqual(
            test_worker.worker_name,
            test_worker_3.worker_name,
            "Worker gotten from db not same as first init",
        )

        failed_worker = MTurkWorker.get_from_mturk_worker_id(db, "FAKE_ID")
        self.assertIsNone(failed_worker,
                          f"Found worker {failed_worker} from a fake id")
コード例 #14
0
def main():
    db = LocalMephistoDB()
    reqs = db.find_requesters(provider_type="mturk")
    names = [r.requester_name for r in reqs]
    print("Available Requesters: ", names)

    requester_name = input("Select a requester to soft block from: ")
    soft_block_qual_name = input("Provide a soft blocking qualification name: ")

    workers_to_block = []
    while True:
        new_id = input("MTurk Worker Id to soft block (blank to block all entered): ")
        if len(new_id.strip()) == 0:
            break
        workers_to_block.append(new_id)

    direct_soft_block_mturk_workers(
        db, workers_to_block, soft_block_qual_name, requester_name
    )
コード例 #15
0
    def mephistoDBReader():
        from mephisto.abstractions.databases.local_database import LocalMephistoDB
        from mephisto.tools.data_browser import DataBrowser as MephistoDataBrowser

        db = LocalMephistoDB()
        mephisto_data_browser = MephistoDataBrowser(db=db)

        units = mephisto_data_browser.get_units_for_task_name(database_task_name)
        for unit in units:
            yield mephisto_data_browser.get_data_from_unit(unit)
コード例 #16
0
 def setUp(self) -> None:
     """
     Setup should put together any requirements for starting the database for a test.
     """
     try:
         _ = self.ArchitectClass
     except:
         raise unittest.SkipTest("Skipping test as no ArchitectClass set")
     if not self.warned_about_setup:
         print(
             "Architect tests may require using an account with the server provider "
             "in order to function properly. Make sure these are configured before testing."
         )
         self.warned_about_setup = True
     self.data_dir = tempfile.mkdtemp()
     database_path = os.path.join(self.data_dir, "mephisto.db")
     self.db = LocalMephistoDB(database_path)
     self.build_dir = tempfile.mkdtemp()
     self.task_run = TaskRun(self.db, get_test_task_run(self.db))
     builder = MockTaskBuilder(self.task_run, {})
     builder.build_in_dir(self.build_dir)
コード例 #17
0
ファイル: scripts.py プロジェクト: wade3han/Mephisto
def get_db_from_config(cfg: DictConfig) -> "MephistoDB":
    """
    Get a MephistoDB from the given configuration. As of now
    this defaults to a LocalMephistoDB
    """
    datapath = cfg.mephisto.get("datapath", None)

    if datapath is None:
        datapath = get_root_data_dir()

    database_path = os.path.join(datapath, "database.db")
    return LocalMephistoDB(database_path=database_path)
コード例 #18
0
ファイル: tests.py プロジェクト: sagar-spkt/ParlAI
    def _set_up_config(
        self,
        task_directory: str,
        overrides: Optional[List[str]] = None,
        config_name: str = "example",
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory,
                                            'hydra_configs', 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name=config_name,
                overrides=[
                    f'mephisto/architect=mock',
                    f'mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )

        self.data_dir = tempfile.mkdtemp()
        self.database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(self.database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True
コード例 #19
0
ファイル: cli.py プロジェクト: vaibhavad/Mephisto
def check():
    """Checks that mephisto is setup correctly"""
    from mephisto.abstractions.databases.local_database import LocalMephistoDB
    from mephisto.operations.utils import get_mock_requester

    try:
        db = LocalMephistoDB()
        get_mock_requester(db)
    except Exception as e:
        click.echo("Something went wrong.")
        click.echo(e)
        return
    click.echo("Mephisto seems to be set up correctly.")
コード例 #20
0
ファイル: analysis.py プロジェクト: sagar-spkt/ParlAI
    def __init__(self, opt: Dict, remove_failed: bool = True):
        """
        Initialize the analyzer.

        Builds up the dataframe

        :param opt:
            opt dict

        :param remove_failed:
            Whether to remove ratings from turkers who failed onboarding
        """
        assert ',' not in opt['run_ids'], "AcuteAnalyzer can only handle one run ID!"
        self.run_id = opt['run_ids']
        self.pairings_filepath = opt['pairings_filepath']
        self.outdir = opt['outdir']
        self.root_dir = opt['root_dir']
        # Get task for loading pairing files
        self.task = opt.get('task', 'q')
        if opt.get('model_ordering') is not None:
            self.custom_model_ordering = opt['model_ordering'].split(',')
        else:
            self.custom_model_ordering = None
        if not self.outdir or not self.pairings_filepath:
            # Default to using self.root_dir as the root directory for outputs
            assert self.root_dir is not None and os.path.isdir(
                self.root_dir
            ), '--root-dir must be a real directory!'
        if not self.pairings_filepath:
            # Will be set to a non-empty path later
            self.pairings_filepath = ''
        if not self.outdir:
            self.outdir = os.path.join(self.root_dir, f'{self.run_id}-results')
        if not os.path.exists(self.outdir):
            os.makedirs(self.outdir, exist_ok=True)
        mephisto_root_path = opt['mephisto_root']
        if not mephisto_root_path:
            mephisto_root_path = None
        self.mephisto_db = LocalMephistoDB(database_path=mephisto_root_path)
        self.mephisto_data_browser = MephistoDataBrowser(db=self.mephisto_db)
        self.checkbox_prefix = self.CHECKBOX_PREFIX
        # Prepended to checkbox columns in self.dataframe
        self.dataframe = self._extract_to_dataframe()
        self._check_eval_question()
        if remove_failed:
            self._remove_failed_onboarding()
        if self.dataframe.index.size == 0:
            raise ValueError('No valid results found!')
        self._get_model_nick_names()
        self._load_pairing_files()
コード例 #21
0
def main():
    task_run_id = input("Please enter the task_run_id you'd like to check: ")
    db = LocalMephistoDB()
    task_run = TaskRun.get(db, task_run_id)
    requester = task_run.get_requester()
    if not isinstance(requester, MTurkRequester):
        print(
            "Must be checking a task launched on MTurk, this one uses the following requester:"
        )
        print(requester)
        exit(0)

    turk_db = db.get_datastore_for_provider("mturk")
    hits = turk_db.get_unassigned_hit_ids(task_run_id)

    print(f"Found the following HIT ids unassigned: {hits}")

    # print all of the HITs found above
    from mephisto.abstractions.providers.mturk.mturk_utils import get_hit

    for hit_id in hits:
        hit_info = get_hit(requester._get_client(requester._requester_name),
                           hits[0])
        print(f"MTurk HIT data for {hit_id}:\n{hit_info}\n")
コード例 #22
0
    def mephistoDBReader():
        from mephisto.abstractions.databases.local_database import LocalMephistoDB
        from mephisto.tools.data_browser import DataBrowser as MephistoDataBrowser

        db = LocalMephistoDB()
        mephisto_data_browser = MephistoDataBrowser(db=db)

        def format_data_for_review(data):
            contents = data["data"]
            return f"{data}"

        units = mephisto_data_browser.get_units_for_task_name(
            database_task_name)
        for unit in units:
            yield format_data_for_review(
                mephisto_data_browser.get_data_from_unit(unit))
コード例 #23
0
ファイル: server.py プロジェクト: facebookresearch/Mephisto
def main():
    app = Flask(
        __name__, static_url_path="/static", static_folder="webapp/build/static"
    )
    app.config.from_object(Config)

    app.register_blueprint(api, url_prefix="/api/v1")

    # Register extensions
    db = LocalMephistoDB()
    operator = Operator(db)
    if not hasattr(app, "extensions"):
        app.extensions = {}
    app.extensions["db"] = db
    app.extensions["operator"] = operator

    @app.route("/", defaults={"path": "index.html"})
    @app.route("/<path:path>")
    def index(path):
        return send_file(os.path.join("webapp", "build", "index.html"))

    @app.after_request
    def after_request(response):
        response.headers.add("Access-Control-Allow-Origin", "*")
        response.headers.add(
            "Access-Control-Allow-Headers", "Content-Type,Authorization"
        )
        response.headers.add(
            "Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS"
        )
        response.headers.add("Cache-Control", "no-store")
        return response

    term_handler = signal.getsignal(signal.SIGINT)

    def cleanup_resources(*args, **kwargs):
        operator.shutdown()
        db.shutdown()
        term_handler(*args, **kwargs)

    atexit.register(cleanup_resources)
    signal.signal(signal.SIGINT, cleanup_resources)
コード例 #24
0
def get_db_from_config(cfg: DictConfig) -> "MephistoDB":
    """
    Get a MephistoDB from the given configuration. As of now
    this defaults to a LocalMephistoDB
    """
    datapath = cfg.mephisto.get("datapath", None)

    if datapath is None:
        datapath = get_root_data_dir()

    database_path = os.path.join(datapath, "database.db")

    database_type = cfg.mephisto.database._database_type

    if database_type == "local":
        return LocalMephistoDB(database_path=database_path)
    elif database_type == "singleton":
        return MephistoSingletonDB(database_path=database_path)
    else:
        raise AssertionError(f"Provided database_type {database_type} is not valid")
コード例 #25
0
ファイル: cli.py プロジェクト: fusesagar/Mephisto
def review(
    review_app_dir,
    port,
    output,
    output_method,
    csv_headers,
    json,
    database_task_name,
    all_data,
    debug,
):
    """Launch a local review UI server. Reads in rows froms stdin and outputs to either a file or stdout."""
    from mephisto.client.review.review_server import run

    if output == "" and output_method == "file":
        raise click.UsageError(
            "You must specify an output file via --output=<filename>, unless the --stdout flag is set."
        )
    if database_task_name is not None:
        from mephisto.abstractions.databases.local_database import LocalMephistoDB
        from mephisto.tools.data_browser import DataBrowser as MephistoDataBrowser

        db = LocalMephistoDB()
        mephisto_data_browser = MephistoDataBrowser(db=db)
        name_list = mephisto_data_browser.get_task_name_list()
        if database_task_name not in name_list:
            raise click.BadParameter(
                f'The task name "{database_task_name}" did not exist in MephistoDB.\n\nPerhaps you meant one of these? {", ".join(name_list)}\n\nFlag usage: mephisto review --db [task_name]\n'
            )

    run(
        review_app_dir,
        port,
        output,
        csv_headers,
        json,
        database_task_name,
        all_data,
        debug,
    )
コード例 #26
0
 def setUp(self):
     self.data_dir = tempfile.mkdtemp()
     database_path = os.path.join(self.data_dir, "mephisto.db")
     self.db = LocalMephistoDB(database_path)
コード例 #27
0
ファイル: tests.py プロジェクト: sagar-spkt/ParlAI
class AbstractCrowdsourcingTest:
    """
    Abstract class for end-to-end tests of Mephisto-based crowdsourcing tasks.

    Allows for setup and teardown of the operator, as well as for config specification
    and agent registration.
    """
    def _setup(self):
        """
        To be run before a test.

        Should be called in a pytest setup/teardown fixture.
        """

        random.seed(0)
        np.random.seed(0)
        torch.manual_seed(0)

        self.operator = None
        self.server = None

    def _teardown(self):
        """
        To be run after a test.

        Should be called in a pytest setup/teardown fixture.
        """

        if self.operator is not None:
            self.operator.force_shutdown()

        if self.server is not None:
            self.server.shutdown_mock()

    def _set_up_config(
        self,
        task_directory: str,
        overrides: Optional[List[str]] = None,
        config_name: str = "example",
    ):
        """
        Set up the config and database.

        Uses the Hydra compose() API for unit testing and a temporary directory to store
        the test database.
        :param blueprint_type: string uniquely specifying Blueprint class
        :param task_directory: directory containing the `conf/` configuration folder.
          Will be injected as `${task_dir}` in YAML files.
        :param overrides: additional config overrides
        """

        # Define the configuration settings
        relative_task_directory = os.path.relpath(task_directory,
                                                  os.path.dirname(__file__))
        relative_config_path = os.path.join(relative_task_directory,
                                            'hydra_configs', 'conf')
        if overrides is None:
            overrides = []
        with initialize(config_path=relative_config_path):
            self.config = compose(
                config_name=config_name,
                overrides=[
                    f'mephisto/architect=mock',
                    f'mephisto/provider=mock',
                    f'+task_dir={task_directory}',
                    f'+current_time={int(time.time())}',
                ] + overrides,
            )

        self.data_dir = tempfile.mkdtemp()
        self.database_path = os.path.join(self.data_dir, "mephisto.db")
        self.db = LocalMephistoDB(self.database_path)
        self.config = augment_config_from_db(self.config, self.db)
        self.config.mephisto.architect.should_run_server = True

    def _set_up_server(self, shared_state: Optional[SharedTaskState] = None):
        """
        Set up the operator and server.
        """
        self.operator = Operator(self.db)
        self.operator.validate_and_run_config(self.config.mephisto,
                                              shared_state=shared_state)
        self.server = self._get_channel_info().job.architect.server

    def _get_channel_info(self):
        """
        Return channel info for the currently running job.
        """
        channels = list(self.operator.supervisor.channels.values())
        if len(channels) > 0:
            return channels[0]
        else:
            raise ValueError('No channel could be detected!')

    def _register_mock_agents(self,
                              num_agents: int = 1,
                              assume_onboarding: bool = False) -> List[str]:
        """
        Register mock agents for testing and onboard them if needed, taking the place of
        crowdsourcing workers.

        Specify the number of agents to register. Return the agents' IDs after creation.
        """

        for idx in range(num_agents):

            mock_worker_name = f"MOCK_WORKER_{idx:d}"
            max_num_tries = 6
            initial_wait_time = 0.5  # In seconds
            num_tries = 0
            wait_time = initial_wait_time
            while num_tries < max_num_tries:
                try:

                    # Register the worker
                    self.server.register_mock_worker(mock_worker_name)
                    workers = self.db.find_workers(
                        worker_name=mock_worker_name)
                    worker_id = workers[0].db_id

                    # Register the agent
                    mock_agent_details = f"FAKE_ASSIGNMENT_{idx:d}"
                    self.server.register_mock_agent(worker_id,
                                                    mock_agent_details)

                    if assume_onboarding:
                        # Submit onboarding from the agent
                        onboard_agents = self.db.find_onboarding_agents()
                        onboard_data = {"onboarding_data": {"success": True}}
                        self.server.register_mock_agent_after_onboarding(
                            worker_id, onboard_agents[0].get_agent_id(),
                            onboard_data)
                    _ = self.db.find_agents()[idx]
                    # Make sure the agent can be found, or else raise an IndexError

                    break
                except IndexError:
                    num_tries += 1
                    print(
                        f'The agent could not be registered after {num_tries:d} '
                        f'attempt(s), out of {max_num_tries:d} attempts total. Waiting '
                        f'for {wait_time:0.1f} seconds...')
                    time.sleep(wait_time)
                    wait_time *= 2  # Wait for longer next time
            else:
                raise ValueError('The worker could not be registered!')

        # Get all agents' IDs
        agents = self.db.find_agents()
        if len(agents) != num_agents:
            raise ValueError(
                f'The actual number of agents is {len(agents):d} instead of the '
                f'desired {num_agents:d}!')
        agent_ids = [agent.db_id for agent in agents]

        return agent_ids
コード例 #28
0
    StaticReactBlueprint,
    BLUEPRINT_TYPE,
)
from mephisto.abstractions.blueprint import AgentState
from mephisto.abstractions.providers.mock.mock_requester import MockRequester
from mephisto.abstractions.providers.mock.mock_worker import MockWorker
from mephisto.abstractions.providers.mock.mock_agent import MockAgent
from mephisto.abstractions.databases.local_database import LocalMephistoDB
from mephisto.data_model.assignment import Assignment, InitializationData
from mephisto.data_model.unit import Unit
from mephisto.data_model.agent import Agent
from mephisto.tools.data_browser import DataBrowser as MephistoDataBrowser

import json

db = LocalMephistoDB()

# Get the requester that the run will be requested from
all_requesters = db.find_requesters(provider_type="mock")

print("You have the following requesters available for use on mock:")
r_names = [r.requester_name for r in all_requesters]
print(sorted(r_names))

use_name = input("Enter the name of the requester to use, or a new requester:\n>> ")
while use_name not in r_names:
    confirm = input(
        f"{use_name} is not in the requester list. "
        f"Would you like to create a new MockRequester with this name? (y)/n > "
    )
    if confirm.lower().startswith("n"):
コード例 #29
0
 def get_mephisto_db(self) -> LocalMephistoDB:
     if not self._mephisto_db:
         self._mephisto_db = LocalMephistoDB()
     return self._mephisto_db
コード例 #30
0
def main():
    global db
    db = LocalMephistoDB()
    run_examine_or_review(db, format_for_printing_data)