def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ]\n{}\n'.format(' '.join(logs), nice_report(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report)
def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) # --------------- change by hengyicai ------------------------- teacher_agent = self.return_teacher_agent() if teacher_agent: teacher_fn = fn + '.teacher' teacher_agent.save(teacher_fn) # --------------- change by hengyicai ------------------------- break except KeyboardInterrupt: pass
def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() train_report_trainstats = dict_report(train_report) train_report_trainstats['total_epochs'] = self._total_epochs train_report_trainstats['total_exs'] = self._total_exs train_report_trainstats['parleys'] = self.parleys train_report_trainstats['train_steps'] = self._train_steps train_report_trainstats['train_time'] = self.train_time.time() self.train_reports.append(train_report_trainstats) # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') logs.append(f'total_steps:{self._train_steps}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta( self._total_epochs, self.train_time.time(), self._train_steps ) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() self._last_log_steps = 0 if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_metrics('train', self.parleys, train_report) return train_report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt for _train_log in self.train_steps(): # we've already done what we need in these pass # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 self.final_valid_report = self._run_eval( valid_worlds, opt, 'valid', max_exs, write_log=True ) test_worlds = load_eval_worlds(self.agent, opt, 'test') self.final_test_report = self._run_eval( test_worlds, opt, 'test', max_exs, write_log=True ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final('valid', self.final_valid_report) self.wb_logger.log_final('test', self.final_test_report) self.wb_logger.finish() if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) if opt['final_extra_opt'] != '': self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() self._save_train_stats() return self.final_valid_report, self.final_test_report
def _run_eval( self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, extra_log_suffix="", ): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports, micro_average=self.opt.get( 'aggregate_micro', False)) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics with PathManager.open( opt['model_file'] + extra_log_suffix + '.' + datatype, 'a') as f: f.write(f'{metrics}\n') return report
def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ print('[ running eval: ' + datatype + ' ]') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:{nice_report(report)}' print(f'[ eval completed in {timer.time():.2f}s ]') print(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(f'{metrics}\n') f.close() return report
def load_eval_worlds(agent, opt, datatype): """ Create a new eval world for the agent and the given opt. Overrides the datatype options for doing this. Handles some magic overrides of other special options for the training script. :param Agent agent: The model being trained. :param Opt opt: The global CLI opts. :param string datatype: The new datatype. """ if not is_primary_worker(): # don't load worlds in workers # TODO(MW): this block will need to be removed return None if 'stream' in opt['datatype']: datatype += ':stream' opt = opt.copy() opt['datatype'] = datatype if opt.get('evaltask'): # if a different eval task is specified, use it. opt['task'] = opt['evaltask'] if opt.get('eval_batchsize'): # override eval time batchsize opt['batchsize'] = opt['eval_batchsize'] tasks = opt['task'].split(',') worlds = [] # possibly load agent if opt.get('validation_share_agent', False): valid_agent = create_agent_from_shared(agent.share()) else: valid_agent = agent # create worlds for task in tasks: task_opt = opt.copy() # copy opt since we edit the task task_opt['task'] = task valid_world = create_task(task_opt, valid_agent) worlds.append(valid_world) return worlds
def _run_final_extra_eval(self, opt): final_valid_opt = copy.deepcopy(opt) final_valid_opt_raw = Opt.load_init(opt['final_extra_opt']) final_datatype = final_valid_opt_raw["datatype"] for k, v in final_valid_opt_raw.items(): final_valid_opt[k] = v final_max_exs = (final_valid_opt['validation_max_exs'] if final_valid_opt.get('short_final_eval') else -1) final_valid_world = load_eval_worlds(self.agent, final_valid_opt, final_datatype) final_valid_report = self._run_eval( final_valid_world, final_valid_opt, final_datatype, final_max_exs, write_log=True, extra_log_suffix="_extra", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report) return final_valid_report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- if opt.get('run_test_after_validation', False): # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_worlds is None: # we need to load the world now self.test_worlds = _maybe_load_eval_worlds( self.agent, test_opt, 'test') run_eval(self.test_worlds, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print( '[ Conducting evaluations on valid and test data using the last model. ]' ) last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt, 'valid') max_exs = last_model_opt[ 'validation_max_exs'] if last_model_opt.get( 'short_final_eval') else -1 run_eval(valid_worlds, last_model_opt, 'valid', max_exs, write_log=True) test_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt, 'test') run_eval(test_worlds, last_model_opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report
def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval( self.valid_worlds, opt, 'valid', opt['validation_max_exs'] ) v = dict_report(valid_report) v['train_time'] = self.train_time.time() v['parleys'] = self.parleys v['train_steps'] = self._train_steps v['total_exs'] = self._total_exs v['total_epochs'] = self._total_epochs self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() if opt['wandb_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.wb_logger.log_metrics('valid', self.parleys, valid_report) # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if ( self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid ): logging.success( 'new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', ) ) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if ( opt['validation_metric_mode'] == 'max' and self.best_valid >= opt['validation_cutoff'] ) or ( opt['validation_metric_mode'] == 'min' and self.best_valid <= opt['validation_cutoff'] ): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report( 'did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience ) ) self.validate_time.reset() # saving if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # check if we are out of patience if ( opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience'] ): logging.info('ran out of patience! stopping training.') return True return False
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model)
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world count = 0 with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not " "supported for distributed mode") break self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set # TODO(MW): replace sync_object with self._sync_metrics. You'll need some # logic to handle 'validation_max_exs' properly valid_report = run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for distributed mode" ) break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- if opt.get('run_test_after_validation', False): # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_worlds is None: # we need to load the world now self.test_worlds = load_eval_worlds( self.agent, test_opt, 'test') run_eval(self.test_worlds, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print( '[ Conducting evaluations on valid and test data using the last model. ]' ) last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_worlds = load_eval_worlds(last_agent, last_model_opt, 'valid') max_exs = last_model_opt[ 'validation_max_exs'] if last_model_opt.get( 'short_final_eval') else -1 run_eval(valid_worlds, last_model_opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(last_agent, last_model_opt, 'test') run_eval(test_worlds, last_model_opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report
def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = _maybe_load_eval_worlds( self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs'])) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if 'tasks' in valid_report and '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") _save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False
def _maybe_load_eval_worlds(agent, opt, datatype): if not is_primary_worker(): # only need the validation on the main worker return None return load_eval_worlds(agent, opt, datatype)
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.checkpoint_counter = 0 print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt)
def train_steps(self): """ Core training loop. Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 self._train_steps = self.parleys // self.update_freq self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # check counters and timers if self._total_epochs >= self.max_num_epochs: yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if self._train_steps >= self.max_train_steps: logging.info( f'max_train_steps elapsed:{self._train_steps} ' f'time elapsed:{train_time}s' ) break if ( log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs or self._train_steps - self._last_valid_steps >= self.val_every_n_steps ): try: # log before we validate if self._last_log_steps: yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs self._last_valid_steps = self._train_steps if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not sync_object(self.saved): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt)
def batch_act(self, observations): """ Process a batch of observations (batchsize list of message dicts). These observations have been preprocessed by the observe method. Subclasses can override this for special functionality, but if the default behaviors are fine then just override the ``train_step`` and ``eval_step`` methods instead. The former is called when labels are present in the observations batch; otherwise, the latter is called. """ # clear local metrics before anything else self._local_metrics.clear() # initialize a list of replies with this agent's id batch_reply = [ Message({ 'id': self.getID(), 'episode_done': False }) for _ in observations ] # check if there are any labels available, if so we will train on them self.is_training = any('labels_1' in obs for obs in observations) # create a batch from the vectors batch = self.batchify(observations) if ('label_1_vec' in batch and 'text_0_vec' in batch and batch.label_1_vec is not None and batch.text_1_vec is not None): # tokens per batch # we divide by the binary is_primary_worker() so that the numerator is # num_tokens in all workers, and the denominator is 1. # TODO HRED: add to this tbp = AverageMetric( (batch.label_1_vec != self.NULL_IDX).sum().item() + (batch.text_0_vec != self.NULL_IDX).sum().item(), float(is_primary_worker()), ) self.global_metrics.add('tokens_per_batch', tbp) if self.is_training: output = self.train_step(batch) else: with torch.no_grad(): # save memory and compute by disabling autograd. # use `with torch.enable_grad()` to gain back gradients. output = self.eval_step(batch) #print(f"output of eval is {output}") if output is not None: # local metrics are automatically matched up self.match_batch(batch_reply, batch.valid_indices, output) # broadcast the metrics back for k, values in self._local_metrics.items(): if len(values) != len(batch.valid_indices): raise IndexError( f"Batchsize mismatch on metric {k} (got {len(values)}, " f"expected {len(batch.valid_indices)}") for i, value in zip(batch.valid_indices, values): if 'metrics' not in batch_reply[i]: batch_reply[i]['metrics'] = {} batch_reply[i]['metrics'][k] = value # Make sure we push all the metrics to main thread in hogwild/workers self.global_metrics.flush() #print(f"batch reply is {batch_reply}") return batch_reply
def _save_outputs(self, opt, world, logger, episode_metrics): if is_distributed(): # flatten everything intelligently if need be world_report = aggregate_unnamed_reports( all_gather_list(world.report())) episode_metrics_unflattened = all_gather_list(episode_metrics) flattened = [] for rank_elem in episode_metrics_unflattened: for elem in rank_elem: flattened.append(elem) episode_metrics = flattened else: world_report = world.report() logging.report("Final report:\n" + nice_report(world_report)) report = dict_report(world_report) def get_episode_report(goal, episode_metric): metrics_dict = dict_report(episode_metric.report()) metrics_dict["goal"] = goal return metrics_dict report["tod_metrics"] = [ get_episode_report(g, e) for g, e in episode_metrics ] if "report_filename" in opt and opt["report_filename"] is not None: if len(world_report) == 0: logging.warning("Report is empty; not saving report") report_fname = f"{opt['report_filename']}.json" # Save report if not is_distributed() or is_primary_worker(): with PathManager.open(report_fname, "w") as f: logging.info(f"Saving model report to {report_fname}") json.dump({"opt": opt, "report": report}, f, indent=4) f.write("\n") # for jq if "world_logs" in opt and opt["world_logs"] is not None: if is_distributed(): # Save separately, then aggregate together rank = get_rank() log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl") logger.write(log_outfile_part, world, file_format=opt["save_format"]) sync_object(None) if is_primary_worker(): log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" log_outfile_metadata = ( f"{opt['world_logs']}_{opt['save_format']}.metadata") with open(log_outfile, "w+") as outfile: for rank in range(num_workers()): log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" ) with open(log_outfile_part) as infile: for line in infile: json_blob = json.loads(line.strip()) if ( len(json_blob["dialog"]) < 2 ): # skip when we don't have generation continue json_blob[ "metadata_path"] = log_outfile_metadata outfile.write(json.dumps(json_blob)) outfile.write("\n") log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" if rank == 0: copyfile(log_output_part_metadata, log_outfile_metadata), os.remove(log_outfile_part) os.remove(log_output_part_metadata) else: log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" logger.write(log_outfile, world, file_format=opt["save_format"]) return report