Пример #1
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    partner_opt_file = opt.get('partner_opt_file')

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + "_1"
    agent2.id = agent2.id + "_2"

    model_id = agent1.id + "_" + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Пример #2
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Generate self-chats of a model')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('-d', '--display-examples', type='bool', default=True)
    parser.add_argument(
        '--display-ignore-fields',
        type=str,
        default='label_candidates,text_candidates',
        help='Do not display these fields',
    )
    parser.add_argument(
        '-st',
        '--selfchat-task',
        type='bool',
        default=True,
        help='Create a self chat version of the task',
    )
    parser.add_argument('--num-self-chats',
                        type=int,
                        default=1,
                        help='Number of self chats to run')
    parser.add_argument(
        '--selfchat-max-turns',
        type=int,
        default=6,
        help='The number of dialogue turns before self chat ends',
    )
    parser.add_argument(
        '--seed-messages-from-task',
        action='store_true',
        help='Automatically seed conversation with messages from task dataset.',
    )
    parser.add_argument('--outfile',
                        type=str,
                        default=None,
                        help='File to save self chat logs')
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
        help=
        'Format to save logs in. conversations is a jsonl format, parlai is a text format.',
    )
    parser.add_argument(
        '-pmf',
        '--partner-model-file',
        default=None,
        help='Define a different partner for self chat',
    )
    parser.add_argument(
        '--partner-opt-file',
        default=None,
        help='Path to file containing opts to override for partner',
    )
    parser.set_defaults(interactive_mode=True, task='self_chat')
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #3
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Self chat with a model')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('-d', '--display-examples', type='bool', default=True)
    parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10)
    parser.add_argument('-nd', '--num-dialogues', type=int, default=10)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument('-host', '--mongo-host', type=str)
    parser.add_argument('-port', '--mongo-port', type=int)
    parser.add_argument('-user', '--user-name', type=str)
    parser.add_argument('-pw', '--password', type=str)
    parser.add_argument('-col', '--collection-name', type=str)
    parser.add_argument(
        '-mf1',
        '--model-file1',
        default=None,
        help='model file name for loading and saving models',
    )
    parser.add_argument(
        '-mf2',
        '--model-file2',
        default=None,
        help='model file name for loading and saving models',
    )
    parser.add_argument(
        '--display-ignore-fields',
        type=str,
        default='label_candidates,text_candidates',
        help='Do not display these fields',
    )
    parser.add_argument(
        '-it',
        '--interactive-task',
        type='bool',
        default=True,
        help='Create interactive version of task',
    )
    parser.add_argument(
        '--selfchat-max-turns',
        type=int,
        default=10,
        help="The number of dialogue turns before self chat ends.",
    )
    parser.add_argument(
        '--seed-messages-from-task',
        action='store_true',
        help="Automatically seed conversation with messages from task dataset.",
    )
    parser.add_argument('--outfile', type=str, default='/tmp/selfchat.json')
    parser.add_argument('--format',
                        type=str,
                        default='json',
                        choices={'parlai', 'json'})
    parser.set_defaults(interactive_mode=True, task='self_chat')
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #4
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(
            True, True, 'Interactive chat with a model on the command line'
        )
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_defaults(interactive_mode=True, task='interactive')
    LocalHumanAgent.add_cmdline_args(parser)
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #5
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Self chat with a model')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('-d', '--display-examples', type='bool', default=True)
    parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10)
    parser.add_argument('-nd', '--num-dialogues', type=int, default=10)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument('-host', '--mongo-host', type=str)
    parser.add_argument('-port', '--mongo-port', type=int)
    parser.add_argument('-user', '--user-name', type=str)
    parser.add_argument('-pw', '--password', type=str)
    parser.add_argument(
        '--display-ignore-fields',
        type=str,
        default='label_candidates,text_candidates',
        help='Do not display these fields',
    )
    parser.add_argument(
        '-st',
        '--selfchat-task',
        type='bool',
        default=True,
        help='Create a self chat version of the task',
    )
    parser.add_argument('--num-self-chats',
                        type=int,
                        default=1,
                        help='Number of self chats to run')
    parser.add_argument(
        '--selfchat-max-turns',
        type=int,
        default=6,
        help='The number of dialogue turns before self chat ends',
    )
    parser.add_argument(
        '--seed-messages-from-task',
        action='store_true',
        help='Automatically seed conversation with messages from task dataset.',
    )
    parser.add_argument('--outfile',
                        type=str,
                        default=None,
                        help='File to save self chat logs')
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
        help=
        'Format to save logs in. conversations is a jsonl format, parlai is a text format.',
    )
    parser.set_defaults(interactive_mode=True, task='self_chat')
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #6
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: self_chat should be passed opt not Parser ]')
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    # Check for `selfchat` in the task name
    if 'selfchat' not in opt['task']:
        warn_once(
            'You are using self chat with task {}. '.format(opt['task'])
            + 'If your task has an existing self chat world, then run with '
            '-t {}:selfchat'.format(opt['task'])
        )

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    if opt.get('display_examples'):
        print('-- end of episode --')

    logger.reset_world()  # flush last episode
    logger.write(opt['outfile'], opt['format'])
    return logger.get_logs()
