def create_agent(opt, requireModelExists=False): """ Create an agent from the options ``model``, ``model_params`` and ``model_file``. The input is either of the form ``parlai.agents.ir_baseline.agents:IrBaselineAgent`` (i.e. the path followed by the class name) or else just ``ir_baseline`` which assumes the path above, and a class name suffixed with 'Agent'. If ``model-file`` is available in the options this function can also attempt to load the model from that location instead. This avoids having to specify all the other options necessary to set up the model including its name as they are all loaded from the options file if it exists (the file opt['model_file'] + '.opt' must exist and contain a pickled or json dict containing the model's options). """ if opt.get('datapath', None) is None: # add datapath, it is missing from parlai.core.params import ParlaiParser, get_model_name parser = ParlaiParser(add_parlai_args=False) parser.add_parlai_data_path() # add model args if they are missing model = get_model_name(opt) if model is not None: parser.add_model_subargs(model) opt_parser = parser.parse_args("", print_args=False) for k, v in opt_parser.items(): if k not in opt: opt[k] = v if opt.get('model_file'): opt['model_file'] = modelzoo_path(opt.get('datapath'), opt['model_file']) if requireModelExists and not os.path.isfile(opt['model_file']): raise RuntimeError( 'WARNING: Model file does not exist, check to make ' 'sure it is correct: {}'.format(opt['model_file'])) # Attempt to load the model from the model file first (this way we do # not even have to specify the model name as a parameter) model = load_agent_module(opt) if model is not None: return model else: print("[ no model with opt yet at: " + opt.get('model_file') + "(.opt) ]") if opt.get('model'): model_class = get_agent_module(opt['model']) # if we want to load weights from --init-model, compare opts with # loaded ones compare_init_model_opts(opt, opt) model = model_class(opt) if requireModelExists and hasattr( model, 'load') and not opt.get('model_file'): # double check that we didn't forget to set model_file on loadable model print('WARNING: model_file unset but model has a `load` function.') return model else: raise RuntimeError('Need to set `model` argument to use create_agent.')
def add_datapath_and_model_args(opt: Opt): # add datapath, it is missing from parlai.core.params import ParlaiParser, get_model_name parser = ParlaiParser(add_parlai_args=False) parser.add_parlai_data_path() # add model args if they are missing model = get_model_name(opt) if model is not None: parser.add_model_subargs(model, opt) opt_parser = parser.parse_args("") for k, v in opt_parser.items(): if k not in opt: opt[k] = v
def main(): """This task consists of an MTurk agent evaluating a wizard model. They are assigned a topic and asked to chat. """ start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M') argparser = ParlaiParser(False, add_model_args=True) argparser.add_parlai_data_path() argparser.add_mturk_args() argparser.add_argument('-mt', '--max-turns', default=10, type=int, help='maximal number of chat turns') argparser.add_argument('--max-resp-time', default=240, type=int, help='time limit for entering a dialog message') argparser.add_argument('--max-choice-time', type=int, default=300, help='time limit for turker' 'choosing the topic') argparser.add_argument('--ag-shutdown-time', default=120, type=int, help='time limit for entering a dialog message') argparser.add_argument('-rt', '--range-turn', default='3,5', help='sample range of number of turns') argparser.add_argument('--human-eval', type='bool', default=False, help='human vs human eval, no models involved') argparser.add_argument('--auto-approve-delay', type=int, default=3600 * 24 * 1, help='how long to wait for auto approval') argparser.add_argument('--only-masters', type='bool', default=False, help='Set to true to use only master turks for ' 'this test eval') argparser.add_argument('--unique-workers', type='bool', default=False, help='Each worker must be unique') argparser.add_argument('--mturk-log', type=str, default='data/mturklogs/{}.log'.format(start_time)) def inject_override(opt, override_dict): opt['override'] = override_dict for k, v in override_dict.items(): opt[k] = v def get_logger(opt): logger = logging.getLogger() logger.setLevel(logging.INFO) fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') console = logging.StreamHandler() console.setFormatter(fmt) logger.addHandler(console) if 'mturk_log' in opt: logfile = logging.FileHandler(opt['mturk_log'], 'a') logfile.setFormatter(fmt) logger.addHandler(logfile) logger.info('COMMAND: %s' % ' '.join(sys.argv)) logger.info('-' * 100) logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True)) return logger # MODEL CONFIG # NOTE: please edit this to test your own models config = { 'model': 'projects:wizard_of_wikipedia:interactive_retrieval', 'retriever_model_file': 'models:wikipedia_full/tfidf_retriever/model', 'responder_model_file': 'models:wizard_of_wikipedia/full_dialogue_retrieval_model/model', } argparser.add_model_subargs(config['model']) # add model args to opt start_opt = argparser.parse_args() inject_override(start_opt, config) if not start_opt.get('human_eval'): bot = create_agent(start_opt) shared_bot_params = bot.share() else: shared_bot_params = None if not start_opt['human_eval']: get_logger(bot.opt) else: get_logger(start_opt) if start_opt['human_eval']: folder_name = 'human_eval-{}'.format(start_time) else: folder_name = '{}-{}'.format(start_opt['model'], start_time) start_opt['task'] = os.path.basename( os.path.dirname(os.path.abspath(__file__))) if 'data_path' not in start_opt: start_opt['data_path'] = os.path.join(os.getcwd(), 'data', 'wizard_eval', folder_name) start_opt.update(task_config) if not start_opt.get('human_eval'): mturk_agent_ids = ['PERSON_1'] else: mturk_agent_ids = ['PERSON_1', 'PERSON_2'] mturk_manager = MTurkManager(opt=start_opt, mturk_agent_ids=mturk_agent_ids) topics_generator = TopicsGenerator(start_opt) directory_path = os.path.dirname(os.path.abspath(__file__)) mturk_manager.setup_server(task_directory_path=directory_path) worker_roles = {} connect_counter = AttrDict(value=0) try: mturk_manager.start_new_run() agent_qualifications = [] if not start_opt['is_sandbox']: # assign qualifications if start_opt['only_masters']: agent_qualifications.append(MASTER_QUALIF) if start_opt['unique_workers']: qual_name = 'UniqueChatEval' qual_desc = ( 'Qualification to ensure each worker completes a maximum ' 'of one of these chat/eval HITs') qualification_id = \ mturk_utils.find_or_create_qualification(qual_name, qual_desc, False) print('Created qualification: ', qualification_id) UNIQUE_QUALIF = { 'QualificationTypeId': qualification_id, 'Comparator': 'DoesNotExist', 'RequiredToPreview': True } start_opt['unique_qualif_id'] = qualification_id agent_qualifications.append(UNIQUE_QUALIF) mturk_manager.create_hits(qualifications=agent_qualifications) def run_onboard(worker): if start_opt['human_eval']: role = mturk_agent_ids[connect_counter.value % len(mturk_agent_ids)] connect_counter.value += 1 worker_roles[worker.worker_id] = role else: role = 'PERSON_1' worker.topics_generator = topics_generator world = TopicChooseWorld(start_opt, worker, role=role) world.parley() world.shutdown() mturk_manager.set_onboard_function(onboard_function=run_onboard) mturk_manager.ready_to_accept_workers() def check_single_worker_eligibility(worker): return True def check_multiple_workers_eligibility(workers): valid_workers = {} for worker in workers: worker_id = worker.worker_id if worker_id not in worker_roles: print('Something went wrong') continue role = worker_roles[worker_id] if role not in valid_workers: valid_workers[role] = worker if len(valid_workers) == 2: break return valid_workers.values() if len(valid_workers) == 2 else [] if not start_opt['human_eval']: eligibility_function = { 'func': check_single_worker_eligibility, 'multiple': False, } else: eligibility_function = { 'func': check_multiple_workers_eligibility, 'multiple': True, } def assign_worker_roles(workers): if start_opt['human_eval']: for worker in workers: worker.id = worker_roles[worker.worker_id] else: for index, worker in enumerate(workers): worker.id = mturk_agent_ids[index % len(mturk_agent_ids)] def run_conversation(mturk_manager, opt, workers): conv_idx = mturk_manager.conversation_index world = WizardEval( opt=start_opt, agents=workers, range_turn=[ int(s) for s in start_opt['range_turn'].split(',') ], max_turn=start_opt['max_turns'], max_resp_time=start_opt['max_resp_time'], model_agent_opt=shared_bot_params, world_tag='conversation t_{}'.format(conv_idx), agent_timeout_shutdown=opt['ag_shutdown_time'], ) while not world.episode_done(): world.parley() world.save_data() world.shutdown() gc.collect() mturk_manager.start_task(eligibility_function=eligibility_function, assign_role_function=assign_worker_roles, task_function=run_conversation) except BaseException: raise finally: mturk_manager.expire_all_unassigned_hits() mturk_manager.shutdown()