def _save_train_stats(self, suffix=None):
     fn = self.opt['model_file']
     if suffix:
         fn += suffix
     fn += '.trainstats'
     with open(fn, 'w') as f:
         json.dump(
             {
                 'parleys':
                 self.parleys,
                 'train_time':
                 self.train_time.time(),
                 'total_epochs':
                 (self._preempted_epochs +
                  num_workers() * self.world.get_total_epochs()),
                 'impatience':
                 self.impatience,
                 'valid_reports':
                 [self._safe_report(v) for v in self.valid_reports],
                 'best_valid':
                 self.best_valid,
             },
             f,
             indent=4,
         )
Exemple #2
0
 def _save_train_stats(self, suffix=None):
     fn = self.opt['model_file']
     if suffix:
         fn += suffix
     fn += '.trainstats'
     if not opt['overwrite_checkpoints']:
         fn += f'_{self.checkpoint_counter}'
     with open(fn, 'w') as f:
         json.dump(
             {
                 'parleys':
                 self.parleys,
                 'train_time':
                 self.train_time.time(),
                 'total_epochs':
                 (self._preempted_epochs +
                  num_workers() * self.world.get_total_epochs()),
                 'impatience':
                 self.impatience,
                 'valid_reports':
                 self.valid_reports,
                 'best_valid':
                 self.best_valid,
             },
             f,
         )
Exemple #3
0
    def _run_eval(
        self,
        valid_worlds,
        opt,
        datatype,
        max_exs=-1,
        write_log=False,
        extra_log_suffix="",
    ):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        for v_world in valid_worlds:
            task_report = self._run_single_eval(opt, v_world,
                                                max_exs_per_worker)
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(named_reports,
                                         micro_average=self.opt.get(
                                             'aggregate_micro', False))
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            with PathManager.open(
                    opt['model_file'] + extra_log_suffix + '.' + datatype,
                    'a') as f:
                f.write(f'{metrics}\n')

        return report
    def _run_eval(self,
                  valid_worlds,
                  opt,
                  datatype,
                  max_exs=-1,
                  write_log=False):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        print('[ running eval: ' + datatype + ' ]')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        for v_world in valid_worlds:
            task_report = self._run_single_eval(opt, v_world,
                                                max_exs_per_worker)
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(named_reports)
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:{nice_report(report)}'
        print(f'[ eval completed in {timer.time():.2f}s ]')
        print(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            f = open(opt['model_file'] + '.' + datatype, 'a+')
            f.write(f'{metrics}\n')
            f.close()

        return report
Exemple #5
0
    def _get_time(self, world: World) -> Tuple[float, float, float]:
        """
        Return train, log, and validate timing.

        If relying on the time for validation/logging/max train time purposes,
        we sync and return primary worker's time.

        Otherwise, it's not super relevant what we do here.

        **SIDE EFFECT**: Update _total_epochs trained.

        :param world:
            current running world

        :return (train, log, valid):
            return time for each of train, log, and validation
        """
        if (
            self.max_train_time < float('inf')
            or self.log_every_n_secs < float('inf')
            or self.val_every_n_secs < float('inf')
            or self.val_every_n_epochs < float('inf')
            or self.max_num_epochs < float('inf')
        ):
            self._total_epochs = self._preempted_epochs + sum(
                all_gather_list(world.get_total_epochs())
            )
            train_time, log_time, validate_time, save_time = sync_object(
                (
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                    self.save_time.time(),
                )
            )
        else:
            train_time, log_time, validate_time, save_time = (
                self.train_time.time(),
                self.log_time.time(),
                self.validate_time.time(),
                self.save_time.time(),
            )
            self._total_epochs = self._preempted_epochs + (
                num_workers() * world.get_total_epochs()
            )

        return train_time, log_time, validate_time, save_time
Exemple #6
0
 def __init__(self, opt: Opt, shared=None):
     if not hasattr(self, "fold"):
         self.fold = DatatypeHelper.fold(opt["datatype"])
     super().__init__(opt, shared)
     self.epochDone = False
     self.batchsize = opt.get("batchsize", 1)
     self.max_episodes = len(self.episodes)
     if opt.get("num_episodes", 0) > 0:
         self.max_episodes = min(self.max_episodes, opt.get("num_episodes"))
     self.episode_idx = opt.get("batchindex", 0)
     self._setup_next_episode()
     self.round_idx = 0  # for some downstream utt + sysUttAndApiCallAgents.
     if is_distributed():  # cause gotta manually handle
         rank = get_rank()
         chunk_size = ceil(self.max_episodes / num_workers())
         self.episode_idx += rank * chunk_size
         self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size)