Пример #7
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    parser.add_pytorch_datateacher_args()
    # Get command line arguments
    parser.add_argument('-rp',
                        '--report',
                        type=str,
                        default="/tmp/eval_model.json")
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-world-logs',
        type='bool',
        default=False,
        help='Saves a jsonl file containing all of the task examples and '
        'model replies. Must also specify --report-filename.',
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='If multitasking, average metrics over the '
        'number of examples. If false, averages over the '
        'number of tasks.',
    )
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    WorldLogger.add_cmdline_args(parser)
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_defaults(datatype='valid')
    return parser
Пример #8
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(datatype='valid')
    return parser
Пример #9
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            world_logger.write(outfile, world, file_format=opt['save_format'])
Пример #10
0
def interactive(opt: 'zoo:tutorial_transformer_generator/model'):
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=False)
    agent.opt.log()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt.get('outfile') else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        print("done by me!")
        print(world.display())
        if world.epoch_done() or world.get_total_parleys() <= 0:
            # chat was reset with [DONE], [EXIT] or EOF
            if world_logger is not None:
                world_logger.reset()
            continue

        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())

    if world_logger is not None:
        # dump world acts to file
        world_logger.write(opt['outfile'],
                           world,
                           file_format=opt['save_format'])
Пример #11
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Interactive chat with a model')
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument(
        '--display-prettify',
        type='bool',
        default=False,
        help='Set to use a prettytable when displaying '
        'examples with text candidates',
    )
    parser.add_argument(
        '--display-ignore-fields',
        type=str,
        default='label_candidates,text_candidates',
        help='Do not display these fields',
    )
    parser.add_argument(
        '-it',
        '--interactive-task',
        type='bool',
        default=True,
        help='Create interactive version of task',
    )
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-world-logs',
        type='bool',
        default=False,
        help='Saves a jsonl file containing all of the task examples and '
        'model replies. Must also specify --report-filename.',
    )
    parser.add_argument('--world-logs-format',
                        type=str,
                        default='parlai',
                        choices=['jsonl', 'parlai', 'forever'],
                        help='File format to save chat logs. (default parlai)')
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.set_defaults(interactive_mode=True, task='interactive')
    LocalHumanAgent.add_cmdline_args(parser)
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #12
0
 def setup_args(parser=None):
     if parser is None:
         parser = ParlaiParser(
             True, True,
             'Interactive chat with a model on the command line')
     parser.add_argument('-d',
                         '--display-examples',
                         type='bool',
                         default=False)
     parser.add_argument(
         '--display-prettify',
         type='bool',
         default=False,
         help='Set to use a prettytable when displaying '
         'examples with text candidates',
     )
     parser.add_argument(
         '--display-ignore-fields',
         type=str,
         default='label_candidates,text_candidates',
         help='Do not display these fields',
     )
     parser.add_argument(
         '-it',
         '--interactive-task',
         type='bool',
         default=True,
         help='Create interactive version of task',
     )
     parser.add_argument(
         '--outfile',
         type=str,
         default='',
         help='Saves a jsonl file containing all of the task examples and '
         'model replies. Set to the empty string to not save at all',
     )
     parser.add_argument(
         '--save-format',
         type=str,
         default='parlai',
         choices=['conversations', 'parlai'],
         help=
         'Format to save logs in. conversations is a jsonl format, parlai is a text format.',
     )
     parser.set_defaults(interactive_mode=True, task='interactive')
     LocalHumanAgent.add_cmdline_args(parser)
     WorldLogger.add_cmdline_args(parser)
     return parser
