コード例 #1
0
 def _save_train_stats(self, suffix=None):
     if not is_primary_worker():
         # never do IO as a non-primary worker
         return
     fn = self.opt.get('model_file', None)
     if not fn:
         return
     if suffix:
         fn += suffix
     fn += '.trainstats'
     with PathManager.open(fn, 'w') as f:
         json.dump(
             {
                 'parleys': self.parleys,
                 'train_time': self.train_time.time(),
                 'train_steps': self._train_steps,
                 'total_epochs': self._total_epochs,
                 'train_reports': self.train_reports,
                 'valid_reports': self.valid_reports,
                 'best_valid': self.best_valid,
                 'impatience': self.impatience,
                 'final_valid_report': dict_report(self.final_valid_report),
                 'final_test_report': dict_report(self.final_test_report),
                 'final_extra_valid_report': dict_report(
                     self.final_extra_valid_report
                 ),
             },
             f,
             indent=4,
         )
コード例 #2
0
 def log_final(self, setting, report):
     report = dict_report(report)
     report = {
         f'{k}/{setting}': v
         for k, v in report.items() if isinstance(v, numbers.Number)
     }
     for key, value in report.items():
         self.run.summary[key] = value
コード例 #3
0
    def _run_opt_get_report(self, opt):
        script = TestTodWorldScript(opt)
        script.run()

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        return (
            dict_report(script.world.report()),
            [get_episode_report(g, e) for g, e in script.episode_metrics],
        )
コード例 #4
0
 def _check_metrics_correct(self, script, opt):
     """
     Last argument is only relevant for the max_turn test.
     """
     max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY]
     max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY]
     episode_metrics = script.episode_metrics
     for episode_idx, episode in enumerate(episode_metrics):
         goal, episode_metric = episode
         episode_metric = dict_report(episode_metric.report())
         self.assertAlmostEqual(
             episode_metric["all_goals_hit"],
             not test_agents.episode_has_broken_api_turn(
                 episode_idx, max_rounds),
         )
     broken_episodes = sum([
         test_agents.episode_has_broken_api_turn(i, max_rounds)
         for i in range(max_episodes)
     ])
     report = dict_report(script.world.report())
     self.assertAlmostEqual(
         report["all_goals_hit"],
         float(max_episodes - broken_episodes) / max_episodes,
     )
コード例 #5
0
    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        train_report_trainstats = dict_report(train_report)
        train_report_trainstats['total_epochs'] = self._total_epochs
        train_report_trainstats['total_exs'] = self._total_exs
        train_report_trainstats['parleys'] = self.parleys
        train_report_trainstats['train_steps'] = self._train_steps
        train_report_trainstats['train_time'] = self.train_time.time()
        self.train_reports.append(train_report_trainstats)

        # time elapsed
        logs.append(f'time:{self.train_time.time():.0f}s')
        logs.append(f'total_exs:{self._total_exs}')
        logs.append(f'total_steps:{self._train_steps}')

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append(f'epochs:{self._total_epochs:.2f}')

        time_left = self._compute_eta(
            self._total_epochs, self.train_time.time(), self._train_steps
        )
        if time_left is not None:
            logs.append(f'time_left:{max(0,time_left):.0f}s')

        log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
        logging.info(log)
        self.log_time.reset()
        self._last_log_steps = 0

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_metrics('train', self.parleys, train_report)

        return train_report
コード例 #6
0
    def log_metrics(self, setting, step, report):
        """
        Log all metrics to W&B.

        :param setting:
            One of train/valid/test. Will be used as the title for the graph.
        :param step:
            Number of parleys
        :param report:
            The report to log
        """
        report = dict_report(report)
        report = {
            f'{k}/{setting}': v
            for k, v in report.items() if isinstance(v, numbers.Number)
        }
        report['custom_step'] = step
        self.run.log(report)
コード例 #7
0
    def test_token_stats(self):
        from parlai.scripts.token_stats import TokenStats
        from parlai.core.metrics import dict_report

        results = dict_report(TokenStats.main(task='integration_tests:multiturn'))
        assert results == {
            'exs': 2000,
            'max': 16,
            'mean': 7.5,
            'min': 1,
            'p01': 1,
            'p05': 1,
            'p10': 1,
            'p25': 4,
            'p50': 7.5,
            'p75': 11.5,
            'p90': 16,
            'p95': 16,
            'p99': 16,
            'p@128': 1,
        }
