Esempio n. 1
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            world_logger.write(outfile, world, file_format=opt['save_format'])
Esempio n. 2
0
def interactive(opt: 'zoo:tutorial_transformer_generator/model'):
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=False)
    agent.opt.log()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt.get('outfile') else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        print("done by me!")
        print(world.display())
        if world.epoch_done() or world.get_total_parleys() <= 0:
            # chat was reset with [DONE], [EXIT] or EOF
            if world_logger is not None:
                world_logger.reset()
            continue

        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())

    if world_logger is not None:
        # dump world acts to file
        world_logger.write(opt['outfile'],
                           world,
                           file_format=opt['save_format'])
Esempio n. 3
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    partner_opt_file = opt.get('partner_opt_file')

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + "_1"
    agent2.id = agent2.id + "_2"

    model_id = agent1.id + "_" + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Esempio n. 4
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: self_chat should be passed opt not Parser ]')
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    # Check for `selfchat` in the task name
    if 'selfchat' not in opt['task']:
        warn_once(
            'You are using self chat with task {}. '.format(opt['task'])
            + 'If your task has an existing self chat world, then run with '
            '-t {}:selfchat'.format(opt['task'])
        )

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    if opt.get('display_examples'):
        print('-- end of episode --')

    logger.reset_world()  # flush last episode
    logger.write(opt['outfile'], opt['format'])
    return logger.get_logs()
Esempio n. 5
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    world_logger = WorldLogger(opt) if opt['world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))
    world.reset()

    return report
Esempio n. 6
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    human_agent = LocalHumanAgent(opt)
    world = create_task(opt, [human_agent, agent])
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Show some example dialogs:
    while True:
        try:
            world.parley()
            if world_logger is not None:
                world_logger.log(world)
            if opt.get('display_examples'):
                print("---")
                print(world.display())
            if world.epoch_done():
                print("EPOCH DONE")
                break
        except KeyboardInterrupt:
            if world_logger is not None:
                print(f"\nWriting out world log.")
                # Save report
                report = world.report()
                world.reset()

                # dump world acts to file
                world_logger.reset()  # add final acts to logs
                base_outfile = opt['report_filename'].split('.')[0]
                outfile = base_outfile + f'_interactive_replies.json'
                world_logger.write(outfile,
                                   file_format=opt['world_logs_format'])
            quit()
Esempio n. 7
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: self_chat should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    logger.write(opt['outfile'], opt['format'])
Esempio n. 8
0
def _eval_single_world(opt, agent, task):
    print('[ Evaluating task {} using datatype {}. ] '.format(
        task, opt.get('datatype', 'N/A')))
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, world.num_examples()),
                                        report)
            print(text)

    report = world.report()
    world.reset()

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        base_outfile = opt['report_filename'].split('.')[0]
        outfile = base_outfile + f'_{task}_replies.jsonl'
        world_logger.write(outfile, world, file_format=opt['save_format'])

    return report
Esempio n. 9
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()


#    human_agent = LocalHumanAgent(opt)
    human_agent = AutoQueryAgent(opt,
                                 query_txt=opt['report_filename'] + ".txt")
    world_logger = WorldLogger(opt)
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    #while not world.epoch_done():
    with open(opt['report_filename'] + ".txt", "r") as f:
        length = len(f.read().split("\n")) - 1
    #for _ in range(5) :
    for _ in range(length):
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            world_logger.reset()
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            #            world_logger.write(outfile, world, file_format=opt['save_format'])
            #            world_logger.write(outfile, world, file_format='conversations')
            world_logger.write(outfile, world, file_format='text')
Esempio n. 10
0
    def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

        # run evaluation on a single world
        valid_world.reset()

        world_logger = None
        task_opt = opt.copy()
        # set up world logger for the "test" fold
        if opt['world_logs'] and datatype == 'test':
            task_opt['world_logs'] = get_task_world_logs(
                task, opt['world_logs'], is_multitask
            )
            world_logger = WorldLogger(task_opt)

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if world_logger is not None:
                world_logger.log(valid_world)
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            if is_distributed():
                rank = get_rank()
                base_outfile, extension = os.path.splitext(task_opt['world_logs'])
                outfile = base_outfile + f'_{rank}' + extension
            else:
                outfile = task_opt['world_logs']
            world_logger.write(outfile, valid_world, file_format=opt['save_format'])

        valid_report = valid_world.report()
        if opt.get('validation_share_agent', False):
            valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report