Пример #13
0
def create_world(opt):
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    agent.opt.log()
    human_agent = HumanAgent(opt)

    # set up world logger
    world_logger = WorldLogger(opt) if opt.get('outfile') else None
    world = MultiClientInteractiveWorld(opt, [human_agent, agent])

    # Show some example dialogs:
    # while not world.epoch_done():
    #     world.parley()
    #     if world.epoch_done() or world.get_total_parleys() <= 0:
    #         # chat was reset with [DONE], [EXIT] or EOF
    #         if world_logger is not None:
    #             world_logger.reset()
    #         continue
    #
    #     if world_logger is not None:
    #         world_logger.log(world)
    #     if opt.get('display_examples'):
    #         print("---")
    #         print(world.display())
    #
    # if world_logger is not None:
    #     # dump world acts to file
    #     world_logger.write(opt['outfile'], world, file_format=opt['save_format'])

    return world
Пример #14
0
def _eval_single_world(opt, agent, task):
    print(
        '[ Evaluating task {} using datatype {}. ] '.format(
            task, opt.get('datatype', 'N/A')
        )
    )
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
Пример #15
0
def _eval_single_world(opt, agent, task):
    print('[ Evaluating task {} using datatype {}. ] '.format(
        task, opt.get('datatype', 'N/A')))
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    # task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
            # for a in world.acts:
            # print (a)
            # print (world.get_acts())
            # print (world.acts)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        world.num_examples(), report)
            print(text)

    report = world.report()
    print("Printing Report")
    print(report)
    world.reset()

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        base_outfile = opt['report_filename'].split('.')[0]
        print("filename: ", base_outfile)
        outfile = base_outfile + f'_{task}_replies.jsonl'
        # world_logger.write_jsonl_format(outfile)
        world_logger.write_parlai_format(outfile)

    return report
