def __init__(self, opt: Opt, agents=None, shared=None): self.id = opt['task'] self.opt = copy.deepcopy(opt) if shared: # Create agents based on shared data. self.agents = create_agents_from_shared(shared['agents']) else: # Add passed in agents to world directly. self.agents = agents self.max_exs = None self.total_exs = 0 self.total_epochs = 0 self.total_parleys = 0 self.time = Timer()
def _run_eval( self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, extra_log_suffix="", ): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports, micro_average=self.opt.get( 'aggregate_micro', False)) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics with PathManager.open( opt['model_file'] + extra_log_suffix + '.' + datatype, 'a') as f: f.write(f'{metrics}\n') return report
def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ print('[ running eval: ' + datatype + ' ]') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:{nice_report(report)}' print(f'[ eval completed in {timer.time():.2f}s ]') print(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(f'{metrics}\n') f.close() return report
def run_eval(valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ if valid_worlds is None: # This isn't the primary worker, so we can just skip evaluation return None print('[ running eval: ' + datatype + ' ]') timer = Timer() reports = [] for v_world in valid_worlds: task_report = _run_single_eval(opt, v_world, max_exs / len(valid_worlds)) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] report = aggregate_task_reports(reports, tasks, micro=opt.get('aggregate_micro', True)) metrics = f'{datatype}:{report}' print(f'[ eval completed in {timer.time():.2f}s ]') print(metrics) # write to file if write_log and opt.get('model_file'): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(metrics + '\n') f.close() return report
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model)
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix if not is_primary_worker(): # never do IO as a non-primary worker if hasattr(self.agent, 'save_nonprimary'): self.agent.save_nonprimary(fn) return while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker return fn = self.opt.get('model_file', None) if not fn: return if suffix: fn += suffix fn += '.trainstats' with PathManager.open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'train_steps': self._train_steps, 'total_epochs': self._total_epochs, 'train_reports': self.train_reports, 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, 'impatience': self.impatience, 'final_valid_report': dict_report(self.final_valid_report), 'final_test_report': dict_report(self.final_test_report), 'final_extra_valid_report': dict_report( self.final_extra_valid_report ), }, f, indent=4, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval( self.valid_worlds, opt, 'valid', opt['validation_max_exs'] ) v = dict_report(valid_report) v['train_time'] = self.train_time.time() v['parleys'] = self.parleys v['train_steps'] = self._train_steps v['total_exs'] = self._total_exs v['total_epochs'] = self._total_epochs self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() if opt['wandb_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.wb_logger.log_metrics('valid', self.parleys, valid_report) # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if ( self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid ): logging.success( 'new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', ) ) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if ( opt['validation_metric_mode'] == 'max' and self.best_valid >= opt['validation_cutoff'] ) or ( opt['validation_metric_mode'] == 'min' and self.best_valid <= opt['validation_cutoff'] ): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report( 'did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience ) ) self.validate_time.reset() # saving if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # check if we are out of patience if ( opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience'] ): logging.info('ran out of patience! stopping training.') return True return False def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task): # run evaluation on a single world valid_world.reset() world_logger = None task_opt = opt.copy() # set up world logger for the "test" fold if opt['world_logs'] and datatype == 'test': task_opt['world_logs'] = get_task_world_logs( task, opt['world_logs'], is_multitask ) world_logger = WorldLogger(task_opt) cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if world_logger is not None: world_logger.log(valid_world) if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, valid_world, file_format=opt['save_format']) valid_report = valid_world.report() if opt.get('validation_share_agent', False): valid_world.reset() # make sure world doesn't remember valid data return valid_report def _run_eval( self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, extra_log_suffix="", ): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) is_multitask = len(valid_worlds) > 1 for index, v_world in enumerate(valid_worlds): if opt.get('evaltask'): task = opt['evaltask'].split(',')[index] else: task = opt['task'].split(',')[index] task_report = self._run_single_eval( opt, v_world, max_exs_per_worker, datatype, is_multitask, task ) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports( named_reports, micro_average=self.opt.get('aggregate_micro', False) ) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics with PathManager.open( opt['model_file'] + extra_log_suffix + '.' + datatype, 'a' ) as f: f.write(f'{metrics}\n') return report def _run_final_extra_eval(self, opt): final_valid_opt = copy.deepcopy(opt) final_valid_opt_raw = Opt.load_init(opt['final_extra_opt']) final_datatype = final_valid_opt_raw["datatype"] for k, v in final_valid_opt_raw.items(): final_valid_opt[k] = v final_max_exs = ( final_valid_opt['validation_max_exs'] if final_valid_opt.get('short_final_eval') else -1 ) final_valid_world = load_eval_worlds( self.agent, final_valid_opt, final_datatype ) final_valid_report = self._run_eval( final_valid_world, final_valid_opt, final_datatype, final_max_exs, write_log=True, extra_log_suffix="_extra", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report) return final_valid_report def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta( self, epochs_completed: float, time_elapsed: float, steps_taken: int ): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left max_train_steps = self.opt.get('max_train_steps', -1) if max_train_steps > 0 and steps_taken > 0: steps_progress = steps_taken / max_train_steps eta = (1 - steps_progress) * time_elapsed / steps_progress return eta def _get_time(self, world: World) -> Tuple[float, float, float]: """ Return train, log, and validate timing. If relying on the time for validation/logging/max train time purposes, we sync and return primary worker's time. Otherwise, it's not super relevant what we do here. **SIDE EFFECT**: Update _total_epochs trained. :param world: current running world :return (train, log, valid): return time for each of train, log, and validation """ if ( self.max_train_time < float('inf') or self.log_every_n_secs < float('inf') or self.val_every_n_secs < float('inf') or self.val_every_n_epochs < float('inf') or self.max_num_epochs < float('inf') ): self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) ) else: train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) return train_time, log_time, validate_time, save_time def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() train_report_trainstats = dict_report(train_report) train_report_trainstats['total_epochs'] = self._total_epochs train_report_trainstats['total_exs'] = self._total_exs train_report_trainstats['parleys'] = self.parleys train_report_trainstats['train_steps'] = self._train_steps train_report_trainstats['train_time'] = self.train_time.time() self.train_reports.append(train_report_trainstats) # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') logs.append(f'total_steps:{self._train_steps}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta( self._total_epochs, self.train_time.time(), self._train_steps ) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() self._last_log_steps = 0 if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_metrics('train', self.parleys, train_report) return train_report def train_steps(self): """ Core training loop. Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 self._train_steps = self.parleys // self.update_freq self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # check counters and timers if self._total_epochs >= self.max_num_epochs: yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if self._train_steps >= self.max_train_steps: logging.info( f'max_train_steps elapsed:{self._train_steps} ' f'time elapsed:{train_time}s' ) break if ( log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs or self._train_steps - self._last_valid_steps >= self.val_every_n_steps ): try: # log before we validate if self._last_log_steps: yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs self._last_valid_steps = self._train_steps if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not sync_object(self.saved): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt for _train_log in self.train_steps(): # we've already done what we need in these pass # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 self.final_valid_report = self._run_eval( valid_worlds, opt, 'valid', max_exs, write_log=True ) test_worlds = load_eval_worlds(self.agent, opt, 'test') self.final_test_report = self._run_eval( test_worlds, opt, 'test', max_exs, write_log=True ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final('valid', self.final_valid_report) self.wb_logger.log_final('test', self.final_test_report) self.wb_logger.finish() if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) if opt['final_extra_opt'] != '': self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() self._save_train_stats() return self.final_valid_report, self.final_test_report
class World(object): """ Empty parent providing null definitions of API functions for Worlds. All children can override these to provide more detailed functionality. """ def __init__(self, opt: Opt, agents=None, shared=None): self.id = opt['task'] self.opt = copy.deepcopy(opt) if shared: # Create agents based on shared data. self.agents = create_agents_from_shared(shared['agents']) else: # Add passed in agents to world directly. self.agents = agents self.max_exs = None self.total_exs = 0 self.total_epochs = 0 self.total_parleys = 0 self.time = Timer() def parley(self): """ Perform one step of actions for the agents in the world. This is empty in the base class. """ # TODO: mark as abstract? pass def getID(self): """ Return the name of the world, typically the task the world encodes. """ return self.id def display(self): """ Return a string describing the current state of the world. Useful for monitoring and debugging. By default, display the messages between the agents. """ if not hasattr(self, 'acts'): return '' return display_messages( self.acts, ignore_fields=self.opt.get('display_ignore_fields', ''), prettify=self.opt.get('display_prettify', False), max_len=self.opt.get('max_display_len', 1000), ) def episode_done(self): """ Whether the episode is done or not. """ return False def epoch_done(self): """ Whether the epoch is done or not. Not all worlds have the notion of an epoch, but this is useful for fixed training, validation or test sets. """ return False def share(self): """ Share the world. """ shared_data = {} shared_data['world_class'] = type(self) shared_data['opt'] = self.opt shared_data['agents'] = self._share_agents() return shared_data def _share_agents(self): """ Create shared data for agents. Allows other classes to create the same agents without duplicating the data (i.e. sharing parameters). """ if not hasattr(self, 'agents'): return None shared_agents = [a.share() for a in self.agents] return shared_agents def get_agents(self): """ Return the list of agents. """ return self.agents def get_task_agent(self): """ Return task agent, if applicable. """ raise NotImplementedError('Implement in subworld') def get_acts(self): """ Return the last act of each agent. """ return self.acts def get_time(self): """ Return total training time. """ return self.time.time() def get_total_exs(self): """ Return total amount of examples seen by world. """ return self.total_exs def get_total_epochs(self): """ Return total amount of epochs on which the world has trained. """ return self.total_epochs def __enter__(self): """ Empty enter provided for use with ``with`` statement. e.g: .. code-block:: python with World() as world: for n in range(10): n.parley() """ return self def __exit__(self, exc_type, exc_value, exc_traceback): """ After ``with`` statement, call shutdown. """ self.shutdown() return False def num_examples(self): """ Return the number of examples. Always 0 in the abstract world. """ # TODO: mark as abstract? return 0 def num_episodes(self): """ Return the number of episodes. Always 0 in the abstract world. """ # TODO: mark as abstract? return 0 def reset(self): """ Reset all agents in the world, and world statistics. """ for a in self.agents: a.reset() self.max_exs = None self.total_exs = 0 self.total_epochs = 0 self.total_parleys = 0 self.time.reset() def reset_metrics(self): """ Reset metrics for all agents. """ for a in self.agents: a.reset_metrics() def shutdown(self): """ Perform any cleanup, if appropriate. """ pass def update_counters(self): """ Update how many epochs have completed. """ self.total_parleys += 1 if self.max_exs is None: if 'num_epochs' in self.opt and self.opt['num_epochs'] > 0: if self.num_examples: self.max_exs = self.num_examples() * self.opt['num_epochs'] else: self.max_exs = -1 else: self.max_exs = -1 # when we know the size of the data if self.max_exs > 0 or self.num_examples(): self.total_epochs = (self.total_parleys * self.opt.get('batchsize', 1) / self.num_examples()) # when we do not know the size of the data else: if self.epoch_done(): self.total_epochs += 1
def test_timer(self): t = Timer() time.sleep(1e-6) elapsed = t.stop().time() assert elapsed > 0 same = t.time() assert elapsed == same t.resume() time.sleep(1e-6) more = t.time() assert more > elapsed rabbit = Timer() time.sleep(1e-6) turtle = Timer() time.sleep(1e-6) assert turtle.time() > 0 assert turtle.time() < rabbit.time()
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _safe_report(self, report): return { k: v.value() if isinstance(v, Metric) else v for k, v in report.items() } def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': [self._safe_report(v) for v in self.valid_reports], 'best_valid': self.best_valid, }, f, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set # TODO(MW): replace sync_object with self._sync_metrics. You'll need some # logic to handle 'validation_max_exs' properly valid_report = run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), nice_report(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world count = 0 with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not " "supported for distributed mode") break self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if opt['tensorboard_log'] is True: self.tb_logger = TensorboardLogger(opt)
class TrainLoop: """TrainLoop contains the core training loop logic.""" def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if opt['tensorboard_log'] is True: self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """Save the model to disk, possibly with a suffix.""" if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': self.valid_reports, }, f, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = _maybe_load_eval_worlds( self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs'])) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if 'tasks' in valid_report and '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") _save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _average_dicts(self, all_versions): # instead of a list-of-dicts with like keys, make a dict-of-lists with # keys to reduce to_reduce = {} for d in all_versions: for k, v in d.items(): to_reduce.setdefault(k, []).append(v) # now perform the reduction finalized = {} for k, values in to_reduce.items(): if k == 'exs' or k == 'total_skipped_batches': # sum across workers finalized[k] = np.sum(values) elif isinstance(values[0], dict): # do the same procedure recursively finalized[k] = self._average_dicts(values) elif isinstance(values[0], str): finalized[k] = values[0] else: # all other cases, take the mean across the workers finalized[k] = np.mean(values) if all(isinstance(v, int) for v in values): finalized[k] = int(finalized[k]) return finalized def _cleanup_inaccurate_metrics(self, metrics): """ Remove inaccurate multiworld metrics. When training in multitask mode, agent-level metrics may be shown, but are actually averages not distinguished across the worlds. This method adds a warning. Issue: https://github.com/facebookresearch/ParlAI/issues/1750 """ # TODO: fix the root issue if 'tasks' in metrics: metrics[ 'warning'] = 'agent level metrics (e.g. loss, mean_loss, ppl) are averaged over tasks' def _sync_training_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return self._average_dicts(all_versions) def _nice_format(self, dictionary): rounded = {} for k, v in dictionary.items(): if isinstance(v, dict): rounded[k] = self._nice_format(v) elif isinstance(v, float): rounded[k] = round_sigfigs(v, 4) else: rounded[k] = v return rounded def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """Output a training log entry.""" opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() self._cleanup_inaccurate_metrics(train_report) train_report = self._sync_training_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), self._nice_format(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def eval_ppl(opt, build_dict=None, dict_file=None): """ Evaluates the the perplexity of a model. This uses a dictionary which implements the following functions: - tokenize(text): splits string up into list of tokens - __in__(text): checks whether dictionary contains a token - keys(): returns an iterator over all tokens in the dictionary :param opt: option dict :param build_dict: function which returns a dictionary class implementing the functions above. :param dict_file: file used when loading the dictionary class set via the "dictionary_class" argument (defaults to parlai.core.dict:DictionaryAgent). Either build_dict or dict_file must be set (both default to None) to determine the dictionary used for the evaluation. """ if not build_dict and not dict_file: raise RuntimeError('eval_ppl script either needs a dictionary build ' 'function or a dictionary file.') if build_dict: dict_agent = build_dict() else: dict_opt = copy.deepcopy(opt) dict_opt['model'] = dict_opt.get('dictionary_class', 'parlai.core.dict:DictionaryAgent') dict_opt['model_file'] = dict_file if 'override' in dict_opt: del dict_opt['override'] dict_agent = create_agent(dict_opt, requireModelExists=True) # create agents agent = create_agent(opt) world = create_task(opt, [agent, dict_agent], default_world=PerplexityWorld) # set up logging log_time = Timer() tot_time = 0 while not world.epoch_done(): world.parley() # process an example if log_time.time() > 1: # log every 1 sec tot_time += log_time.time() report = world.report() print('{}s elapsed, {}%% complete, {}'.format( int(tot_time), round_sigfigs(report['exs'] / world.num_examples() * 100, 3), report, )) log_time.reset() print('EPOCH DONE') tot_time += log_time.time() final_report = world.report() print('{}s elapsed: {}'.format(int(tot_time), final_report)) print("============================") print("FINAL PPL: " + str(final_report['ppl'])) if final_report.get('ppl', 0) == float('inf'): print('Note: you got inf perplexity. Consider adding (or raising) the ' 'minimum probability you assign to each possible word. If you ' 'assign zero probability to the correct token in the evaluation ' 'vocabulary, you get inf probability immediately.')
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' elif opt['dict_file'] is None and opt.get('teacher_model_file'): logging.info("using teacher's dictionary...") opt['dict_file'] = opt['teacher_model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # print(opt) # Create teacher model teacher_opt = { 'datapath': 'blended_skill_talk', 'model_file': opt['teacher_model_file'], # some custom args 'tie_layers': False, 'enable_checkpointing': False, } self.teacher_agent = create_agent(teacher_opt) self.agent.set_teacher_agent(self.teacher_agent) print(self.agent.model) print(self.teacher_agent.model) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _safe_report(self, report: Dict[str, Metric]): return { k: v.value() if isinstance(v, Metric) else v for k, v in report.items() } def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': [self._safe_report(v) for v in self.valid_reports], 'best_valid': self.best_valid, }, f, indent=4, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): logging.success('new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report('did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): logging.info('ran out of patience! stopping training.') return True return False def _run_single_eval(self, opt, valid_world, max_exs): # run evaluation on a single world valid_world.reset() cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 valid_report = valid_world.report() valid_world.reset() # make sure world doesn't remember valid data return valid_report def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) for v_world in valid_worlds: task_report = self._run_single_eval(opt, v_world, max_exs_per_worker) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports, micro_average=self.opt.get( 'aggregate_micro', False)) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(f'{metrics}\n') f.close() return report def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for distributed mode" ) break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report