def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions)
def _eval_single_world(opt, agent, task): logging.info( f'Evaluating task {task} using datatype {opt.get("datatype")}.') # set up world logger world_logger = WorldLogger(opt) if opt['world_logs'] else None task_opt = opt.copy() # copy opt since we're editing the task task_opt['task'] = task world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 total_cnt = world.num_examples() if is_distributed(): logging.warning('Progress bar is approximate in distributed mode.') while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() if world_logger is not None: world_logger.log(world) if opt['display_examples']: # display examples print(world.display() + '\n~~') if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report.get('exs', 0), min(max_cnt, total_cnt), report) logging.info(text) if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = opt['world_logs'] world_logger.write(outfile, world, file_format=opt['save_format']) report = aggregate_unnamed_reports(all_gather_list(world.report())) world.reset() return report
def _get_time(self, world: World) -> Tuple[float, float, float]: """ Return train, log, and validate timing. If relying on the time for validation/logging/max train time purposes, we sync and return primary worker's time. Otherwise, it's not super relevant what we do here. **SIDE EFFECT**: Update _total_epochs trained. :param world: current running world :return (train, log, valid): return time for each of train, log, and validation """ if ( self.max_train_time < float('inf') or self.log_every_n_secs < float('inf') or self.val_every_n_secs < float('inf') or self.val_every_n_epochs < float('inf') or self.max_num_epochs < float('inf') ): self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) ) else: train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) return train_time, log_time, validate_time, save_time
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def _eval_single_world(opt, agent, task): logging.info( f'Evaluating task {task} using datatype {opt.get("datatype")}.') # set up world logger task_opt = opt.copy() # copy opt since we're editing the task task_opt['task'] = task # add task suffix in case of multi-tasking if opt['world_logs']: task_opt['world_logs'] = get_task_world_logs( task, task_opt['world_logs'], is_multitask=len(opt['task'].split(',')) > 1) world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None world = create_task(task_opt, agent) # create worlds for tasks # set up logging log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 total_cnt = world.num_examples() if is_distributed(): logging.warning('Progress bar is approximate in distributed mode.') while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) world.parley() if world_logger is not None: world_logger.log(world) if opt['display_examples']: # display examples print(world.display() + '\n~~') if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report.get('exs', 0), min(max_cnt, total_cnt), report) logging.info(text) if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, world, file_format=opt['save_format']) report = aggregate_unnamed_reports(all_gather_list(world.report())) if isinstance(world.agents, list) and len(world.agents) > 1: classifier_agent = world.agents[CLASSIFIER_AGENT] if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc: for class_indices, curr_auc in zip( classifier_agent.auc_class_indices, classifier_agent.aucs): report[ f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc classifier_agent.reset_auc() # for safety measures agent.reset_auc() world.reset() return report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for distributed mode" ) break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def _save_outputs(self, opt, world, logger, episode_metrics): if is_distributed(): # flatten everything intelligently if need be world_report = aggregate_unnamed_reports( all_gather_list(world.report())) episode_metrics_unflattened = all_gather_list(episode_metrics) flattened = [] for rank_elem in episode_metrics_unflattened: for elem in rank_elem: flattened.append(elem) episode_metrics = flattened else: world_report = world.report() logging.report("Final report:\n" + nice_report(world_report)) report = dict_report(world_report) def get_episode_report(goal, episode_metric): metrics_dict = dict_report(episode_metric.report()) metrics_dict["goal"] = goal return metrics_dict report["tod_metrics"] = [ get_episode_report(g, e) for g, e in episode_metrics ] if "report_filename" in opt and opt["report_filename"] is not None: if len(world_report) == 0: logging.warning("Report is empty; not saving report") report_fname = f"{opt['report_filename']}.json" # Save report if not is_distributed() or is_primary_worker(): with PathManager.open(report_fname, "w") as f: logging.info(f"Saving model report to {report_fname}") json.dump({"opt": opt, "report": report}, f, indent=4) f.write("\n") # for jq if "world_logs" in opt and opt["world_logs"] is not None: if is_distributed(): # Save separately, then aggregate together rank = get_rank() log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl") logger.write(log_outfile_part, world, file_format=opt["save_format"]) sync_object(None) if is_primary_worker(): log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" log_outfile_metadata = ( f"{opt['world_logs']}_{opt['save_format']}.metadata") with open(log_outfile, "w+") as outfile: for rank in range(num_workers()): log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" ) with open(log_outfile_part) as infile: for line in infile: json_blob = json.loads(line.strip()) if ( len(json_blob["dialog"]) < 2 ): # skip when we don't have generation continue json_blob[ "metadata_path"] = log_outfile_metadata outfile.write(json.dumps(json_blob)) outfile.write("\n") log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" if rank == 0: copyfile(log_output_part_metadata, log_outfile_metadata), os.remove(log_outfile_part) os.remove(log_output_part_metadata) else: log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" logger.write(log_outfile, world, file_format=opt["save_format"]) return report