Пример #16
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Self chat with a model')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('-d', '--display-examples', type='bool', default=True)
    parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=60)
    parser.add_argument(
        '--display-ignore-fields',
        type=str,
        default='label_candidates,text_candidates',
        help='Do not display these fields',
    )
    parser.add_argument(
        '-it',
        '--interactive-task',
        type='bool',
        default=True,
        help='Create interactive version of task',
    )
    parser.add_argument(
        '--selfchat-max-turns',
        type=int,
        default=10,
        help="The number of dialogue turns before self chat ends.",
    )
    parser.add_argument(
        '--seed-messages-from-task',
        action='store_true',
        help="Automatically seed conversation with messages from task dataset.",
    )
    parser.add_argument('--outfile', type=str, default='/tmp/selfchat.json')
    parser.add_argument('--format',
                        type=str,
                        default='jsonl',
                        choices={'parlai', 'jsonl'})
    parser.add_argument('--indent',
                        type=int,
                        default=4,
                        help='how much to indent jsonl string')
    parser.set_defaults(interactive_mode=True, task='self_chat')
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #17
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    world_logger = WorldLogger(opt) if opt['world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))
    world.reset()

    return report
Пример #18
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Display data from a task')
    #parser.add_pytorch_datateacher_args()
    # Get command line arguments
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('-n', '-ne', '--num-examples', type=int, default=10)
    parser.add_argument('-ns', '--num-stored', type=int, default=10)
    parser.add_argument('-mdl', '--max-display-len', type=int, default=1000)
    parser.add_argument('--display-ignore-fields',
                        type=str,
                        default='agent_reply')
    parser.set_defaults(datatype='train:stream')
    parser.add_argument('-host', '--mongo-host', type=str)
    parser.add_argument('-port', '--mongo-port', type=int)
    parser.add_argument('-user', '--user-name', type=str)
    parser.add_argument('-pw', '--password', type=str)
    parser.add_argument('-col', '--collection-name', type=str)
    WorldLogger.add_cmdline_args(parser)
    return parser
Пример #19
0
def self_chat(opt):
    random.seed(opt['seed'])

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()

    # Set IDs
    model_id = agent1.id
    agent1.id = model_id + "_1"
    agent2.id = model_id + "_2"

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Пример #20
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    human_agent = LocalHumanAgent(opt)
    world = create_task(opt, [human_agent, agent])
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Show some example dialogs:
    while True:
        try:
            world.parley()
            if world_logger is not None:
                world_logger.log(world)
            if opt.get('display_examples'):
                print("---")
                print(world.display())
            if world.epoch_done():
                print("EPOCH DONE")
                break
        except KeyboardInterrupt:
            if world_logger is not None:
                print(f"\nWriting out world log.")
                # Save report
                report = world.report()
                world.reset()

                # dump world acts to file
                world_logger.reset()  # add final acts to logs
                base_outfile = opt['report_filename'].split('.')[0]
                outfile = base_outfile + f'_interactive_replies.json'
                world_logger.write(outfile,
                                   file_format=opt['world_logs_format'])
            quit()
Пример #21
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: self_chat should be passed opt not Parser ]')
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
Пример #22
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()


#    human_agent = LocalHumanAgent(opt)
    human_agent = AutoQueryAgent(opt,
                                 query_txt=opt['report_filename'] + ".txt")
    world_logger = WorldLogger(opt)
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    #while not world.epoch_done():
    with open(opt['report_filename'] + ".txt", "r") as f:
        length = len(f.read().split("\n")) - 1
    #for _ in range(5) :
    for _ in range(length):
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            world_logger.reset()
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            #            world_logger.write(outfile, world, file_format=opt['save_format'])
            #            world_logger.write(outfile, world, file_format='conversations')
            world_logger.write(outfile, world, file_format='text')
Пример #23
0
    def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

        # run evaluation on a single world
        valid_world.reset()

        world_logger = None
        task_opt = opt.copy()
        # set up world logger for the "test" fold
        if opt['world_logs'] and datatype == 'test':
            task_opt['world_logs'] = get_task_world_logs(
                task, opt['world_logs'], is_multitask
            )
            world_logger = WorldLogger(task_opt)

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if world_logger is not None:
                world_logger.log(valid_world)
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            if is_distributed():
                rank = get_rank()
                base_outfile, extension = os.path.splitext(task_opt['world_logs'])
                outfile = base_outfile + f'_{rank}' + extension
            else:
                outfile = task_opt['world_logs']
            world_logger.write(outfile, valid_world, file_format=opt['save_format'])

        valid_report = valid_world.report()
        if opt.get('validation_share_agent', False):
            valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report