Esempio n. 11
0
def self_chat(opt):
    random.seed(opt['seed'])

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()

    # Set IDs
    model_id = agent1.id
    agent1.id = model_id + "_1"
    agent2.id = model_id + "_2"

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Esempio n. 12
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    # add task suffix in case of multi-tasking
    if opt['world_logs']:
        task_opt['world_logs'] = get_task_world_logs(
            task,
            task_opt['world_logs'],
            is_multitask=len(opt['task'].split(',')) > 1)

    world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(task_opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = task_opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))

    if isinstance(world.agents, list) and len(world.agents) > 1:
        classifier_agent = world.agents[CLASSIFIER_AGENT]
        if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
            for class_indices, curr_auc in zip(
                    classifier_agent.auc_class_indices, classifier_agent.aucs):
                report[
                    f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
            classifier_agent.reset_auc()
            # for safety measures
            agent.reset_auc()
    world.reset()
    return report
Esempio n. 13
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    assert partner is not None
    partner_opt_file = opt.get('partner_opt_file')
    if partner_opt_file:
        assert partner_opt_file == partner + '.opt', (
            'Unless you think it is save,'
            ' you can remove assert')
    else:
        partner_opt_file = partner + '.opt'

    # Create agents
    if opt['model_file'].split(':')[0] == 'human':
        agent1 = MyLocalHumanAgent(opt)
        assert partner is not None
    else:
        agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + '_1'
    agent2.id = agent2.id + '_2'

    model_id = agent1.id + '_' + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    all_report = []
    if opt['num_self_chats'] < 0:
        opt['num_self_chats'] = len(world.messages)

    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)
        all_report.append(report)

        world.write(logger, all_report, opt['outfile'])

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, all_report, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Esempio n. 14
0
def self_chat(opt, print_parser=None):
    client = MongoClient(
        opt['mongo_host'],
        opt['mongo_port'],
        username=opt['user_name'],
        password=opt['password'],
        #authSource=DATABASE_NAME
    )

    db = client[DATABASE_NAME]

    collection = db[COLLECTION_NAME]

    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: self_chat should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()

    # Set IDs
    model_id = agent1.id
    agent1.id = model_id + "_1"
    agent2.id = model_id + "_2"

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    max_dial_cnt = opt['num_dialogues']
    dial_cnt = 0
    while dial_cnt < max_dial_cnt:
        world.max_turn_cnt = world.sample_episode_length()
        world.turn_cnt = 0
        print('Dialogue Number: {}, Max Turn: {}\n'.format(
            dial_cnt, world.max_turn_cnt))
        while True:
            world.parley()
            logger.log(world)

            if opt.get('display_examples'):
                print(world.display())
            if world.episode_done():
                break

        print('\n\n')
        dial_cnt += 1

    if opt.get('display_examples'):
        print('-- end of episode --')

    logger.write(opt['outfile'], opt['format'])
    for convo in logger._logs:
        convo_data = {}
        convo_data['system_name0'] = opt['model_file']
        convo_data['system_name1'] = opt['model_file']

        convo_data['system_type0'] = opt['model_file'].split('/')[2]
        convo_data['system_type1'] = opt['model_file'].split('/')[2]

        convo_data['is_human0'] = False
        convo_data['is_human1'] = False

        convo_data['domain_name'] = opt['task'].split(':')[0]
        turn_list = []

        for eid, exchange in enumerate(convo):
            turn0 = exchange[0]
            turn1 = exchange[1]
            turn0['exchange_nr'] = eid
            turn1['exchange_nr'] = eid
            if type(turn0) == Message:
                turn0.force_set('episode_done', bool(turn0['episode_done']))
            else:
                turn0['episode_done'] = bool(turn0['episode_done'])
            if type(turn0) == Message:
                turn1.force_set('episode_done', bool(turn1['episode_done']))
            else:
                turn1['episode_done'] = bool(turn1['episode_done'])
            turn_list.append(turn0)
            turn_list.append(turn1)

        convo_data['convo'] = cap_context(turn_list, convo_data['domain_name'])
        collection.insert_one(convo_data)
        print(len(convo_data['convo']))