class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _safe_report(self, report): return { k: v.value() if isinstance(v, Metric) else v for k, v in report.items() } def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': [self._safe_report(v) for v in self.valid_reports], 'best_valid': self.best_valid, }, f, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set # TODO(MW): replace sync_object with self._sync_metrics. You'll need some # logic to handle 'validation_max_exs' properly valid_report = run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), nice_report(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world count = 0 with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not " "supported for distributed mode") break self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix if not is_primary_worker(): # never do IO as a non-primary worker if hasattr(self.agent, 'save_nonprimary'): self.agent.save_nonprimary(fn) return while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker return fn = self.opt.get('model_file', None) if not fn: return if suffix: fn += suffix fn += '.trainstats' with PathManager.open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'train_steps': self._train_steps, 'total_epochs': self._total_epochs, 'train_reports': self.train_reports, 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, 'impatience': self.impatience, 'final_valid_report': dict_report(self.final_valid_report), 'final_test_report': dict_report(self.final_test_report), 'final_extra_valid_report': dict_report( self.final_extra_valid_report ), }, f, indent=4, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval( self.valid_worlds, opt, 'valid', opt['validation_max_exs'] ) v = dict_report(valid_report) v['train_time'] = self.train_time.time() v['parleys'] = self.parleys v['train_steps'] = self._train_steps v['total_exs'] = self._total_exs v['total_epochs'] = self._total_epochs self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() if opt['wandb_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.wb_logger.log_metrics('valid', self.parleys, valid_report) # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if ( self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid ): logging.success( 'new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', ) ) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if ( opt['validation_metric_mode'] == 'max' and self.best_valid >= opt['validation_cutoff'] ) or ( opt['validation_metric_mode'] == 'min' and self.best_valid <= opt['validation_cutoff'] ): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report( 'did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience ) ) self.validate_time.reset() # saving if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # check if we are out of patience if ( opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience'] ): logging.info('ran out of patience! stopping training.') return True return False def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task): # run evaluation on a single world valid_world.reset() world_logger = None task_opt = opt.copy() # set up world logger for the "test" fold if opt['world_logs'] and datatype == 'test': task_opt['world_logs'] = get_task_world_logs( task, opt['world_logs'], is_multitask ) world_logger = WorldLogger(task_opt) cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if world_logger is not None: world_logger.log(valid_world) if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, valid_world, file_format=opt['save_format']) valid_report = valid_world.report() if opt.get('validation_share_agent', False): valid_world.reset() # make sure world doesn't remember valid data return valid_report def _run_eval( self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, extra_log_suffix="", ): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) is_multitask = len(valid_worlds) > 1 for index, v_world in enumerate(valid_worlds): if opt.get('evaltask'): task = opt['evaltask'].split(',')[index] else: task = opt['task'].split(',')[index] task_report = self._run_single_eval( opt, v_world, max_exs_per_worker, datatype, is_multitask, task ) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports( named_reports, micro_average=self.opt.get('aggregate_micro', False) ) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics with PathManager.open( opt['model_file'] + extra_log_suffix + '.' + datatype, 'a' ) as f: f.write(f'{metrics}\n') return report def _run_final_extra_eval(self, opt): final_valid_opt = copy.deepcopy(opt) final_valid_opt_raw = Opt.load_init(opt['final_extra_opt']) final_datatype = final_valid_opt_raw["datatype"] for k, v in final_valid_opt_raw.items(): final_valid_opt[k] = v final_max_exs = ( final_valid_opt['validation_max_exs'] if final_valid_opt.get('short_final_eval') else -1 ) final_valid_world = load_eval_worlds( self.agent, final_valid_opt, final_datatype ) final_valid_report = self._run_eval( final_valid_world, final_valid_opt, final_datatype, final_max_exs, write_log=True, extra_log_suffix="_extra", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report) return final_valid_report def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta( self, epochs_completed: float, time_elapsed: float, steps_taken: int ): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left max_train_steps = self.opt.get('max_train_steps', -1) if max_train_steps > 0 and steps_taken > 0: steps_progress = steps_taken / max_train_steps eta = (1 - steps_progress) * time_elapsed / steps_progress return eta def _get_time(self, world: World) -> Tuple[float, float, float]: """ Return train, log, and validate timing. If relying on the time for validation/logging/max train time purposes, we sync and return primary worker's time. Otherwise, it's not super relevant what we do here. **SIDE EFFECT**: Update _total_epochs trained. :param world: current running world :return (train, log, valid): return time for each of train, log, and validation """ if ( self.max_train_time < float('inf') or self.log_every_n_secs < float('inf') or self.val_every_n_secs < float('inf') or self.val_every_n_epochs < float('inf') or self.max_num_epochs < float('inf') ): self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) ) else: train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) return train_time, log_time, validate_time, save_time def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() train_report_trainstats = dict_report(train_report) train_report_trainstats['total_epochs'] = self._total_epochs train_report_trainstats['total_exs'] = self._total_exs train_report_trainstats['parleys'] = self.parleys train_report_trainstats['train_steps'] = self._train_steps train_report_trainstats['train_time'] = self.train_time.time() self.train_reports.append(train_report_trainstats) # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') logs.append(f'total_steps:{self._train_steps}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta( self._total_epochs, self.train_time.time(), self._train_steps ) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() self._last_log_steps = 0 if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_metrics('train', self.parleys, train_report) return train_report def train_steps(self): """ Core training loop. Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 self._train_steps = self.parleys // self.update_freq self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # check counters and timers if self._total_epochs >= self.max_num_epochs: yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if self._train_steps >= self.max_train_steps: logging.info( f'max_train_steps elapsed:{self._train_steps} ' f'time elapsed:{train_time}s' ) break if ( log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs or self._train_steps - self._last_valid_steps >= self.val_every_n_steps ): try: # log before we validate if self._last_log_steps: yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs self._last_valid_steps = self._train_steps if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not sync_object(self.saved): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt for _train_log in self.train_steps(): # we've already done what we need in these pass # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 self.final_valid_report = self._run_eval( valid_worlds, opt, 'valid', max_exs, write_log=True ) test_worlds = load_eval_worlds(self.agent, opt, 'test') self.final_test_report = self._run_eval( test_worlds, opt, 'test', max_exs, write_log=True ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final('valid', self.final_valid_report) self.wb_logger.log_final('test', self.final_test_report) self.wb_logger.finish() if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) if opt['final_extra_opt'] != '': self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() self._save_train_stats() return self.final_valid_report, self.final_test_report
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' elif opt['dict_file'] is None and opt.get('teacher_model_file'): logging.info("using teacher's dictionary...") opt['dict_file'] = opt['teacher_model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # print(opt) # Create teacher model teacher_opt = { 'datapath': 'blended_skill_talk', 'model_file': opt['teacher_model_file'], # some custom args 'tie_layers': False, 'enable_checkpointing': False, } self.teacher_agent = create_agent(teacher_opt) self.agent.set_teacher_agent(self.teacher_agent) print(self.agent.model) print(self.teacher_agent.model) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _safe_report(self, report: Dict[str, Metric]): return { k: v.value() if isinstance(v, Metric) else v for k, v in report.items() } def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': [self._safe_report(v) for v in self.valid_reports], 'best_valid': self.best_valid, }, f, indent=4, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): logging.success('new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report('did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): logging.info('ran out of patience! stopping training.') return True return False def _run_single_eval(self, opt, valid_world, max_exs): # run evaluation on a single world valid_world.reset() cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 valid_report = valid_world.report() valid_world.reset() # make sure world doesn't remember valid data return valid_report def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports, micro_average=self.opt.get( 'aggregate_micro', False)) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(f'{metrics}\n') f.close() return report def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for distributed mode" ) break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report