Пример #24
0
    def generate_data(self):
        """
        Generate the LM Data.
        """
        random.seed(42)
        # load model and possibly print opt
        agent = create_agent(self.opt, requireModelExists=True)
        agent.opt.log()

        tasks = self.opt['task'].split(',')
        assert len(tasks) == 1
        task = tasks[0]
        logging.info(
            f'Generating data for task {task} using datatype {self.opt.get("datatype")}.'
        )
        logging.warning('Appending `flatten` to mutators.')
        # set up world logger
        self.task_opt = self.opt.copy()  # copy opt since we're editing the task
        self.task_opt['task'] = task
        if not self.task_opt['mutators']:
            self.task_opt['mutators'] = 'flatten'
        else:
            self.task_opt['mutators'] += '+flatten'
        # add task suffix in case of multi-tasking
        self.task_opt['world_logs'] = get_task_world_logs(
            task, self.task_opt['world_logs'], is_multitask=False
        )

        self.world_logger = WorldLogger(self.task_opt)

        self.world = create_task(self.task_opt, agent)  # create worlds for tasks

        self.run_generation()

        # dump world acts to file
        self.log()

        self.world.reset()
Пример #25
0
def setup_args(parser=None) -> ParlaiParser:
    """
    Build the ParlAI parser, adding command line args if necessary.

    :param ParlaiParser parser:
        Preexisting parser to append options to. Will be created if needed.

    :returns:
        the ParlaiParser with CLI options added.
    """
    if parser is None:
        parser = ParlaiParser(True, True, 'Train a model')
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument(
        '-et',
        '--evaltask',
        help='task to use for valid/test (defaults to the one used for training)',
    )
    train.add_argument(
        '--final-extra-opt',
        type=str,
        default='',
        help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.",
    )
    train.add_argument(
        '--eval-batchsize',
        type=int,
        hidden=True,
        help='Eval time batch size (defaults to same as -bs)',
    )
    train.add_argument(
        '--eval-dynamic-batching',  # FIXME: see https://github.com/facebookresearch/ParlAI/issues/3367
        default=None,
        type='nonestr',
        choices={None, 'off', 'full', 'batchsort'},
        help=(
            'Set dynamic batching at evaluation time. Set to off for '
            'train-only dynamic batching. Set to none (default) to use same '
            'setting as --dynamic-batching.'
        ),
    )
    train.add_argument(
        '--num-workers',
        default=0,
        type=int,
        help='Number of background workers (training only)',
    )
    train.add_argument('--display-examples', type='bool', default=False, hidden=True)
    train.add_argument('-eps', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument(
        '-tstep',
        '--max-train-steps',
        '--max-lr-steps',
        type=int,
        default=-1,
        help='End training after n model updates',
    )
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1)
    train.add_argument(
        '-lstep',
        '--log-every-n-steps',
        type=int,
        default=50,
        help='Log every n training steps',
    )
    train.add_argument(
        '-vtim',
        '--validation-every-n-secs',
        type=float,
        default=-1,
        help='Validate every n seconds. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-vstep',
        '--validation-every-n-steps',
        type=int,
        default=-1,
        help='Validate every n training steps. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-stim',
        '--save-every-n-secs',
        type=float,
        default=-1,
        help='Saves the model to model_file.checkpoint after '
        'every n seconds (default -1, never).',
    )
    train.add_argument(
        '-sval',
        '--save-after-valid',
        type='bool',
        default=False,
        help='Saves the model to model_file.checkpoint after '
        'every validation (default %(default)s).',
    )
    train.add_argument(
        '-veps',
        '--validation-every-n-epochs',
        type=float,
        default=-1,
        help='Validate every n epochs. Saves model to model_file '
        '(if set) whenever best val metric is found',
    )
    train.add_argument(
        '-vme',
        '--validation-max-exs',
        type=int,
        default=-1,
        hidden=True,
        help='max examples to use during validation (default -1 uses all)',
    )
    train.add_argument(
        '--short-final-eval',
        default=False,
        hidden=True,
        type='bool',
        help='If true, obeys --validation-max-exs in the final '
        'validation and test evaluations.',
    )
    train.add_argument(
        '-vp',
        '--validation-patience',
        type=int,
        default=10,
        help=(
            'number of iterations of validation where result'
            ' does not improve before we stop training'
        ),
    )
    train.add_argument(
        '-vmt',
        '--validation-metric',
        default='accuracy',
        help='key into report table for selecting best validation',
    )
    train.add_argument(
        '-vmm',
        '--validation-metric-mode',
        type=str,
        choices=['max', 'min'],
        help='the direction in which to optimize the validation metric, i.e. maximize or minimize',
    )
    train.add_argument(
        '-vcut',
        '--validation-cutoff',
        type=float,
        default=1.0,
        hidden=True,
        help='value at which training will stop if exceeded by metric',
    )
    train.add_argument(
        '-lfc',
        '--load-from-checkpoint',
        type='bool',
        default=True,
        hidden=True,
        help='load model from checkpoint if available',
    )
    train.add_argument(
        '-vshare',
        '--validation-share-agent',
        default=False,
        hidden=True,
        help='use a shared copy of the agent for validation. '
        'this will eventually default to True, but '
        'currently defaults to False.',
    )
    train.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    train.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    train.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    train.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    WandbLogger.add_cmdline_args(parser, partial_opt=None)

    parser = setup_dict_args(parser)
    return parser
