def test_base_task(self): # Paths expected_states_folder = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'expected_states') expected_state_path = os.path.join(expected_states_folder, 'state.json') # # Setup build_task(task_directory=TASK_DIRECTORY) # Set up the config and database overrides = [ 'mephisto.blueprint.num_conversations=1', 'mephisto.task.allowed_concurrent=0', '+turn_timeout=300', ] # TODO: remove all of these params once Hydra 1.1 is released with # support for recursive defaults self._set_up_config( blueprint_type=BLUEPRINT_TYPE, task_directory=TASK_DIRECTORY, overrides=overrides, ) # Set up the operator and server teacher = get_teacher(self.config) world_opt = { "turn_timeout": self.config.turn_timeout, "teacher": teacher } shared_state = SharedParlAITaskState( world_opt=world_opt, onboarding_world_opt=world_opt) self._set_up_server(shared_state=shared_state) # Check that the agent states are as they should be with open(expected_state_path) as f: expected_state = json.load(f) self._test_agent_states( num_agents=1, agent_display_ids=AGENT_DISPLAY_IDS, agent_messages=AGENT_MESSAGES, form_messages=FORM_MESSAGES, form_task_data=FORM_TASK_DATA, expected_states=(expected_state, ), )
def main(cfg: DictConfig) -> None: db, cfg = load_db_and_process_config(cfg) teacher = get_teacher(cfg) world_opt = {"turn_timeout": cfg.turn_timeout, "teacher": teacher} custom_bundle_path = cfg.mephisto.blueprint.get("custom_source_bundle", None) if custom_bundle_path is not None: if not os.path.exists(custom_bundle_path): build_task(TASK_DIRECTORY) shared_state = SharedParlAITaskState(world_opt=world_opt, onboarding_world_opt=world_opt) operator = Operator(db) operator.validate_and_run_config(run_config=cfg.mephisto, shared_state=shared_state) operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=cfg.monitoring_log_rate)
def test_base_task(self): # Paths expected_states_folder = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'expected_states') expected_state_path = os.path.join(expected_states_folder, 'state.json') # # Setup build_task(task_directory=TASK_DIRECTORY) # Set up the config and database overrides = ['+turn_timeout=300'] self._set_up_config(task_directory=TASK_DIRECTORY, overrides=overrides) # Set up the operator and server teacher = get_teacher(self.config) world_opt = { "turn_timeout": self.config.turn_timeout, "teacher": teacher } shared_state = SharedParlAITaskState( world_opt=world_opt, onboarding_world_opt=world_opt) self._set_up_server(shared_state=shared_state) # Check that the agent states are as they should be with open(expected_state_path) as f: expected_state = json.load(f) self._test_agent_states( num_agents=1, agent_display_ids=AGENT_DISPLAY_IDS, agent_messages=AGENT_MESSAGES, form_messages=FORM_MESSAGES, form_task_data=FORM_TASK_DATA, expected_states=(expected_state, ), )