Ejemplo n.º 1
0
class TrainLoop:
    """
    TrainLoop contains the core training loop logic.
    """

    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (
            opt['load_from_checkpoint']
            and opt.get('model_file')
            and PathManager.exists(opt['model_file'] + '.checkpoint')
        ):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self._train_steps = 0
        self._last_log_steps = 0
        self.update_freq = opt.get('update_freq', 1)

        self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
        self.max_train_time = _num_else_inf(
            opt, 'max_train_time', distributed_warn=True
        )
        self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
        self.log_every_n_secs = _num_else_inf(
            opt, 'log_every_n_secs', distributed_warn=True
        )
        self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
        self.val_every_n_secs = _num_else_inf(
            opt, 'validation_every_n_secs', distributed_warn=True
        )
        self.val_every_n_epochs = _num_else_inf(
            opt, 'validation_every_n_epochs', distributed_warn=True
        )
        self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
        self.save_every_n_secs = _num_else_inf(
            opt, 'save_every_n_secs', distributed_warn=True
        )

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self._last_valid_steps = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.final_valid_report = {}
        self.final_test_report = {}
        self.final_extra_valid_report = {}
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(
            opt['model_file'] + trainstats_suffix
        ):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self._train_steps = obj.get('train_steps', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if self.valid_reports:
                    self.last_valid_epoch = self.valid_reports[-1].get(
                        'total_epochs', 0.0
                    )
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                        opt['model_file'] + '.best_valid'
                    ):
                        with PathManager.open(
                            opt['model_file'] + ".best_valid", 'r'
                        ) as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
        if opt['wandb_log'] and is_primary_worker():
            model = self.agent.model if hasattr(self.agent, 'model') else None
            self.wb_logger = WandbLogger(opt, model)

    def save_model(self, suffix=None):
        """
        Save the model to disk, possibly with a suffix.
        """
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix

        if not is_primary_worker():
            # never do IO as a non-primary worker
            if hasattr(self.agent, 'save_nonprimary'):
                self.agent.save_nonprimary(fn)
            return

        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _save_train_stats(self, suffix=None):
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        fn = self.opt.get('model_file', None)
        if not fn:
            return
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with PathManager.open(fn, 'w') as f:
            json.dump(
                {
                    'parleys': self.parleys,
                    'train_time': self.train_time.time(),
                    'train_steps': self._train_steps,
                    'total_epochs': self._total_epochs,
                    'train_reports': self.train_reports,
                    'valid_reports': self.valid_reports,
                    'best_valid': self.best_valid,
                    'impatience': self.impatience,
                    'final_valid_report': dict_report(self.final_valid_report),
                    'final_test_report': dict_report(self.final_test_report),
                    'final_extra_valid_report': dict_report(
                        self.final_extra_valid_report
                    ),
                },
                f,
                indent=4,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = self._run_eval(
            self.valid_worlds, opt, 'valid', opt['validation_max_exs']
        )
        v = dict_report(valid_report)
        v['train_time'] = self.train_time.time()
        v['parleys'] = self.parleys
        v['train_steps'] = self._train_steps
        v['total_exs'] = self._total_exs
        v['total_epochs'] = self._total_epochs
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        if opt['wandb_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.wb_logger.log_metrics('valid', self.parleys, valid_report)

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (
            self.best_valid is None
            or self.valid_optim * new_valid > self.valid_optim * self.best_valid
        ):
            logging.success(
                'new best {}: {:.4g}{}'.format(
                    opt['validation_metric'],
                    new_valid,
                    ' (previous best was {:.4g})'.format(self.best_valid)
                    if self.best_valid is not None
                    else '',
                )
            )
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file'):
                logging.info(f"saving best valid model: {opt['model_file']}")
                self.save_model()
                self.saved = True
            if (
                opt['validation_metric_mode'] == 'max'
                and self.best_valid >= opt['validation_cutoff']
            ) or (
                opt['validation_metric_mode'] == 'min'
                and self.best_valid <= opt['validation_cutoff']
            ):
                logging.info('task solved! stopping.')
                return True
        else:
            self.impatience += 1
            logging.report(
                'did not beat best {}: {} impatience: {}'.format(
                    opt['validation_metric'], round(self.best_valid, 4), self.impatience
                )
            )
        self.validate_time.reset()

        # saving
        if opt.get('model_file') and opt.get('save_after_valid'):
            logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
            self.save_model('.checkpoint')

        # check if we are out of patience
        if (
            opt['validation_patience'] > 0
            and self.impatience >= opt['validation_patience']
        ):
            logging.info('ran out of patience! stopping training.')
            return True
        return False

    def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

        # run evaluation on a single world
        valid_world.reset()

        world_logger = None
        task_opt = opt.copy()
        # set up world logger for the "test" fold
        if opt['world_logs'] and datatype == 'test':
            task_opt['world_logs'] = get_task_world_logs(
                task, opt['world_logs'], is_multitask
            )
            world_logger = WorldLogger(task_opt)

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if world_logger is not None:
                world_logger.log(valid_world)
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        if world_logger is not None:
            # dump world acts to file
            world_logger.reset()  # add final acts to logs
            if is_distributed():
                rank = get_rank()
                base_outfile, extension = os.path.splitext(task_opt['world_logs'])
                outfile = base_outfile + f'_{rank}' + extension
            else:
                outfile = task_opt['world_logs']
            world_logger.write(outfile, valid_world, file_format=opt['save_format'])

        valid_report = valid_world.report()
        if opt.get('validation_share_agent', False):
            valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report

    def _run_eval(
        self,
        valid_worlds,
        opt,
        datatype,
        max_exs=-1,
        write_log=False,
        extra_log_suffix="",
    ):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        is_multitask = len(valid_worlds) > 1
        for index, v_world in enumerate(valid_worlds):
            if opt.get('evaltask'):
                task = opt['evaltask'].split(',')[index]
            else:
                task = opt['task'].split(',')[index]
            task_report = self._run_single_eval(
                opt, v_world, max_exs_per_worker, datatype, is_multitask, task
            )
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(
            named_reports, micro_average=self.opt.get('aggregate_micro', False)
        )
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            with PathManager.open(
                opt['model_file'] + extra_log_suffix + '.' + datatype, 'a'
            ) as f:
                f.write(f'{metrics}\n')

        return report

    def _run_final_extra_eval(self, opt):
        final_valid_opt = copy.deepcopy(opt)
        final_valid_opt_raw = Opt.load_init(opt['final_extra_opt'])
        final_datatype = final_valid_opt_raw["datatype"]
        for k, v in final_valid_opt_raw.items():
            final_valid_opt[k] = v
        final_max_exs = (
            final_valid_opt['validation_max_exs']
            if final_valid_opt.get('short_final_eval')
            else -1
        )
        final_valid_world = load_eval_worlds(
            self.agent, final_valid_opt, final_datatype
        )
        final_valid_report = self._run_eval(
            final_valid_world,
            final_valid_opt,
            final_datatype,
            final_max_exs,
            write_log=True,
            extra_log_suffix="_extra",
        )
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_final(final_datatype, final_valid_report)

        return final_valid_report

    def _sync_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining metrics
        are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return aggregate_unnamed_reports(all_versions)

    def _compute_eta(
        self, epochs_completed: float, time_elapsed: float, steps_taken: int
    ):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        max_train_steps = self.opt.get('max_train_steps', -1)
        if max_train_steps > 0 and steps_taken > 0:
            steps_progress = steps_taken / max_train_steps
            eta = (1 - steps_progress) * time_elapsed / steps_progress

        return eta

    def _get_time(self, world: World) -> Tuple[float, float, float]:
        """
        Return train, log, and validate timing.

        If relying on the time for validation/logging/max train time purposes,
        we sync and return primary worker's time.

        Otherwise, it's not super relevant what we do here.

        **SIDE EFFECT**: Update _total_epochs trained.

        :param world:
            current running world

        :return (train, log, valid):
            return time for each of train, log, and validation
        """
        if (
            self.max_train_time < float('inf')
            or self.log_every_n_secs < float('inf')
            or self.val_every_n_secs < float('inf')
            or self.val_every_n_epochs < float('inf')
            or self.max_num_epochs < float('inf')
        ):
            self._total_epochs = self._preempted_epochs + sum(
                all_gather_list(world.get_total_epochs())
            )
            train_time, log_time, validate_time, save_time = sync_object(
                (
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                    self.save_time.time(),
                )
            )
        else:
            train_time, log_time, validate_time, save_time = (
                self.train_time.time(),
                self.log_time.time(),
                self.validate_time.time(),
                self.save_time.time(),
            )
            self._total_epochs = self._preempted_epochs + (
                num_workers() * world.get_total_epochs()
            )

        return train_time, log_time, validate_time, save_time

    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        train_report_trainstats = dict_report(train_report)
        train_report_trainstats['total_epochs'] = self._total_epochs
        train_report_trainstats['total_exs'] = self._total_exs
        train_report_trainstats['parleys'] = self.parleys
        train_report_trainstats['train_steps'] = self._train_steps
        train_report_trainstats['train_time'] = self.train_time.time()
        self.train_reports.append(train_report_trainstats)

        # time elapsed
        logs.append(f'time:{self.train_time.time():.0f}s')
        logs.append(f'total_exs:{self._total_exs}')
        logs.append(f'total_steps:{self._train_steps}')

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append(f'epochs:{self._total_epochs:.2f}')

        time_left = self._compute_eta(
            self._total_epochs, self.train_time.time(), self._train_steps
        )
        if time_left is not None:
            logs.append(f'time_left:{max(0,time_left):.0f}s')

        log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
        logging.info(log)
        self.log_time.reset()
        self._last_log_steps = 0

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_metrics('train', self.parleys, train_report)

        return train_report

    def train_steps(self):
        """
        Core training loop.

        Yields a metrics dict with each log.
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1
                self._train_steps = self.parleys // self.update_freq
                self._last_log_steps += 1 / self.update_freq

                # the following additionally updates self._total_epochs
                train_time, log_time, validate_time, save_time = self._get_time(world)
                # get the total training examples done, compute epochs
                exs_per_epoch = world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    yield self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if self._train_steps >= self.max_train_steps:
                    logging.info(
                        f'max_train_steps elapsed:{self._train_steps} '
                        f'time elapsed:{train_time}s'
                    )
                    break
                if (
                    log_time > self.log_every_n_secs
                    or self._last_log_steps >= self.log_every_n_steps
                ):
                    yield self.log()
                if (
                    validate_time > self.val_every_n_secs
                    or self._total_epochs - self.last_valid_epoch
                    >= self.val_every_n_epochs
                    or self._train_steps - self._last_valid_steps
                    >= self.val_every_n_steps
                ):
                    try:
                        # log before we validate
                        if self._last_log_steps:
                            yield self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    self._last_valid_steps = self._train_steps
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if save_time > self.save_every_n_secs and opt.get('model_file'):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not sync_object(self.saved):
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        for _train_log in self.train_steps():
            # we've already done what we need in these
            pass

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        self.final_valid_report = self._run_eval(
            valid_worlds, opt, 'valid', max_exs, write_log=True
        )
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        self.final_test_report = self._run_eval(
            test_worlds, opt, 'test', max_exs, write_log=True
        )

        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_final('valid', self.final_valid_report)
            self.wb_logger.log_final('test', self.final_test_report)
            self.wb_logger.finish()

        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        if opt['final_extra_opt'] != '':
            self.final_extra_valid_report = self._run_final_extra_eval(opt)

        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.finish()

        self._save_train_stats()

        return self.final_valid_report, self.final_test_report
Ejemplo n.º 2
0
class World(object):
    """
    Empty parent providing null definitions of API functions for Worlds.

    All children can override these to provide more detailed functionality.
    """
    def __init__(self, opt: Opt, agents=None, shared=None):
        self.id = opt['task']
        self.opt = copy.deepcopy(opt)
        if shared:
            # Create agents based on shared data.
            self.agents = create_agents_from_shared(shared['agents'])
        else:
            # Add passed in agents to world directly.
            self.agents = agents
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time = Timer()

    def parley(self):
        """
        Perform one step of actions for the agents in the world.

        This is empty in the base class.
        """
        # TODO: mark as abstract?
        pass

    def getID(self):
        """
        Return the name of the world, typically the task the world encodes.
        """
        return self.id

    def display(self):
        """
        Return a string describing the current state of the world.

        Useful for monitoring and debugging. By default, display the messages between
        the agents.
        """
        if not hasattr(self, 'acts'):
            return ''
        return display_messages(
            self.acts,
            ignore_fields=self.opt.get('display_ignore_fields', ''),
            prettify=self.opt.get('display_prettify', False),
            max_len=self.opt.get('max_display_len', 1000),
        )

    def episode_done(self):
        """
        Whether the episode is done or not.
        """
        return False

    def epoch_done(self):
        """
        Whether the epoch is done or not.

        Not all worlds have the notion of an epoch, but this is useful for fixed
        training, validation or test sets.
        """
        return False

    def share(self):
        """
        Share the world.
        """
        shared_data = {}
        shared_data['world_class'] = type(self)
        shared_data['opt'] = self.opt
        shared_data['agents'] = self._share_agents()
        return shared_data

    def _share_agents(self):
        """
        Create shared data for agents.

        Allows other classes to create the same agents without duplicating the data
        (i.e. sharing parameters).
        """
        if not hasattr(self, 'agents'):
            return None
        shared_agents = [a.share() for a in self.agents]
        return shared_agents

    def get_agents(self):
        """
        Return the list of agents.
        """
        return self.agents

    def get_task_agent(self):
        """
        Return task agent, if applicable.
        """
        raise NotImplementedError('Implement in subworld')

    def get_acts(self):
        """
        Return the last act of each agent.
        """
        return self.acts

    def get_time(self):
        """
        Return total training time.
        """
        return self.time.time()

    def get_total_exs(self):
        """
        Return total amount of examples seen by world.
        """
        return self.total_exs

    def get_total_epochs(self):
        """
        Return total amount of epochs on which the world has trained.
        """
        return self.total_epochs

    def __enter__(self):
        """
        Empty enter provided for use with ``with`` statement.

        e.g:

        .. code-block:: python

            with World() as world:
                for n in range(10):
                    n.parley()
        """
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        """
        After ``with`` statement, call shutdown.
        """
        self.shutdown()
        return False

    def num_examples(self):
        """
        Return the number of examples.

        Always 0 in the abstract world.
        """
        # TODO: mark as abstract?
        return 0

    def num_episodes(self):
        """
        Return the number of episodes.

        Always 0 in the abstract world.
        """
        # TODO: mark as abstract?
        return 0

    def reset(self):
        """
        Reset all agents in the world, and world statistics.
        """
        for a in self.agents:
            a.reset()
        self.max_exs = None
        self.total_exs = 0
        self.total_epochs = 0
        self.total_parleys = 0
        self.time.reset()

    def reset_metrics(self):
        """
        Reset metrics for all agents.
        """
        for a in self.agents:
            a.reset_metrics()

    def shutdown(self):
        """
        Perform any cleanup, if appropriate.
        """
        pass

    def update_counters(self):
        """
        Update how many epochs have completed.
        """
        self.total_parleys += 1
        if self.max_exs is None:
            if 'num_epochs' in self.opt and self.opt['num_epochs'] > 0:
                if self.num_examples:
                    self.max_exs = self.num_examples() * self.opt['num_epochs']
                else:
                    self.max_exs = -1
            else:
                self.max_exs = -1
        # when we know the size of the data
        if self.max_exs > 0 or self.num_examples():
            self.total_epochs = (self.total_parleys *
                                 self.opt.get('batchsize', 1) /
                                 self.num_examples())
        # when we do not know the size of the data
        else:
            if self.epoch_done():
                self.total_epochs += 1
Ejemplo n.º 3
0
class TrainLoop:
    """
    TrainLoop contains the core training loop logic.
    """
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (opt['load_from_checkpoint'] and opt.get('model_file')
                and os.path.isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.')
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))
        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))
        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))
        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))
        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))
        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.valid_reports = []
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    trainstats_suffix):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and os.path.isfile(
                            opt['model_file'] + '.best_valid'):
                        with open(opt['model_file'] + ".best_valid", 'r') as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        """
        Save the model to disk, possibly with a suffix.
        """
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _safe_report(self, report):
        return {
            k: v.value() if isinstance(v, Metric) else v
            for k, v in report.items()
        }

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'parleys':
                    self.parleys,
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                    'valid_reports':
                    [self._safe_report(v) for v in self.valid_reports],
                    'best_valid':
                    self.best_valid,
                },
                f,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        # TODO(MW): replace sync_object with self._sync_metrics. You'll need some
        # logic to handle 'validation_max_exs' properly
        valid_report = run_eval(self.valid_worlds, opt, 'valid',
                                opt['validation_max_exs'])
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'],
                new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else '',
            ))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def _sync_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining metrics
        are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return aggregate_unnamed_reports(all_versions)

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(np.floor(self.train_time.time())))
        logs.append('total_exs:{}'.format(self._total_exs))

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append('epochs:{}'.format(round(self._total_epochs, 2)))

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append('time_left:{}s'.format(max(0, np.ceil(time_left))))

        log = '[ {} ] {}'.format(' '.join(logs), nice_report(train_report))
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        count = 0
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not "
                                "supported for distributed mode")
                        break
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
Ejemplo n.º 4
0
class TrainLoop:
    """TrainLoop contains the core training loop logic."""
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (opt['load_from_checkpoint'] and opt.get('model_file')
                and os.path.isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.')
        if 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))
        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))
        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))
        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))
        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))
        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.valid_reports = []
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    trainstats_suffix):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])

        if opt['tensorboard_log'] is True:
            self.tb_logger = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        """Save the model to disk, possibly with a suffix."""
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return
        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                    'valid_reports':
                    self.valid_reports,
                },
                f,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = _maybe_load_eval_worlds(
                self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(
            run_eval(self.valid_worlds, opt, 'valid',
                     opt['validation_max_exs']))
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if 'tasks' in valid_report and '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'],
                new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else '',
            ))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                _save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False

    def _average_dicts(self, all_versions):
        # instead of a list-of-dicts with like keys, make a dict-of-lists with
        # keys to reduce
        to_reduce = {}
        for d in all_versions:
            for k, v in d.items():
                to_reduce.setdefault(k, []).append(v)
        # now perform the reduction
        finalized = {}
        for k, values in to_reduce.items():
            if k == 'exs' or k == 'total_skipped_batches':
                # sum across workers
                finalized[k] = np.sum(values)
            elif isinstance(values[0], dict):
                # do the same procedure recursively
                finalized[k] = self._average_dicts(values)
            elif isinstance(values[0], str):
                finalized[k] = values[0]
            else:
                # all other cases, take the mean across the workers
                finalized[k] = np.mean(values)
                if all(isinstance(v, int) for v in values):
                    finalized[k] = int(finalized[k])
        return finalized

    def _cleanup_inaccurate_metrics(self, metrics):
        """
        Remove inaccurate multiworld metrics.

        When training in multitask mode, agent-level metrics may be shown,
        but are actually averages
        not distinguished across the worlds. This method adds a warning.

        Issue: https://github.com/facebookresearch/ParlAI/issues/1750
        """
        # TODO: fix the root issue
        if 'tasks' in metrics:
            metrics[
                'warning'] = 'agent level metrics (e.g. loss, mean_loss, ppl) are averaged over tasks'

    def _sync_training_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining
        metrics are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return self._average_dicts(all_versions)

    def _nice_format(self, dictionary):
        rounded = {}
        for k, v in dictionary.items():
            if isinstance(v, dict):
                rounded[k] = self._nice_format(v)
            elif isinstance(v, float):
                rounded[k] = round_sigfigs(v, 4)
            else:
                rounded[k] = v
        return rounded

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        """Output a training log entry."""
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        self._cleanup_inaccurate_metrics(train_report)
        train_report = self._sync_training_metrics(train_report)
        self.world.reset_metrics()

        # time elapsed
        logs.append('time:{}s'.format(np.floor(self.train_time.time())))
        logs.append('total_exs:{}'.format(self._total_exs))

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append('epochs:{}'.format(round(self._total_epochs, 2)))

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append('time_left:{}s'.format(max(0, np.ceil(time_left))))

        log = '[ {} ] {}'.format(' '.join(logs),
                                 self._nice_format(train_report))
        print(log)
        self.log_time.reset()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
Ejemplo n.º 5
0
def eval_ppl(opt, build_dict=None, dict_file=None):
    """
    Evaluates the the perplexity of a model.

    This uses a dictionary which implements the following functions:
    - tokenize(text): splits string up into list of tokens
    - __in__(text): checks whether dictionary contains a token
    - keys(): returns an iterator over all tokens in the dictionary

    :param opt: option dict
    :param build_dict: function which returns a dictionary class implementing
        the functions above.
    :param dict_file: file used when loading the dictionary class set via the
        "dictionary_class" argument (defaults to
        parlai.core.dict:DictionaryAgent).

    Either build_dict or dict_file must be set (both default to None) to
    determine the dictionary used for the evaluation.
    """
    if not build_dict and not dict_file:
        raise RuntimeError('eval_ppl script either needs a dictionary build '
                           'function or a dictionary file.')

    if build_dict:
        dict_agent = build_dict()
    else:
        dict_opt = copy.deepcopy(opt)
        dict_opt['model'] = dict_opt.get('dictionary_class',
                                         'parlai.core.dict:DictionaryAgent')
        dict_opt['model_file'] = dict_file
        if 'override' in dict_opt:
            del dict_opt['override']
        dict_agent = create_agent(dict_opt, requireModelExists=True)

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['exs'] / world.num_examples() * 100, 3),
                report,
            ))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Ejemplo n.º 6
0
class TrainLoop:
    """
    TrainLoop contains the core training loop logic.
    """
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (opt['load_from_checkpoint'] and opt.get('model_file')
                and os.path.isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.')
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            elif opt['dict_file'] is None and opt.get('teacher_model_file'):
                logging.info("using teacher's dictionary...")
                opt['dict_file'] = opt['teacher_model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        # print(opt)
        # Create teacher model
        teacher_opt = {
            'datapath': 'blended_skill_talk',
            'model_file': opt['teacher_model_file'],
            # some custom args
            'tie_layers': False,
            'enable_checkpointing': False,
        }
        self.teacher_agent = create_agent(teacher_opt)
        self.agent.set_teacher_agent(self.teacher_agent)
        print(self.agent.model)
        print(self.teacher_agent.model)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))
        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))
        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))
        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))
        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))
        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.valid_reports = []
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    trainstats_suffix):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and os.path.isfile(
                            opt['model_file'] + '.best_valid'):
                        with open(opt['model_file'] + ".best_valid", 'r') as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)

    def save_model(self, suffix=None):
        """
        Save the model to disk, possibly with a suffix.
        """
        if not is_primary_worker():
            # never do IO as a non-primary worker
            return

        if not self.opt.get('model_file'):
            # nothing to save to, just exit
            return

        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        while True:
            # don't ever let a ctrl-c interrupt saving
            try:
                self.agent.save(fn)
                self._save_train_stats(suffix)
                break
            except KeyboardInterrupt:
                pass

    def _safe_report(self, report: Dict[str, Metric]):
        return {
            k: v.value() if isinstance(v, Metric) else v
            for k, v in report.items()
        }

    def _save_train_stats(self, suffix=None):
        fn = self.opt['model_file']
        if suffix:
            fn += suffix
        fn += '.trainstats'
        with open(fn, 'w') as f:
            json.dump(
                {
                    'parleys':
                    self.parleys,
                    'train_time':
                    self.train_time.time(),
                    'total_epochs':
                    (self._preempted_epochs +
                     num_workers() * self.world.get_total_epochs()),
                    'impatience':
                    self.impatience,
                    'valid_reports':
                    [self._safe_report(v) for v in self.valid_reports],
                    'best_valid':
                    self.best_valid,
                },
                f,
                indent=4,
            )

    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = self._run_eval(self.valid_worlds, opt, 'valid',
                                      opt['validation_max_exs'])
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            valid_report['total_exs'] = self._total_exs
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
            # flush on a validation
            self.tb_logger.flush()
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            logging.info(
                f"saving model checkpoint: {opt['model_file']}.checkpoint")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        new_valid = valid_report[opt['validation_metric']]

        if isinstance(new_valid, Metric):
            new_valid = new_valid.value()

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            logging.success('new best {}: {:.4g}{}'.format(
                opt['validation_metric'],
                new_valid,
                ' (previous best was {:.4g})'.format(self.best_valid)
                if self.best_valid is not None else '',
            ))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                logging.info(f"saving best valid model: {opt['model_file']}")
                self.save_model()
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                logging.info('task solved! stopping.')
                return True
        else:
            self.impatience += 1
            logging.report('did not beat best {}: {} impatience: {}'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            logging.info('ran out of patience! stopping training.')
            return True
        return False

    def _run_single_eval(self, opt, valid_world, max_exs):

        # run evaluation on a single world
        valid_world.reset()

        cnt = 0
        max_cnt = max_exs if max_exs > 0 else float('inf')
        while not valid_world.epoch_done() and cnt < max_cnt:
            valid_world.parley()
            if cnt == 0 and opt['display_examples']:
                print(valid_world.display() + '\n~~')
                print(valid_world.report())
            cnt = valid_world.report().get('exs') or 0

        valid_report = valid_world.report()
        valid_world.reset()  # make sure world doesn't remember valid data

        return valid_report

    def _run_eval(self,
                  valid_worlds,
                  opt,
                  datatype,
                  max_exs=-1,
                  write_log=False):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        for v_world in valid_worlds:
            task_report = self._run_single_eval(opt, v_world,
                                                max_exs_per_worker)
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(named_reports,
                                         micro_average=self.opt.get(
                                             'aggregate_micro', False))
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            f = open(opt['model_file'] + '.' + datatype, 'a+')
            f.write(f'{metrics}\n')
            f.close()

        return report

    def _sync_metrics(self, metrics):
        """
        Sync training metrics across workers.

        A handful of special cases are handled as exceptions, and the remaining metrics
        are simply averaged across workers.
        """
        if not is_distributed():
            # nothing special needed
            return metrics
        all_versions = all_gather_list(metrics)
        return aggregate_unnamed_reports(all_versions)

    def _compute_eta(self, epochs_completed, time_elapsed):
        """
        Compute the estimated seconds remaining in training.

        :param float epochs_completed: number of epochs already completed.
        :param float time_elapsed: total time spent already, in seconds.
        :return: ETA in seconds, or None if not computable
        """
        # start off with no estimate
        eta = None

        # Determine time_left and num_epochs
        max_epochs = self.opt.get('num_epochs', 0)
        if max_epochs > 0 and epochs_completed > 0:
            epoch_progress = epochs_completed / max_epochs
            eta = (1 - epoch_progress) * time_elapsed / epoch_progress

        max_training_time = self.opt.get('max_training_time', -1)
        if max_training_time > 0:
            time_left = max_training_time - time_elapsed
            if eta is None or time_left < eta:
                eta = time_left

        return eta

    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        # time elapsed
        logs.append(f'time:{self.train_time.time():.0f}s')
        logs.append(f'total_exs:{self._total_exs}')

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append(f'epochs:{self._total_epochs:.2f}')

        time_left = self._compute_eta(self._total_epochs,
                                      self.train_time.time())
        if time_left is not None:
            logs.append(f'time_left:{max(0,time_left):.0f}s')

        log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
        logging.info(log)
        self.log_time.reset()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)

    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not supported for distributed mode"
                            )
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = self._run_eval(valid_worlds,
                                  opt,
                                  'valid',
                                  max_exs,
                                  write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds,
                                  opt,
                                  'test',
                                  max_exs,
                                  write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report