コード例 #8
0
    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = self._run_eval(
            self.valid_worlds, opt, 'valid', opt['validation_max_exs']
        )
        v = dict_report(valid_report)
        v['train_time'] = self.train_time.time()
        v['parleys'] = self.parleys
        v['train_steps'] = self._train_steps
        v['total_exs'] = self._total_exs
        v['total_epochs'] = self._total_epochs
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        if opt['wandb_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.wb_logger.log_metrics('valid', self.parleys, valid_report)

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (
            self.best_valid is None
            or self.valid_optim * new_valid > self.valid_optim * self.best_valid
        ):
            logging.success(
                'new best {}: {:.4g}{}'.format(
                    opt['validation_metric'],
                    new_valid,
                    ' (previous best was {:.4g})'.format(self.best_valid)
                    if self.best_valid is not None
                    else '',
                )
            )
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                logging.info(f"saving best valid model: {opt['model_file']}")
                self.save_model()
                self.saved = True
            if (
                opt['validation_metric_mode'] == 'max'
                and self.best_valid >= opt['validation_cutoff']
            ) or (
                opt['validation_metric_mode'] == 'min'
                and self.best_valid <= opt['validation_cutoff']
            ):
                logging.info('task solved! stopping.')
                return True
        else:
            self.impatience += 1
            logging.report(
                'did not beat best {}: {} impatience: {}'.format(
                    opt['validation_metric'], round(self.best_valid, 4), self.impatience
                )
            )
        self.validate_time.reset()

        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
            self.save_model('.checkpoint')

        # check if we are out of patience
        if (
            opt['validation_patience'] > 0
            and self.impatience >= opt['validation_patience']
        ):
            logging.info('ran out of patience! stopping training.')
            return True
        return False
コード例 #9
0
ファイル: conversations.py プロジェクト: johnbent/cortx-1
    def save_conversations(
        cls,
        act_list,
        datapath,
        opt,
        save_keys='all',
        context_ids='context',
        self_chat=False,
        **kwargs,
    ):
        """
        Write Conversations to file from an act list.

        Conversations assume the act list is of the following form: a list of episodes,
        each of which is comprised of a list of act pairs (i.e. a list dictionaries
        returned from one parley)
        """
        to_save = cls._get_path(datapath)

        context_ids = context_ids.split(',')
        # save conversations
        speakers = []
        with PathManager.open(to_save, 'w') as f:
            for ep in act_list:
                if not ep:
                    continue
                convo = {
                    'dialog': [],
                    'context': [],
                    'metadata_path': Metadata._get_path(to_save),
                }
                for act_pair in ep:
                    new_pair = []
                    for ex in act_pair:
                        ex_id = ex.get('id')
                        if ex_id in context_ids:
                            context = True
                        else:
                            context = False
                            if ex_id not in speakers:
                                speakers.append(ex_id)

                        # set turn
                        turn = {}
                        if save_keys != 'all':
                            save_keys_lst = save_keys.split(',')
                        else:
                            save_keys_lst = ex.keys()
                        for key in save_keys_lst:
                            turn[key] = ex.get(key, '')
                            if key == 'metrics':
                                turn[key] = dict_report(turn[key])
                        turn['id'] = ex_id
                        if not context:
                            new_pair.append(turn)
                        else:
                            convo['context'].append(turn)
                    if new_pair:
                        convo['dialog'].append(new_pair)
                json_convo = json.dumps(convo,
                                        default=lambda v: '<not serializable>')
                f.write(json_convo + '\n')
        logging.info(f'Conversations saved to file: {to_save}')

        # save metadata
        Metadata.save_metadata(to_save,
                               opt,
                               self_chat=self_chat,
                               speakers=speakers,
                               **kwargs)
コード例 #10
0
ファイル: tod_world_script.py プロジェクト: sagar-spkt/ParlAI
 def get_episode_report(goal, episode_metric):
     metrics_dict = dict_report(episode_metric.report())
     metrics_dict["goal"] = goal
     return metrics_dict
コード例 #11
0
ファイル: tod_world_script.py プロジェクト: sagar-spkt/ParlAI
    def _save_outputs(self, opt, world, logger, episode_metrics):
        if is_distributed():  # flatten everything intelligently if need be
            world_report = aggregate_unnamed_reports(
                all_gather_list(world.report()))
            episode_metrics_unflattened = all_gather_list(episode_metrics)
            flattened = []
            for rank_elem in episode_metrics_unflattened:
                for elem in rank_elem:
                    flattened.append(elem)
            episode_metrics = flattened
        else:
            world_report = world.report()
        logging.report("Final report:\n" + nice_report(world_report))

        report = dict_report(world_report)

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        report["tod_metrics"] = [
            get_episode_report(g, e) for g, e in episode_metrics
        ]

        if "report_filename" in opt and opt["report_filename"] is not None:
            if len(world_report) == 0:
                logging.warning("Report is empty; not saving report")

            report_fname = f"{opt['report_filename']}.json"
            # Save report
            if not is_distributed() or is_primary_worker():
                with PathManager.open(report_fname, "w") as f:
                    logging.info(f"Saving model report to {report_fname}")
                    json.dump({"opt": opt, "report": report}, f, indent=4)
                    f.write("\n")  # for jq

        if "world_logs" in opt and opt["world_logs"] is not None:
            if is_distributed():  # Save separately, then aggregate together
                rank = get_rank()
                log_outfile_part = (
                    f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl")
                logger.write(log_outfile_part,
                             world,
                             file_format=opt["save_format"])
                sync_object(None)
                if is_primary_worker():
                    log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                    log_outfile_metadata = (
                        f"{opt['world_logs']}_{opt['save_format']}.metadata")
                    with open(log_outfile, "w+") as outfile:
                        for rank in range(num_workers()):
                            log_outfile_part = (
                                f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl"
                            )
                            with open(log_outfile_part) as infile:
                                for line in infile:
                                    json_blob = json.loads(line.strip())
                                    if (
                                            len(json_blob["dialog"]) < 2
                                    ):  # skip when we don't have generation
                                        continue
                                    json_blob[
                                        "metadata_path"] = log_outfile_metadata
                                    outfile.write(json.dumps(json_blob))
                                    outfile.write("\n")
                            log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata"
                            if rank == 0:
                                copyfile(log_output_part_metadata,
                                         log_outfile_metadata),
                            os.remove(log_outfile_part)
                            os.remove(log_output_part_metadata)
            else:
                log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                logger.write(log_outfile,
                             world,
                             file_format=opt["save_format"])

        return report