Exemple #7
0
    def __init__(self, opt: Opt, shared=None):
        super().__init__(opt, shared)
        self.epochDone = False
        if shared is None:
            self.episodes = self._setup_single_goal_episodes()
        else:
            # Handled fine in _TodDataDumpAgent
            pass

        self.max_episodes = len(self.episodes)
        if opt.get("num_episodes", 0) > 0:
            self.max_episodes = min(self.max_episodes, opt.get("num_episodes"))
        if is_distributed():  # cause gotta manually handle
            rank = get_rank()
            chunk_size = ceil(self.max_episodes / num_workers())
            self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size)

        self._setup_next_episode()
Exemple #8
0
 def _save_train_stats(self, suffix=None):
     fn = self.opt['model_file']
     if suffix:
         fn += suffix
     fn += '.trainstats'
     with open(fn, 'w') as f:
         json.dump(
             {
                 'train_time': self.train_time.time(),
                 'total_epochs': (
                     self._preempted_epochs
                     + num_workers() * self.world.get_total_epochs()
                 ),
                 'impatience': self.impatience,
                 'valid_reports': self.valid_reports,
             },
             f,
         )
Exemple #9
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        count = 0
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not "
                                "supported for distributed mode")
                        break
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
Exemple #10
0
    def _save_outputs(self, opt, world, logger, episode_metrics):
        if is_distributed():  # flatten everything intelligently if need be
            world_report = aggregate_unnamed_reports(
                all_gather_list(world.report()))
            episode_metrics_unflattened = all_gather_list(episode_metrics)
            flattened = []
            for rank_elem in episode_metrics_unflattened:
                for elem in rank_elem:
                    flattened.append(elem)
            episode_metrics = flattened
        else:
            world_report = world.report()
        logging.report("Final report:\n" + nice_report(world_report))

        report = dict_report(world_report)

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        report["tod_metrics"] = [
            get_episode_report(g, e) for g, e in episode_metrics
        ]

        if "report_filename" in opt and opt["report_filename"] is not None:
            if len(world_report) == 0:
                logging.warning("Report is empty; not saving report")

            report_fname = f"{opt['report_filename']}.json"
            # Save report
            if not is_distributed() or is_primary_worker():
                with PathManager.open(report_fname, "w") as f:
                    logging.info(f"Saving model report to {report_fname}")
                    json.dump({"opt": opt, "report": report}, f, indent=4)
                    f.write("\n")  # for jq

        if "world_logs" in opt and opt["world_logs"] is not None:
            if is_distributed():  # Save separately, then aggregate together
                rank = get_rank()
                log_outfile_part = (
                    f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl")
                logger.write(log_outfile_part,
                             world,
                             file_format=opt["save_format"])
                sync_object(None)
                if is_primary_worker():
                    log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                    log_outfile_metadata = (
                        f"{opt['world_logs']}_{opt['save_format']}.metadata")
                    with open(log_outfile, "w+") as outfile:
                        for rank in range(num_workers()):
                            log_outfile_part = (
                                f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl"
                            )
                            with open(log_outfile_part) as infile:
                                for line in infile:
                                    json_blob = json.loads(line.strip())
                                    if (
                                            len(json_blob["dialog"]) < 2
                                    ):  # skip when we don't have generation
                                        continue
                                    json_blob[
                                        "metadata_path"] = log_outfile_metadata
                                    outfile.write(json.dumps(json_blob))
                                    outfile.write("\n")
                            log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata"
                            if rank == 0:
                                copyfile(log_output_part_metadata,
                                         log_outfile_metadata),
                            os.remove(log_outfile_part)
                            os.remove(log_output_part_metadata)
            else:
                log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                logger.write(log_outfile,
                             world,
                             file_format=opt["save_format"])

        return report
Exemple #11
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    if opt.get('run_test_after_validation', False):
                        # run evaluation on the test data as well
                        test_opt = copy.deepcopy(self.opt)
                        test_opt['display_examples'] = False
                        test_opt['report_freq'] = 0
                        if self.test_worlds is None:
                            # we need to load the world now
                            self.test_worlds = _maybe_load_eval_worlds(
                                self.agent, test_opt, 'test')
                        run_eval(self.test_worlds,
                                 test_opt,
                                 'test',
                                 -1,
                                 write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print(
                '[ Conducting evaluations on valid and test data using the last model. ]'
            )
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt,
                                                   'valid')
            max_exs = last_model_opt[
                'validation_max_exs'] if last_model_opt.get(
                    'short_final_eval') else -1
            run_eval(valid_worlds,
                     last_model_opt,
                     'valid',
                     max_exs,
                     write_log=True)
            test_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt,
                                                  'test')
            run_eval(test_worlds,
                     last_model_opt,
                     'test',
                     max_exs,
                     write_log=True)
            if valid_worlds:
                for valid_world in valid_worlds:
                    valid_world.shutdown()
            if test_worlds:
                for test_world in test_worlds:
                    test_world.shutdown()
        # --------------- change by hengyicai -------------------------
        print_announcements(opt)

        return v_report, t_report