def self_chat(opt): random.seed(opt['seed']) partner = opt['partner_model_file'] partner_opt_file = opt.get('partner_opt_file') # Create agents agent1 = create_agent(opt, requireModelExists=True) if partner is None: # Self chat with same model agent2 = agent1.clone() else: # Self chat with different models if partner_opt_file: print(f"WARNING: Loading override opts from: {partner_opt_file}") with open(partner_opt_file) as f: partner_opt = json.load(f) else: partner_opt = {} partner_opt['interactive_mode'] = opt.get('interactive_mode', True) print( f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}" ) agent2 = create_agent_from_model_file(partner, partner_opt) # Set IDs agent1.id = agent1.id + "_1" agent2.id = agent2.id + "_2" model_id = agent1.id + "_" + agent2.id world = create_task(opt, user_agents=[agent1, agent2]) # Set up world logging logger = WorldLogger(opt) log_time = TimeLogger() # Run some self chats. for i in range(opt['num_self_chats']): _run_self_chat_episode(opt, world, logger) report = world.report() text, report = log_time.log(i + 1, opt['num_self_chats'], report) logging.info(text) # Save chats if opt['outfile'] is None: outfile = '/tmp/{}_selfchat'.format(model_id) else: outfile = opt['outfile'] if opt['save_format'] == 'conversations' and hasattr(world, 'write'): # use self chat specific world to write conversation # this might be useful for logging extra contextual # information (like personas) world.write(logger, outfile) else: # use default logger write function logger.write(outfile, world, opt['save_format']) return logger.get_logs()
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Generate self-chats of a model') parser.add_argument('--seed', type=int, default=42) parser.add_argument('-d', '--display-examples', type='bool', default=True) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-st', '--selfchat-task', type='bool', default=True, help='Create a self chat version of the task', ) parser.add_argument('--num-self-chats', type=int, default=1, help='Number of self chats to run') parser.add_argument( '--selfchat-max-turns', type=int, default=6, help='The number of dialogue turns before self chat ends', ) parser.add_argument( '--seed-messages-from-task', action='store_true', help='Automatically seed conversation with messages from task dataset.', ) parser.add_argument('--outfile', type=str, default=None, help='File to save self chat logs') parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], help= 'Format to save logs in. conversations is a jsonl format, parlai is a text format.', ) parser.add_argument( '-pmf', '--partner-model-file', default=None, help='Define a different partner for self chat', ) parser.add_argument( '--partner-opt-file', default=None, help='Path to file containing opts to override for partner', ) parser.set_defaults(interactive_mode=True, task='self_chat') WorldLogger.add_cmdline_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Self chat with a model') parser.add_argument('--seed', type=int, default=42) parser.add_argument('-d', '--display-examples', type='bool', default=True) parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10) parser.add_argument('-nd', '--num-dialogues', type=int, default=10) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument('-host', '--mongo-host', type=str) parser.add_argument('-port', '--mongo-port', type=int) parser.add_argument('-user', '--user-name', type=str) parser.add_argument('-pw', '--password', type=str) parser.add_argument('-col', '--collection-name', type=str) parser.add_argument( '-mf1', '--model-file1', default=None, help='model file name for loading and saving models', ) parser.add_argument( '-mf2', '--model-file2', default=None, help='model file name for loading and saving models', ) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-it', '--interactive-task', type='bool', default=True, help='Create interactive version of task', ) parser.add_argument( '--selfchat-max-turns', type=int, default=10, help="The number of dialogue turns before self chat ends.", ) parser.add_argument( '--seed-messages-from-task', action='store_true', help="Automatically seed conversation with messages from task dataset.", ) parser.add_argument('--outfile', type=str, default='/tmp/selfchat.json') parser.add_argument('--format', type=str, default='json', choices={'parlai', 'json'}) parser.set_defaults(interactive_mode=True, task='self_chat') WorldLogger.add_cmdline_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser( True, True, 'Interactive chat with a model on the command line' ) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.set_defaults(interactive_mode=True, task='interactive') LocalHumanAgent.add_cmdline_args(parser) WorldLogger.add_cmdline_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Self chat with a model') parser.add_argument('--seed', type=int, default=42) parser.add_argument('-d', '--display-examples', type='bool', default=True) parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10) parser.add_argument('-nd', '--num-dialogues', type=int, default=10) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument('-host', '--mongo-host', type=str) parser.add_argument('-port', '--mongo-port', type=int) parser.add_argument('-user', '--user-name', type=str) parser.add_argument('-pw', '--password', type=str) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-st', '--selfchat-task', type='bool', default=True, help='Create a self chat version of the task', ) parser.add_argument('--num-self-chats', type=int, default=1, help='Number of self chats to run') parser.add_argument( '--selfchat-max-turns', type=int, default=6, help='The number of dialogue turns before self chat ends', ) parser.add_argument( '--seed-messages-from-task', action='store_true', help='Automatically seed conversation with messages from task dataset.', ) parser.add_argument('--outfile', type=str, default=None, help='File to save self chat logs') parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], help= 'Format to save logs in. conversations is a jsonl format, parlai is a text format.', ) parser.set_defaults(interactive_mode=True, task='self_chat') WorldLogger.add_cmdline_args(parser) return parser
def self_chat(opt, print_parser=None): if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): print('[ Deprecated Warning: self_chat should be passed opt not Parser ]') opt = opt.parse_args() random.seed(opt['seed']) # Create models agent1 = create_agent(opt, requireModelExists=True) agent2 = agent1.clone() if hasattr(agent2, 'id'): agent2.id = agent2.id + "2" # Check for `selfchat` in the task name if 'selfchat' not in opt['task']: warn_once( 'You are using self chat with task {}. '.format(opt['task']) + 'If your task has an existing self chat world, then run with ' '-t {}:selfchat'.format(opt['task']) ) world = create_task(opt, [agent1, agent2]) if print_parser: # Show arguments after loading model print_parser.opt = agent1.opt print_parser.print_args() # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() logger = WorldLogger(opt) # Run some self chats. max_cnt = opt['num_examples'] cnt = 0 while cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() logger.log(world) if opt.get('display_examples'): print(world.display()) if log_time.time() > log_every_n_secs: text = log_time.log(cnt, max_cnt) print(text) if opt.get('display_examples'): print('-- end of episode --') logger.reset_world() # flush last episode logger.write(opt['outfile'], opt['format']) return logger.get_logs()
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments parser.add_argument('-rp', '--report', type=str, default="/tmp/eval_model.json") parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--save-world-logs', type='bool', default=False, help='Saves a jsonl file containing all of the task examples and ' 'model replies. Must also specify --report-filename.', ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='If multitasking, average metrics over the ' 'number of examples. If false, averages over the ' 'number of tasks.', ) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) WorldLogger.add_cmdline_args(parser) TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') # Get command line arguments parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) parser.set_params(datatype='valid') return parser
def interactive(opt, print_parser=None): if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): logging.error('interactive should be passed opt not Parser') opt = opt.parse_args() # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) if print_parser: # Show arguments after loading model print_parser.opt = agent.opt print_parser.print_args() human_agent = LocalHumanAgent(opt) # set up world logger world_logger = WorldLogger(opt) if opt['save_world_logs'] else None world = create_task(opt, [human_agent, agent]) # Show some example dialogs: while not world.epoch_done(): world.parley() if world_logger is not None: world_logger.log(world) if opt.get('display_examples'): print("---") print(world.display()) if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs base_outfile = opt['report_filename'].split('.')[0] outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl' world_logger.write(outfile, world, file_format=opt['save_format'])
def interactive(opt: 'zoo:tutorial_transformer_generator/model'): if isinstance(opt, ParlaiParser): logging.error('interactive should be passed opt not Parser') opt = opt.parse_args() # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=False) agent.opt.log() human_agent = LocalHumanAgent(opt) # set up world logger world_logger = WorldLogger(opt) if opt.get('outfile') else None world = create_task(opt, [human_agent, agent]) # Show some example dialogs: while not world.epoch_done(): world.parley() print("done by me!") print(world.display()) if world.epoch_done() or world.get_total_parleys() <= 0: # chat was reset with [DONE], [EXIT] or EOF if world_logger is not None: world_logger.reset() continue if world_logger is not None: world_logger.log(world) if opt.get('display_examples'): print("---") print(world.display()) if world_logger is not None: # dump world acts to file world_logger.write(opt['outfile'], world, file_format=opt['save_format'])
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Interactive chat with a model') parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument( '--display-prettify', type='bool', default=False, help='Set to use a prettytable when displaying ' 'examples with text candidates', ) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-it', '--interactive-task', type='bool', default=True, help='Create interactive version of task', ) parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--save-world-logs', type='bool', default=False, help='Saves a jsonl file containing all of the task examples and ' 'model replies. Must also specify --report-filename.', ) parser.add_argument('--world-logs-format', type=str, default='parlai', choices=['jsonl', 'parlai', 'forever'], help='File format to save chat logs. (default parlai)') parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.set_defaults(interactive_mode=True, task='interactive') LocalHumanAgent.add_cmdline_args(parser) WorldLogger.add_cmdline_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser( True, True, 'Interactive chat with a model on the command line') parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument( '--display-prettify', type='bool', default=False, help='Set to use a prettytable when displaying ' 'examples with text candidates', ) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-it', '--interactive-task', type='bool', default=True, help='Create interactive version of task', ) parser.add_argument( '--outfile', type=str, default='', help='Saves a jsonl file containing all of the task examples and ' 'model replies. Set to the empty string to not save at all', ) parser.add_argument( '--save-format', type=str, default='parlai', choices=['conversations', 'parlai'], help= 'Format to save logs in. conversations is a jsonl format, parlai is a text format.', ) parser.set_defaults(interactive_mode=True, task='interactive') LocalHumanAgent.add_cmdline_args(parser) WorldLogger.add_cmdline_args(parser) return parser
def create_world(opt): if isinstance(opt, ParlaiParser): logging.error('interactive should be passed opt not Parser') opt = opt.parse_args() # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) agent.opt.log() human_agent = HumanAgent(opt) # set up world logger world_logger = WorldLogger(opt) if opt.get('outfile') else None world = MultiClientInteractiveWorld(opt, [human_agent, agent]) # Show some example dialogs: # while not world.epoch_done(): # world.parley() # if world.epoch_done() or world.get_total_parleys() <= 0: # # chat was reset with [DONE], [EXIT] or EOF # if world_logger is not None: # world_logger.reset() # continue # # if world_logger is not None: # world_logger.log(world) # if opt.get('display_examples'): # print("---") # print(world.display()) # # if world_logger is not None: # # dump world acts to file # world_logger.write(opt['outfile'], world, file_format=opt['save_format']) return world
def _eval_single_world(opt, agent, task): print( '[ Evaluating task {} using datatype {}. ] '.format( task, opt.get('datatype', 'N/A') ) ) # set up world logger world_logger = WorldLogger(opt) if opt['save_world_logs'] else None task_opt = opt.copy() # copy opt since we're editing the task task_opt['task'] = task world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1)
def _eval_single_world(opt, agent, task): print('[ Evaluating task {} using datatype {}. ] '.format( task, opt.get('datatype', 'N/A'))) # set up world logger world_logger = WorldLogger(opt) if opt['save_world_logs'] else None task_opt = opt.copy() # copy opt since we're editing the task # task_opt['task'] = task world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() if world_logger is not None: world_logger.log(world) if opt['display_examples']: # display examples print(world.display() + '\n~~') # for a in world.acts: # print (a) # print (world.get_acts()) # print (world.acts) if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report.get('exs', 0), world.num_examples(), report) print(text) report = world.report() print("Printing Report") print(report) world.reset() if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs base_outfile = opt['report_filename'].split('.')[0] print("filename: ", base_outfile) outfile = base_outfile + f'_{task}_replies.jsonl' # world_logger.write_jsonl_format(outfile) world_logger.write_parlai_format(outfile) return report
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Self chat with a model') parser.add_argument('--seed', type=int, default=42) parser.add_argument('-d', '--display-examples', type='bool', default=True) parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=60) parser.add_argument( '--display-ignore-fields', type=str, default='label_candidates,text_candidates', help='Do not display these fields', ) parser.add_argument( '-it', '--interactive-task', type='bool', default=True, help='Create interactive version of task', ) parser.add_argument( '--selfchat-max-turns', type=int, default=10, help="The number of dialogue turns before self chat ends.", ) parser.add_argument( '--seed-messages-from-task', action='store_true', help="Automatically seed conversation with messages from task dataset.", ) parser.add_argument('--outfile', type=str, default='/tmp/selfchat.json') parser.add_argument('--format', type=str, default='jsonl', choices={'parlai', 'jsonl'}) parser.add_argument('--indent', type=int, default=4, help='how much to indent jsonl string') parser.set_defaults(interactive_mode=True, task='self_chat') WorldLogger.add_cmdline_args(parser) return parser
def _eval_single_world(opt, agent, task): logging.info( f'Evaluating task {task} using datatype {opt.get("datatype")}.') # set up world logger world_logger = WorldLogger(opt) if opt['world_logs'] else None task_opt = opt.copy() # copy opt since we're editing the task task_opt['task'] = task world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 total_cnt = world.num_examples() if is_distributed(): logging.warning('Progress bar is approximate in distributed mode.') while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() if world_logger is not None: world_logger.log(world) if opt['display_examples']: # display examples print(world.display() + '\n~~') if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report.get('exs', 0), min(max_cnt, total_cnt), report) logging.info(text) if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = opt['world_logs'] world_logger.write(outfile, world, file_format=opt['save_format']) report = aggregate_unnamed_reports(all_gather_list(world.report())) world.reset() return report
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Display data from a task') #parser.add_pytorch_datateacher_args() # Get command line arguments parser.add_argument('--seed', type=int, default=42) parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10) parser.add_argument('-ns', '--num-stored', type=int, default=10) parser.add_argument('-mdl', '--max-display-len', type=int, default=1000) parser.add_argument('--display-ignore-fields', type=str, default='agent_reply') parser.set_defaults(datatype='train:stream') parser.add_argument('-host', '--mongo-host', type=str) parser.add_argument('-port', '--mongo-port', type=int) parser.add_argument('-user', '--user-name', type=str) parser.add_argument('-pw', '--password', type=str) parser.add_argument('-col', '--collection-name', type=str) WorldLogger.add_cmdline_args(parser) return parser
def self_chat(opt): random.seed(opt['seed']) # Create agents agent1 = create_agent(opt, requireModelExists=True) agent2 = agent1.clone() # Set IDs model_id = agent1.id agent1.id = model_id + "_1" agent2.id = model_id + "_2" world = create_task(opt, user_agents=[agent1, agent2]) # Set up world logging logger = WorldLogger(opt) log_time = TimeLogger() # Run some self chats. for i in range(opt['num_self_chats']): _run_self_chat_episode(opt, world, logger) report = world.report() text, report = log_time.log(i + 1, opt['num_self_chats'], report) logging.info(text) # Save chats if opt['outfile'] is None: outfile = '/tmp/{}_selfchat'.format(model_id) else: outfile = opt['outfile'] if opt['save_format'] == 'conversations' and hasattr(world, 'write'): # use self chat specific world to write conversation # this might be useful for logging extra contextual # information (like personas) world.write(logger, outfile) else: # use default logger write function logger.write(outfile, world, opt['save_format']) return logger.get_logs()
def interactive(opt, print_parser=None): if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: interactive should be passed opt not Parser ]' ) opt = opt.parse_args() # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) human_agent = LocalHumanAgent(opt) world = create_task(opt, [human_agent, agent]) # set up world logger world_logger = WorldLogger(opt) if opt['save_world_logs'] else None if print_parser: # Show arguments after loading model print_parser.opt = agent.opt print_parser.print_args() # Show some example dialogs: while True: try: world.parley() if world_logger is not None: world_logger.log(world) if opt.get('display_examples'): print("---") print(world.display()) if world.epoch_done(): print("EPOCH DONE") break except KeyboardInterrupt: if world_logger is not None: print(f"\nWriting out world log.") # Save report report = world.report() world.reset() # dump world acts to file world_logger.reset() # add final acts to logs base_outfile = opt['report_filename'].split('.')[0] outfile = base_outfile + f'_interactive_replies.json' world_logger.write(outfile, file_format=opt['world_logs_format']) quit()
def self_chat(opt, print_parser=None): if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): print('[ Deprecated Warning: self_chat should be passed opt not Parser ]') opt = opt.parse_args() random.seed(opt['seed']) # Create models agent1 = create_agent(opt, requireModelExists=True) agent2 = agent1.clone() if hasattr(agent2, 'id'): agent2.id = agent2.id + "2" world = create_task(opt, [agent1, agent2]) if print_parser: # Show arguments after loading model print_parser.opt = agent1.opt print_parser.print_args() # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() logger = WorldLogger(opt) # Run some self chats. max_cnt = opt['num_examples'] cnt = 0 while cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() logger.log(world) if opt.get('display_examples'):
def interactive(opt, print_parser=None): if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: interactive should be passed opt not Parser ]' ) opt = opt.parse_args() # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) if print_parser: # Show arguments after loading model print_parser.opt = agent.opt print_parser.print_args() # human_agent = LocalHumanAgent(opt) human_agent = AutoQueryAgent(opt, query_txt=opt['report_filename'] + ".txt") world_logger = WorldLogger(opt) world = create_task(opt, [human_agent, agent]) # Show some example dialogs: #while not world.epoch_done(): with open(opt['report_filename'] + ".txt", "r") as f: length = len(f.read().split("\n")) - 1 #for _ in range(5) : for _ in range(length): world.parley() if world_logger is not None: world_logger.log(world) if opt.get('display_examples'): print("---") print(world.display()) if world_logger is not None: world_logger.reset() base_outfile = opt['report_filename'].split('.')[0] outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl' # world_logger.write(outfile, world, file_format=opt['save_format']) # world_logger.write(outfile, world, file_format='conversations') world_logger.write(outfile, world, file_format='text')
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task): # run evaluation on a single world valid_world.reset() world_logger = None task_opt = opt.copy() # set up world logger for the "test" fold if opt['world_logs'] and datatype == 'test': task_opt['world_logs'] = get_task_world_logs( task, opt['world_logs'], is_multitask ) world_logger = WorldLogger(task_opt) cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if world_logger is not None: world_logger.log(valid_world) if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, valid_world, file_format=opt['save_format']) valid_report = valid_world.report() if opt.get('validation_share_agent', False): valid_world.reset() # make sure world doesn't remember valid data return valid_report
def generate_data(self): """ Generate the LM Data. """ random.seed(42) # load model and possibly print opt agent = create_agent(self.opt, requireModelExists=True) agent.opt.log() tasks = self.opt['task'].split(',') assert len(tasks) == 1 task = tasks[0] logging.info( f'Generating data for task {task} using datatype {self.opt.get("datatype")}.' ) logging.warning('Appending `flatten` to mutators.') # set up world logger self.task_opt = self.opt.copy() # copy opt since we're editing the task self.task_opt['task'] = task if not self.task_opt['mutators']: self.task_opt['mutators'] = 'flatten' else: self.task_opt['mutators'] += '+flatten' # add task suffix in case of multi-tasking self.task_opt['world_logs'] = get_task_world_logs( task, self.task_opt['world_logs'], is_multitask=False ) self.world_logger = WorldLogger(self.task_opt) self.world = create_task(self.task_opt, agent) # create worlds for tasks self.run_generation() # dump world acts to file self.log() self.world.reset()
def setup_args(parser=None) -> ParlaiParser: """ Build the ParlAI parser, adding command line args if necessary. :param ParlaiParser parser: Preexisting parser to append options to. Will be created if needed. :returns: the ParlaiParser with CLI options added. """ if parser is None: parser = ParlaiParser(True, True, 'Train a model') train = parser.add_argument_group('Training Loop Arguments') train.add_argument( '-et', '--evaltask', help='task to use for valid/test (defaults to the one used for training)', ) train.add_argument( '--final-extra-opt', type=str, default='', help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.", ) train.add_argument( '--eval-batchsize', type=int, hidden=True, help='Eval time batch size (defaults to same as -bs)', ) train.add_argument( '--eval-dynamic-batching', # FIXME: see https://github.com/facebookresearch/ParlAI/issues/3367 default=None, type='nonestr', choices={None, 'off', 'full', 'batchsort'}, help=( 'Set dynamic batching at evaluation time. Set to off for ' 'train-only dynamic batching. Set to none (default) to use same ' 'setting as --dynamic-batching.' ), ) train.add_argument( '--num-workers', default=0, type=int, help='Number of background workers (training only)', ) train.add_argument('--display-examples', type='bool', default=False, hidden=True) train.add_argument('-eps', '--num-epochs', type=float, default=-1) train.add_argument('-ttim', '--max-train-time', type=float, default=-1) train.add_argument( '-tstep', '--max-train-steps', '--max-lr-steps', type=int, default=-1, help='End training after n model updates', ) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1) train.add_argument( '-lstep', '--log-every-n-steps', type=int, default=50, help='Log every n training steps', ) train.add_argument( '-vtim', '--validation-every-n-secs', type=float, default=-1, help='Validate every n seconds. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-vstep', '--validation-every-n-steps', type=int, default=-1, help='Validate every n training steps. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-stim', '--save-every-n-secs', type=float, default=-1, help='Saves the model to model_file.checkpoint after ' 'every n seconds (default -1, never).', ) train.add_argument( '-sval', '--save-after-valid', type='bool', default=False, help='Saves the model to model_file.checkpoint after ' 'every validation (default %(default)s).', ) train.add_argument( '-veps', '--validation-every-n-epochs', type=float, default=-1, help='Validate every n epochs. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-vme', '--validation-max-exs', type=int, default=-1, hidden=True, help='max examples to use during validation (default -1 uses all)', ) train.add_argument( '--short-final-eval', default=False, hidden=True, type='bool', help='If true, obeys --validation-max-exs in the final ' 'validation and test evaluations.', ) train.add_argument( '-vp', '--validation-patience', type=int, default=10, help=( 'number of iterations of validation where result' ' does not improve before we stop training' ), ) train.add_argument( '-vmt', '--validation-metric', default='accuracy', help='key into report table for selecting best validation', ) train.add_argument( '-vmm', '--validation-metric-mode', type=str, choices=['max', 'min'], help='the direction in which to optimize the validation metric, i.e. maximize or minimize', ) train.add_argument( '-vcut', '--validation-cutoff', type=float, default=1.0, hidden=True, help='value at which training will stop if exceeded by metric', ) train.add_argument( '-lfc', '--load-from-checkpoint', type='bool', default=True, hidden=True, help='load model from checkpoint if available', ) train.add_argument( '-vshare', '--validation-share-agent', default=False, hidden=True, help='use a shared copy of the agent for validation. ' 'this will eventually default to True, but ' 'currently defaults to False.', ) train.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) train.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) train.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) train.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) WandbLogger.add_cmdline_args(parser, partial_opt=None) parser = setup_dict_args(parser) return parser
type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) WorldLogger.add_cmdline_args(parser) >>>>>>> b49eba4519856f6ab83b869b168c6af99863df47 TensorboardLogger.add_cmdline_args(parser) parser.set_params(datatype='valid') return parser def _save_eval_stats(opt, report): report_fname = opt['report_filename'] if report_fname == '': return if report_fname.startswith('.'): report_fname = opt['model_file'] + report_fname json_serializable_report = report for k, v in report.items():
def self_chat(opt): random.seed(opt['seed']) partner = opt['partner_model_file'] assert partner is not None partner_opt_file = opt.get('partner_opt_file') if partner_opt_file: assert partner_opt_file == partner + '.opt', ( 'Unless you think it is save,' ' you can remove assert') else: partner_opt_file = partner + '.opt' # Create agents if opt['model_file'].split(':')[0] == 'human': agent1 = MyLocalHumanAgent(opt) assert partner is not None else: agent1 = create_agent(opt, requireModelExists=True) if partner is None: # Self chat with same model agent2 = agent1.clone() else: # Self chat with different models if partner_opt_file: print(f"WARNING: Loading override opts from: {partner_opt_file}") with open(partner_opt_file) as f: partner_opt = json.load(f) else: partner_opt = {} partner_opt['interactive_mode'] = opt.get('interactive_mode', True) print( f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}" ) agent2 = create_agent_from_model_file(partner, partner_opt) # Set IDs agent1.id = agent1.id + '_1' agent2.id = agent2.id + '_2' model_id = agent1.id + '_' + agent2.id world = create_task(opt, user_agents=[agent1, agent2]) # Set up world logging logger = WorldLogger(opt) log_time = TimeLogger() # Run some self chats. all_report = [] if opt['num_self_chats'] < 0: opt['num_self_chats'] = len(world.messages) for i in range(opt['num_self_chats']): _run_self_chat_episode(opt, world, logger) report = world.report() text, report = log_time.log(i + 1, opt['num_self_chats'], report) logging.info(text) all_report.append(report) world.write(logger, all_report, opt['outfile']) # Save chats if opt['outfile'] is None: outfile = '/tmp/{}_selfchat'.format(model_id) else: outfile = opt['outfile'] if opt['save_format'] == 'conversations' and hasattr(world, 'write'): # use self chat specific world to write conversation # this might be useful for logging extra contextual # information (like personas) world.write(logger, all_report, outfile) else: # use default logger write function logger.write(outfile, world, opt['save_format']) return logger.get_logs()
def _eval_single_world(opt, agent, task): logging.info( f'Evaluating task {task} using datatype {opt.get("datatype")}.') # set up world logger task_opt = opt.copy() # copy opt since we're editing the task task_opt['task'] = task # add task suffix in case of multi-tasking if opt['world_logs']: task_opt['world_logs'] = get_task_world_logs( task, task_opt['world_logs'], is_multitask=len(opt['task'].split(',')) > 1) world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 total_cnt = world.num_examples() if is_distributed(): logging.warning('Progress bar is approximate in distributed mode.') while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() if world_logger is not None: world_logger.log(world) if opt['display_examples']: # display examples print(world.display() + '\n~~') if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report.get('exs', 0), min(max_cnt, total_cnt), report) logging.info(text) if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, world, file_format=opt['save_format']) report = aggregate_unnamed_reports(all_gather_list(world.report())) if isinstance(world.agents, list) and len(world.agents) > 1: classifier_agent = world.agents[CLASSIFIER_AGENT] if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc: for class_indices, curr_auc in zip( classifier_agent.auc_class_indices, classifier_agent.aucs): report[ f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc classifier_agent.reset_auc() # for safety measures agent.reset_auc() world.reset() return report
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') # Get command line arguments parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) parser.add_argument( '--area-under-curve-digits', '-auc', type=int, default=-1, help= 'a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric', ) parser.add_argument( '--area-under-curve-class', '-auclass', type=str, default=None, nargs='*', help='the name(s) of the class to calculate the auc for', ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) parser.set_params(datatype='valid') return parser
def self_chat(opt, print_parser=None): client = MongoClient( opt['mongo_host'], opt['mongo_port'], username=opt['user_name'], password=opt['password'], #authSource=DATABASE_NAME ) db = client[DATABASE_NAME] collection = db[opt['collection_name']] if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: self_chat should be passed opt not Parser ]' ) opt = opt.parse_args() random.seed(opt['seed']) # Create models opt['model_file'] = opt['model_file1'] if opt['model_file'] == 'tmp/convai2/lost_in_conversation/last_checkpoint': parser.set_defaults( model= 'projects.convai2.baselines.transformer_chatbot.agent:TransformerAgent', sample=False, wild_mode=False, replace_repeat=False, replace_ngram=False, detokenize=False, emoji_prob=0, add_questions=0, clean_emoji=False, check_grammar=False, correct_generative=False, split_into_sentences=False, max_seq_len=256, beam_size=3, annealing_topk=None, annealing=0.6, length_penalty=0.6) opt = parser.parse_args() agent1 = create_agent(opt, requireModelExists=True) elif opt['model_file'] == 'tmp/convai2/huggingface/model': parser.set_params( model= 'projects.convai2.baselines.huggingface.convai_evaluation:TransformerAgent' ) opt = parser.parse_args() agent1 = create_agent(opt, requireModelExists=True) else: agent1 = create_agent(opt, requireModelExists=True) opt['model_file'] = opt['model_file2'] if opt['model_file'] == 'tmp/convai2/lost_in_conversation/last_checkpoint': parser.set_defaults( model= 'projects.convai2.baselines.transformer_chatbot.agent:TransformerAgent', sample=False, wild_mode=False, replace_repeat=False, replace_ngram=False, detokenize=False, emoji_prob=0, add_questions=0, clean_emoji=False, check_grammar=False, correct_generative=False, split_into_sentences=False, max_seq_len=256, beam_size=3, annealing_topk=None, annealing=0.6, length_penalty=0.6) opt = parser.parse_args() agent2 = create_agent(opt, requireModelExists=True) elif opt['model_file'] == 'tmp/convai2/huggingface/model': parser.set_params( model= 'projects.convai2.baselines.huggingface.convai_evaluation:TransformerAgent' ) opt = parser.parse_args() agent2 = create_agent(opt, requireModelExists=True) else: agent2 = create_agent(opt, requireModelExists=True) if hasattr(agent2, 'id'): agent2.id = agent2.id + "2" opt['random_order'] = False world = create_task(opt, [agent1, agent2]) if print_parser: # Show arguments after loading model print_parser.opt = agent1.opt print_parser.print_args() # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() logger = WorldLogger(opt) # Run some self chats. max_dial_cnt = opt['num_dialogues'] #dial_cnt = 0 world.max_turn_cnt = world.sample_episode_length() for dial_cnt in tqdm(range(max_dial_cnt)): #while dial_cnt < max_dial_cnt: #world.max_turn_cnt = world.sample_episode_length() #world.turn_cnt = 0 #print('Dialogue Number: {}, Max Turn: {}\n'.format(dial_cnt, world.max_turn_cnt)) while True: world.parley() logger.log(world) if world.episode_done(): break #dial_cnt += 1 if dial_cnt % 20 == 0: store_logger(opt, collection, logger) logger = WorldLogger(opt) store_logger(opt, collection, logger)