コード例 #1
0
    def _run_eval(
        self,
        valid_worlds,
        opt,
        datatype,
        max_exs=-1,
        write_log=False,
        extra_log_suffix="",
    ):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        for v_world in valid_worlds:
            task_report = self._run_single_eval(opt, v_world,
                                                max_exs_per_worker)
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(named_reports,
                                         micro_average=self.opt.get(
                                             'aggregate_micro', False))
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            with PathManager.open(
                    opt['model_file'] + extra_log_suffix + '.' + datatype,
                    'a') as f:
                f.write(f'{metrics}\n')

        return report
コード例 #2
0
def eval_model(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :return: the final result of calling report()
    """
    random.seed(42)
    if 'train' in opt['datatype'] and 'evalmode' not in opt['datatype']:
        raise ValueError(
            'You should use --datatype train:evalmode if you want to evaluate on '
            'the training set.')

    # load model and possibly print opt
    agent = create_agent(opt, requireModelExists=True)
    agent.opt.log()

    tasks = opt['task'].split(',')
    reports = []
    for task in tasks:
        task_report = _eval_single_world(opt, agent, task)
        reports.append(task_report)
        logging.report(f"Report for {task}:\n{nice_report(task_report)}")

    report = aggregate_named_reports(dict(zip(tasks, reports)),
                                     micro_average=opt.get(
                                         'aggregate_micro', False))

    # print announcements and report
    print_announcements(opt)
    logging.info(
        f'Finished evaluating tasks {tasks} using datatype {opt.get("datatype")}'
    )

    print(nice_report(report))
    _save_eval_stats(opt, report)
    return report
コード例 #3
0
    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = self._run_eval(
            self.valid_worlds, opt, 'valid', opt['validation_max_exs']
        )
        v = dict_report(valid_report)
        v['train_time'] = self.train_time.time()
        v['parleys'] = self.parleys
        v['train_steps'] = self._train_steps
        v['total_exs'] = self._total_exs
        v['total_epochs'] = self._total_epochs
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        if opt['wandb_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.wb_logger.log_metrics('valid', self.parleys, valid_report)

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (
            self.best_valid is None
            or self.valid_optim * new_valid > self.valid_optim * self.best_valid
        ):
            logging.success(
                'new best {}: {:.4g}{}'.format(
                    opt['validation_metric'],
                    new_valid,
                    ' (previous best was {:.4g})'.format(self.best_valid)
                    if self.best_valid is not None
                    else '',
                )
            )
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                logging.info(f"saving best valid model: {opt['model_file']}")
                self.save_model()
                self.saved = True
            if (
                opt['validation_metric_mode'] == 'max'
                and self.best_valid >= opt['validation_cutoff']
            ) or (
                opt['validation_metric_mode'] == 'min'
                and self.best_valid <= opt['validation_cutoff']
            ):
                logging.info('task solved! stopping.')
                return True
        else:
            self.impatience += 1
            logging.report(
                'did not beat best {}: {} impatience: {}'.format(
                    opt['validation_metric'], round(self.best_valid, 4), self.impatience
                )
            )
        self.validate_time.reset()

        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
            self.save_model('.checkpoint')

        # check if we are out of patience
        if (
            opt['validation_patience'] > 0
            and self.impatience >= opt['validation_patience']
        ):
            logging.info('ran out of patience! stopping training.')
            return True
        return False
コード例 #4
0
ファイル: tod_world_script.py プロジェクト: sagar-spkt/ParlAI
    def _save_outputs(self, opt, world, logger, episode_metrics):
        if is_distributed():  # flatten everything intelligently if need be
            world_report = aggregate_unnamed_reports(
                all_gather_list(world.report()))
            episode_metrics_unflattened = all_gather_list(episode_metrics)
            flattened = []
            for rank_elem in episode_metrics_unflattened:
                for elem in rank_elem:
                    flattened.append(elem)
            episode_metrics = flattened
        else:
            world_report = world.report()
        logging.report("Final report:\n" + nice_report(world_report))

        report = dict_report(world_report)

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        report["tod_metrics"] = [
            get_episode_report(g, e) for g, e in episode_metrics
        ]

        if "report_filename" in opt and opt["report_filename"] is not None:
            if len(world_report) == 0:
                logging.warning("Report is empty; not saving report")

            report_fname = f"{opt['report_filename']}.json"
            # Save report
            if not is_distributed() or is_primary_worker():
                with PathManager.open(report_fname, "w") as f:
                    logging.info(f"Saving model report to {report_fname}")
                    json.dump({"opt": opt, "report": report}, f, indent=4)
                    f.write("\n")  # for jq

        if "world_logs" in opt and opt["world_logs"] is not None:
            if is_distributed():  # Save separately, then aggregate together
                rank = get_rank()
                log_outfile_part = (
                    f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl")
                logger.write(log_outfile_part,
                             world,
                             file_format=opt["save_format"])
                sync_object(None)
                if is_primary_worker():
                    log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                    log_outfile_metadata = (
                        f"{opt['world_logs']}_{opt['save_format']}.metadata")
                    with open(log_outfile, "w+") as outfile:
                        for rank in range(num_workers()):
                            log_outfile_part = (
                                f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl"
                            )
                            with open(log_outfile_part) as infile:
                                for line in infile:
                                    json_blob = json.loads(line.strip())
                                    if (
                                            len(json_blob["dialog"]) < 2
                                    ):  # skip when we don't have generation
                                        continue
                                    json_blob[
                                        "metadata_path"] = log_outfile_metadata
                                    outfile.write(json.dumps(json_blob))
                                    outfile.write("\n")
                            log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata"
                            if rank == 0:
                                copyfile(log_output_part_metadata,
                                         log_outfile_metadata),
                            os.remove(log_outfile_part)
                            os.remove(log_output_part_metadata)
            else:
                log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                logger.write(log_outfile,
                             world,
                             file_format=opt["save_format"])

        return report