Пример #1
0
    def test_uneven_macro_aggrevation(self):
        report1 = {
            'avg': AverageMetric(1, 1),
        }
        report2 = {
            'avg': AverageMetric(0, 1),
        }
        report3 = {
            'avg': AverageMetric(0, 1),
        }
        agg1 = aggregate_named_reports({
            'a': report1,
            'b': report2
        },
                                       micro_average=False)
        agg2 = aggregate_named_reports({
            'a': {},
            'c': report3
        },
                                       micro_average=False)

        agg = aggregate_unnamed_reports([agg1, agg2])
        assert agg1['avg'] == 0.5
        assert agg2['avg'] == 0.0
        assert agg['a/avg'] == 1.0
        assert agg['b/avg'] == 0.0
        assert agg['c/avg'] == 0.0
        assert agg['avg'] == 1.0 / 3
Пример #2
0
 def _sync_metrics(self, metrics):
     """
     Sync training metrics across workers.
     A handful of special cases are handled as exceptions, and the remaining metrics
     are simply averaged across workers.
     """
     if not is_distributed():
         # nothing special needed
         return metrics
     all_versions = all_gather_list(metrics)
     return aggregate_unnamed_reports(all_versions)
Пример #3
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    world_logger = WorldLogger(opt) if opt['world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))
    world.reset()

    return report
Пример #4
0
 def test_unnamed_aggregation(self):
     report1 = {
         'avg': AverageMetric(3, 4),
         'sum': SumMetric(3),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(3, 4),
     }
     report2 = {
         'avg': AverageMetric(1, 3),
         'sum': SumMetric(4),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(1, 3),
     }
     agg = aggregate_unnamed_reports([report1, report2])
     assert agg['avg'] == 4.0 / 7
     assert agg['sum'] == 7
     assert agg['fixed'] == 4
     assert agg['global_avg'] == 4.0 / 7
Пример #5
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    # add task suffix in case of multi-tasking
    if opt['world_logs']:
        task_opt['world_logs'] = get_task_world_logs(
            task,
            task_opt['world_logs'],
            is_multitask=len(opt['task'].split(',')) > 1)

    world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(task_opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = task_opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))

    if isinstance(world.agents, list) and len(world.agents) > 1:
        classifier_agent = world.agents[CLASSIFIER_AGENT]
        if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
            for class_indices, curr_auc in zip(
                    classifier_agent.auc_class_indices, classifier_agent.aucs):
                report[
                    f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
            classifier_agent.reset_auc()
            # for safety measures
            agent.reset_auc()
    world.reset()
    return report
Пример #6
0
 def report(self):
     return aggregate_unnamed_reports(
         [self.world.report(), self.metrics.report()])
Пример #7
0
    def _save_outputs(self, opt, world, logger, episode_metrics):
        if is_distributed():  # flatten everything intelligently if need be
            world_report = aggregate_unnamed_reports(
                all_gather_list(world.report()))
            episode_metrics_unflattened = all_gather_list(episode_metrics)
            flattened = []
            for rank_elem in episode_metrics_unflattened:
                for elem in rank_elem:
                    flattened.append(elem)
            episode_metrics = flattened
        else:
            world_report = world.report()
        logging.report("Final report:\n" + nice_report(world_report))

        report = dict_report(world_report)

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        report["tod_metrics"] = [
            get_episode_report(g, e) for g, e in episode_metrics
        ]

        if "report_filename" in opt and opt["report_filename"] is not None:
            if len(world_report) == 0:
                logging.warning("Report is empty; not saving report")

            report_fname = f"{opt['report_filename']}.json"
            # Save report
            if not is_distributed() or is_primary_worker():
                with PathManager.open(report_fname, "w") as f:
                    logging.info(f"Saving model report to {report_fname}")
                    json.dump({"opt": opt, "report": report}, f, indent=4)
                    f.write("\n")  # for jq

        if "world_logs" in opt and opt["world_logs"] is not None:
            if is_distributed():  # Save separately, then aggregate together
                rank = get_rank()
                log_outfile_part = (
                    f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl")
                logger.write(log_outfile_part,
                             world,
                             file_format=opt["save_format"])
                sync_object(None)
                if is_primary_worker():
                    log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                    log_outfile_metadata = (
                        f"{opt['world_logs']}_{opt['save_format']}.metadata")
                    with open(log_outfile, "w+") as outfile:
                        for rank in range(num_workers()):
                            log_outfile_part = (
                                f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl"
                            )
                            with open(log_outfile_part) as infile:
                                for line in infile:
                                    json_blob = json.loads(line.strip())
                                    if (
                                            len(json_blob["dialog"]) < 2
                                    ):  # skip when we don't have generation
                                        continue
                                    json_blob[
                                        "metadata_path"] = log_outfile_metadata
                                    outfile.write(json.dumps(json_blob))
                                    outfile.write("\n")
                            log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata"
                            if rank == 0:
                                copyfile(log_output_part_metadata,
                                         log_outfile_metadata),
                            os.remove(log_outfile_part)
                            os.remove(log_output_part_metadata)
            else:
                log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                logger.write(log_outfile,
                             world,
                             file_format=opt["save_format"])

        return report