Beispiel #1
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            world_logger.write(outfile, world, file_format=opt['save_format'])
Beispiel #2
0
def interactive(opt: 'zoo:tutorial_transformer_generator/model'):
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=False)
    agent.opt.log()
    human_agent = LocalHumanAgent(opt)
    # set up world logger
    world_logger = WorldLogger(opt) if opt.get('outfile') else None
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    while not world.epoch_done():
        world.parley()
        print("done by me!")
        print(world.display())
        if world.epoch_done() or world.get_total_parleys() <= 0:
            # chat was reset with [DONE], [EXIT] or EOF
            if world_logger is not None:
                world_logger.reset()
            continue

        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())

    if world_logger is not None:
        # dump world acts to file
        world_logger.write(opt['outfile'],
                           world,
                           file_format=opt['save_format'])
Beispiel #3
0
def _eval_single_world(opt, agent, task):
    print('[ Evaluating task {} using datatype {}. ] '.format(
        task, opt.get('datatype', 'N/A')))
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    # task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
            # for a in world.acts:
            # print (a)
            # print (world.get_acts())
            # print (world.acts)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        world.num_examples(), report)
            print(text)

    report = world.report()
    print("Printing Report")
    print(report)
    world.reset()

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        base_outfile = opt['report_filename'].split('.')[0]
        print("filename: ", base_outfile)
        outfile = base_outfile + f'_{task}_replies.jsonl'
        # world_logger.write_jsonl_format(outfile)
        world_logger.write_parlai_format(outfile)

    return report
Beispiel #4
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    world_logger = WorldLogger(opt) if opt['world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))
    world.reset()

    return report
Beispiel #5
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    human_agent = LocalHumanAgent(opt)
    world = create_task(opt, [human_agent, agent])
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Show some example dialogs:
    while True:
        try:
            world.parley()
            if world_logger is not None:
                world_logger.log(world)
            if opt.get('display_examples'):
                print("---")
                print(world.display())
            if world.epoch_done():
                print("EPOCH DONE")
                break
        except KeyboardInterrupt:
            if world_logger is not None:
                print(f"\nWriting out world log.")
                # Save report
                report = world.report()
                world.reset()

                # dump world acts to file
                world_logger.reset()  # add final acts to logs
                base_outfile = opt['report_filename'].split('.')[0]
                outfile = base_outfile + f'_interactive_replies.json'
                world_logger.write(outfile,
                                   file_format=opt['world_logs_format'])
            quit()
Beispiel #6
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()


#    human_agent = LocalHumanAgent(opt)
    human_agent = AutoQueryAgent(opt,
                                 query_txt=opt['report_filename'] + ".txt")
    world_logger = WorldLogger(opt)
    world = create_task(opt, [human_agent, agent])

    # Show some example dialogs:
    #while not world.epoch_done():
    with open(opt['report_filename'] + ".txt", "r") as f:
        length = len(f.read().split("\n")) - 1
    #for _ in range(5) :
    for _ in range(length):
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world_logger is not None:
            world_logger.reset()
            base_outfile = opt['report_filename'].split('.')[0]
            outfile = f'{base_outfile}_{opt["task"]}_replies.jsonl'
            #            world_logger.write(outfile, world, file_format=opt['save_format'])
            #            world_logger.write(outfile, world, file_format='conversations')
            world_logger.write(outfile, world, file_format='text')
Beispiel #7
0
    def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

        # run evaluation on a single world
        valid_world.reset()

        world_logger = None
        task_opt = opt.copy()
        # set up world logger for the "test" fold
        if opt['world_logs'] and datatype == 'test':
            task_opt['world_logs'] = get_task_world_logs(
                task, opt['world_logs'], is_multitask
            )
            world_logger = WorldLogger(task_opt)

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if world_logger is not None:
                world_logger.log(valid_world)
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            if is_distributed():
                rank = get_rank()
                base_outfile, extension = os.path.splitext(task_opt['world_logs'])
                outfile = base_outfile + f'_{rank}' + extension
            else:
                outfile = task_opt['world_logs']
            world_logger.write(outfile, valid_world, file_format=opt['save_format'])

        valid_report = valid_world.report()
        if opt.get('validation_share_agent', False):
            valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report
Beispiel #8
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    # add task suffix in case of multi-tasking
    if opt['world_logs']:
        task_opt['world_logs'] = get_task_world_logs(
            task,
            task_opt['world_logs'],
            is_multitask=len(opt['task'].split(',')) > 1)

    world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(task_opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = task_opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))

    if isinstance(world.agents, list) and len(world.agents) > 1:
        classifier_agent = world.agents[CLASSIFIER_AGENT]
        if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
            for class_indices, curr_auc in zip(
                    classifier_agent.auc_class_indices, classifier_agent.aucs):
                report[
                    f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
            classifier_agent.reset_auc()
            # for safety measures
            agent.reset_auc()
    world.reset()
    return report