Пример #26
0
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser)
>>>>>>> b49eba4519856f6ab83b869b168c6af99863df47
    TensorboardLogger.add_cmdline_args(parser)
    parser.set_params(datatype='valid')
    return parser


def _save_eval_stats(opt, report):
    report_fname = opt['report_filename']
    if report_fname == '':
        return
    if report_fname.startswith('.'):
        report_fname = opt['model_file'] + report_fname

    json_serializable_report = report
    for k, v in report.items():
Пример #27
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    assert partner is not None
    partner_opt_file = opt.get('partner_opt_file')
    if partner_opt_file:
        assert partner_opt_file == partner + '.opt', (
            'Unless you think it is save,'
            ' you can remove assert')
    else:
        partner_opt_file = partner + '.opt'

    # Create agents
    if opt['model_file'].split(':')[0] == 'human':
        agent1 = MyLocalHumanAgent(opt)
        assert partner is not None
    else:
        agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + '_1'
    agent2.id = agent2.id + '_2'

    model_id = agent1.id + '_' + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    all_report = []
    if opt['num_self_chats'] < 0:
        opt['num_self_chats'] = len(world.messages)

    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)
        all_report.append(report)

        world.write(logger, all_report, opt['outfile'])

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, all_report, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Пример #28
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    # add task suffix in case of multi-tasking
    if opt['world_logs']:
        task_opt['world_logs'] = get_task_world_logs(
            task,
            task_opt['world_logs'],
            is_multitask=len(opt['task'].split(',')) > 1)

    world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(task_opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = task_opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))

    if isinstance(world.agents, list) and len(world.agents) > 1:
        classifier_agent = world.agents[CLASSIFIER_AGENT]
        if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
            for class_indices, curr_auc in zip(
                    classifier_agent.auc_class_indices, classifier_agent.aucs):
                report[
                    f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
            classifier_agent.reset_auc()
            # for safety measures
            agent.reset_auc()
    world.reset()
    return report
Пример #29
0
def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Evaluate a model')
    # Get command line arguments
    parser.add_argument(
        '-rf',
        '--report-filename',
        type=str,
        default='',
        help='Saves a json file of the evaluation report either as an '
        'extension to the model-file (if begins with a ".") or a whole '
        'file path. Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--world-logs',
        type=str,
        default='',
        help='Saves a jsonl file of the world logs.'
        'Set to the empty string to not save at all.',
    )
    parser.add_argument(
        '--save-format',
        type=str,
        default='conversations',
        choices=['conversations', 'parlai'],
    )
    parser.add_argument(
        '--area-under-curve-digits',
        '-auc',
        type=int,
        default=-1,
        help=
        'a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
    )
    parser.add_argument(
        '--area-under-curve-class',
        '-auclass',
        type=str,
        default=None,
        nargs='*',
        help='the name(s) of the class to calculate the auc for',
    )
    parser.add_argument('-ne', '--num-examples', type=int, default=-1)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10)
    parser.add_argument(
        '-mcs',
        '--metrics',
        type=str,
        default='default',
        help='list of metrics to show/compute, e.g. all, default,'
        'or give a list split by , like '
        'ppl,f1,accuracy,hits@1,rouge,bleu'
        'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l',
    )
    parser.add_argument(
        '-micro',
        '--aggregate-micro',
        type='bool',
        default=False,
        help='Report micro-averaged metrics instead of macro averaged metrics.',
        recommended=False,
    )
    WorldLogger.add_cmdline_args(parser, partial_opt=None)
    TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
    parser.set_params(datatype='valid')
    return parser
Пример #30
0
def self_chat(opt, print_parser=None):
    client = MongoClient(
        opt['mongo_host'],
        opt['mongo_port'],
        username=opt['user_name'],
        password=opt['password'],
        #authSource=DATABASE_NAME
    )

    db = client[DATABASE_NAME]

    collection = db[opt['collection_name']]

    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: self_chat should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    opt['model_file'] = opt['model_file1']
    if opt['model_file'] == 'tmp/convai2/lost_in_conversation/last_checkpoint':
        parser.set_defaults(
            model=
            'projects.convai2.baselines.transformer_chatbot.agent:TransformerAgent',
            sample=False,
            wild_mode=False,
            replace_repeat=False,
            replace_ngram=False,
            detokenize=False,
            emoji_prob=0,
            add_questions=0,
            clean_emoji=False,
            check_grammar=False,
            correct_generative=False,
            split_into_sentences=False,
            max_seq_len=256,
            beam_size=3,
            annealing_topk=None,
            annealing=0.6,
            length_penalty=0.6)
        opt = parser.parse_args()
        agent1 = create_agent(opt, requireModelExists=True)
    elif opt['model_file'] == 'tmp/convai2/huggingface/model':
        parser.set_params(
            model=
            'projects.convai2.baselines.huggingface.convai_evaluation:TransformerAgent'
        )
        opt = parser.parse_args()
        agent1 = create_agent(opt, requireModelExists=True)
    else:
        agent1 = create_agent(opt, requireModelExists=True)

    opt['model_file'] = opt['model_file2']
    if opt['model_file'] == 'tmp/convai2/lost_in_conversation/last_checkpoint':
        parser.set_defaults(
            model=
            'projects.convai2.baselines.transformer_chatbot.agent:TransformerAgent',
            sample=False,
            wild_mode=False,
            replace_repeat=False,
            replace_ngram=False,
            detokenize=False,
            emoji_prob=0,
            add_questions=0,
            clean_emoji=False,
            check_grammar=False,
            correct_generative=False,
            split_into_sentences=False,
            max_seq_len=256,
            beam_size=3,
            annealing_topk=None,
            annealing=0.6,
            length_penalty=0.6)
        opt = parser.parse_args()
        agent2 = create_agent(opt, requireModelExists=True)
    elif opt['model_file'] == 'tmp/convai2/huggingface/model':
        parser.set_params(
            model=
            'projects.convai2.baselines.huggingface.convai_evaluation:TransformerAgent'
        )
        opt = parser.parse_args()
        agent2 = create_agent(opt, requireModelExists=True)
    else:
        agent2 = create_agent(opt, requireModelExists=True)

    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    opt['random_order'] = False
    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_dial_cnt = opt['num_dialogues']
    #dial_cnt = 0
    world.max_turn_cnt = world.sample_episode_length()
    for dial_cnt in tqdm(range(max_dial_cnt)):
        #while dial_cnt < max_dial_cnt:
        #world.max_turn_cnt = world.sample_episode_length()
        #world.turn_cnt = 0
        #print('Dialogue Number: {}, Max Turn: {}\n'.format(dial_cnt, world.max_turn_cnt))
        while True:
            world.parley()
            logger.log(world)

            if world.episode_done():
                break
        #dial_cnt += 1
        if dial_cnt % 20 == 0:
            store_logger(opt, collection, logger)
            logger = WorldLogger(opt)
    store_logger(opt, collection, logger)