def multiprocess_train( rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' ): """ Subprocess which initializes distributed training, and begins training. This should be launched n times for n GPUs; this is handled either in main or via srun. :param int rank: This process's rank - 1. (Starts at -1 ... n - 2). See comments. :param opt: command line options :param int port: A TCP port to use. This will need to be changed to run multiple distributed training setups on the same machine. :param int gpu: Which GPU to use. Defaults to using rank and local devices, but must be manually specified when using many-hosts. :param str hostname: Hostname of the main server. """ # Set per-host options opt = copy.deepcopy(opt) # we need to manually adjust the rank differently in multiprocessing # and distributed train rank = rank + rank_offset opt['rank'] = rank if gpu is None: # default assumption is local GPUs gpu = rank % torch.cuda.device_count() opt['gpu'] = gpu # make sure we don't just use whatever GPU was saved in the model file if 'override' not in opt: opt['override'] = {} opt['override']['gpu'] = gpu # Suppress output of workers except the main host. if opt.get('verbose') or rank != 0: print_prefix = '[rank:{:3d}]'.format(rank) else: print_prefix = None suppress_output = not opt.get('verbose') and rank != 0 with distributed_utils.override_print(suppress_output, print_prefix): # perform distributed setup, ensuring all hosts are ready if opt['gpu'] != -1: torch.cuda.set_device(opt['gpu']) dist.init_process_group( backend="nccl", init_method="tcp://{}:{}".format(hostname, port), world_size=opt['distributed_world_size'], rank=rank, ) logging.info("Distributed group initialized") # manual_seed can be a noop without this torch.cuda.init() # make sure all parameters will be in sync torch.manual_seed(42) # force a sync so that no one gets ahead, and all are seeded together distributed_utils.sync_object(None) # Run the actual training return single_train.TrainLoop(opt).train()
def run_eval(valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ if valid_worlds is None: # This isn't the primary worker, so we can just skip evaluation return sync_object(None) print('[ running eval: ' + datatype + ' ]') timer = Timer() reports = [] for v_world in valid_worlds: task_report = _run_single_eval(opt, v_world, max_exs / len(valid_worlds)) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports(named_reports) metrics = f'{datatype}:{nice_report(report)}' print(f'[ eval completed in {timer.time():.2f}s ]') print(metrics) # write to file if write_log and opt.get('model_file'): # Write out metrics f = open(opt['model_file'] + '.' + datatype, 'a+') f.write(f'{metrics}\n') f.close() return sync_object(report)
def _get_time(self, world: World) -> Tuple[float, float, float]: """ Return train, log, and validate timing. If relying on the time for validation/logging/max train time purposes, we sync and return primary worker's time. Otherwise, it's not super relevant what we do here. **SIDE EFFECT**: Update _total_epochs trained. :param world: current running world :return (train, log, valid): return time for each of train, log, and validation """ if ( self.max_train_time < float('inf') or self.log_every_n_secs < float('inf') or self.val_every_n_secs < float('inf') or self.val_every_n_epochs < float('inf') or self.max_num_epochs < float('inf') ): self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) ) else: train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) return train_time, log_time, validate_time, save_time
def train_steps(self): """ Core training loop. Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 self._train_steps = self.parleys // self.update_freq self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # check counters and timers if self._total_epochs >= self.max_num_epochs: yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if self._train_steps >= self.max_train_steps: logging.info( f'max_train_steps elapsed:{self._train_steps} ' f'time elapsed:{train_time}s' ) break if ( log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs or self._train_steps - self._last_valid_steps >= self.val_every_n_steps ): try: # log before we validate if self._last_log_steps: yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs self._last_valid_steps = self._train_steps if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not sync_object(self.saved): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt)
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world count = 0 with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not " "supported for distributed mode") break self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = _maybe_load_eval_worlds( self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs'])) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if 'tasks' in valid_report and '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") _save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False
def _save_outputs(self, opt, world, logger, episode_metrics): if is_distributed(): # flatten everything intelligently if need be world_report = aggregate_unnamed_reports( all_gather_list(world.report())) episode_metrics_unflattened = all_gather_list(episode_metrics) flattened = [] for rank_elem in episode_metrics_unflattened: for elem in rank_elem: flattened.append(elem) episode_metrics = flattened else: world_report = world.report() logging.report("Final report:\n" + nice_report(world_report)) report = dict_report(world_report) def get_episode_report(goal, episode_metric): metrics_dict = dict_report(episode_metric.report()) metrics_dict["goal"] = goal return metrics_dict report["tod_metrics"] = [ get_episode_report(g, e) for g, e in episode_metrics ] if "report_filename" in opt and opt["report_filename"] is not None: if len(world_report) == 0: logging.warning("Report is empty; not saving report") report_fname = f"{opt['report_filename']}.json" # Save report if not is_distributed() or is_primary_worker(): with PathManager.open(report_fname, "w") as f: logging.info(f"Saving model report to {report_fname}") json.dump({"opt": opt, "report": report}, f, indent=4) f.write("\n") # for jq if "world_logs" in opt and opt["world_logs"] is not None: if is_distributed(): # Save separately, then aggregate together rank = get_rank() log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl") logger.write(log_outfile_part, world, file_format=opt["save_format"]) sync_object(None) if is_primary_worker(): log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" log_outfile_metadata = ( f"{opt['world_logs']}_{opt['save_format']}.metadata") with open(log_outfile, "w+") as outfile: for rank in range(num_workers()): log_outfile_part = ( f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" ) with open(log_outfile_part) as infile: for line in infile: json_blob = json.loads(line.strip()) if ( len(json_blob["dialog"]) < 2 ): # skip when we don't have generation continue json_blob[ "metadata_path"] = log_outfile_metadata outfile.write(json.dumps(json_blob)) outfile.write("\n") log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" if rank == 0: copyfile(log_output_part_metadata, log_outfile_metadata), os.remove(log_outfile_part) os.remove(log_output_part_metadata) else: log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" logger.write(log_outfile, world, file_format=opt["save_format"]) return report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- if opt.get('run_test_after_validation', False): # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_worlds is None: # we need to load the world now self.test_worlds = _maybe_load_eval_worlds( self.agent, test_opt, 'test') run_eval(self.test_worlds, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print( '[ Conducting evaluations on valid and test data using the last model. ]' ) last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt, 'valid') max_exs = last_model_opt[ 'validation_max_exs'] if last_model_opt.get( 'short_final_eval') else -1 run_eval(valid_worlds, last_model_opt, 'valid', max_exs, write_log=True) test_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt, 'test') run_eval(test_worlds, last_model_opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report
def train(self): """ Perform ftml training run. :return: tuple of reports (validation_report, test_report) """ logging.info('training...') opt = self.opt world = self.world teacher = world.agents[0] student = world.agents[1] more_data_in_domain = True eval_data = {x: [] for x in teacher.domains} with world: shuffled_domains = [ x for x in teacher.domains if x not in ['police', 'hospital'] ] random.shuffle(shuffled_domains) for d, domain in enumerate(shuffled_domains): N = len(teacher.domain_convo_inds[domain]) teacher.add_domain(domain) teacher.add_all_domain_data(domain) self.best_valid = None stop_training = False if opt['no_multi_task']: # only fine-tune to each domain, so don't enter the multi-task while loop stop_training = True self._total_epochs = 0 self._total_exs = 0 while not stop_training: # This is for multi-tasking a global model for _ in range( int(teacher.num_episodes() / opt['num_episode_batch'])): world.batch_parley() self.parleys += 1 # TODO: I think this should be set correctly. Helps tracking, needed for termination? self._total_epochs = 0 self._total_exs = 0 train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) if log_time > self.log_every_n_secs: self.log() self.write_log('finished %s parleys' % self.parleys) self.write_log('Learning rate before valid: %s ' % world.agents[1].optimizer.state_dict() ['param_groups'][0]['lr']) # validation_decreasing = # todo # todo: add validation here to tell when to stop updating the meta model. # This is harder to do, as the validation is of the meta model.... for w in self.valid_worlds: for dd in teacher.added_domains(): w.reset( ) # Should also reset the teacher.index.value --> -1, but keep the domain fixed. w.agents[0].add_domain(dd) w.agents[0].add_all_domain_data(dd) # Fix validation teacher to domains training teacher has seen. w.agents[0].fix_teacher_domain(teacher.added_domains()) w.agents[ 0].index.value = -1 # reset index because we'll stream through the training data. w.agents[0].entry_idx = 0 if not opt['validation_metric'].startswith('bleu'): student.skip_generation = True stop_training = self.validate() logging.info('Multi-task model validation value: %s ' % self.best_valid) # After the multi-task model is trained, fine tune model for each domain. M = copy.deepcopy(world.agents[1].model.state_dict()) optim_state = copy.deepcopy(world.agents[1].optimizer.state_dict()) for dd in teacher.domains: # Restrict valid_world teachers to chosen domain for fine-tuning for w in self.valid_worlds: w.reset( ) # Should also reset the teacher.index.value --> -1, but keep the domain fixed. w.agents[0].add_domain(dd) w.agents[0].add_all_domain_data(dd) w.agents[0].fix_teacher_domain([dd]) w.agents[ 0].index.value = -1 # reset index because we'll stream through the data. w.agents[0].entry_idx = 0 # Restrict test worlds to chosen domain for fine-tuning for w in self.test_worlds: w.reset( ) # Should also reset the teacher.index.value --> -1, but keep the domain fixed. w.agents[0].add_domain(dd) w.agents[0].add_all_domain_data(dd) w.agents[0].fix_teacher_domain([dd]) w.agents[ 0].index.value = -1 # reset index because we'll stream through the data. w.agents[0].entry_idx = 0 # note the appropriate state_dict should be loaded, as the agent should # be shared by reference in the training and the testing worlds. self.write_log("STARTING Fine-tuning OF STUDENT HERE on %s " % dd) if self.test_worlds[0].agents[ 0].num_episodes_in_restricted_domain() > 0: # make sure the meta parameters are loaded before evaluating another training domain world.agents[1].model.load_state_dict(M) world.agents[1].optimizer.load_state_dict(optim_state) # Restrict training world to fine-tuning domain teacher.fix_teacher_domain([dd]) teacher.index.value = -1 # reset index because we'll stream through the training data. teacher.entry_idx = 0 logging.info('Fine-tuning to: %s' % dd) self.write_log('fine-tuning epoch size: %s' % teacher.num_episodes_in_restricted_domain()) self.best_valid = None stop_training = False self.tune_parley_epochs = 0 # Fine-tune model to single domain while not stop_training: # fine-tune for one epoch over training self.write_log('Learning rate : %s ' % world.agents[1].optimizer.state_dict() ['param_groups'][0]['lr']) # while not world.epoch_done(): # HERE: loop for an epoch over domain training data. for n in range( int(teacher.num_episodes_in_restricted_domain( ) / opt['num_episode_batch']) ): # epoch episodes, as each full episode processed. # print('\n\n TRAINING PARLEY') world.batch_parley( ) # Note the updating is fixed to the domain training data only. # import pdb; pdb.set_trace() # fine-tune until validation on domain stops decreasing. # print('\n\n VALIDATION PARLEYS') if not opt['validation_metric'].startswith('bleu'): student.skip_generation = True stop_training = self.validate() # import pdb; pdb.set_trace() # print('\t\t\t\t\tValid: %s Learning rate : %s ' % (self.best_valid, world.agents[1].optimizer.state_dict()['param_groups'][0]['lr'])) # import sys; sys.exit() # print('num training exs: ', world.agents[0].num_episodes_in_restricted_domain()) # print('num valid exs: ', self.valid_worlds[0].agents[0].num_episodes_in_restricted_domain()) # print('valid is rand: ', self.valid_worlds[0].agents[0].random) # print('episode index: ', self.valid_worlds[0].agents[0].index.value) # print('Training: ', world.agents[0].messages[0]) # print('Validation: ', self.valid_worlds[0].agents[0].messages[0]) self.tune_parley_epochs += 1 logging.info('Best valid: %s' % self.best_valid) self.write_log('Best fine-tune valid: %s' % self.best_valid) self.write_log('Finished %s tune_parley epochs' % self.tune_parley_epochs) # Evaluate on domain test set. student.skip_generation = False max_exs = -1 t_report = self._run_eval(self.test_worlds, opt, 'test', max_exs, write_log=True) logging.info('on domain %s: test report: ' % dd) logging.info(t_report) eval_data[dd] = { 'domain': dd, 'test_report': t_report, 'num_parleys': self.parleys, 'tune_epochs': self.tune_parley_epochs, 'mt_epoch_size': teacher.num_episodes(), 'domain_epoch_size': teacher.num_episodes_in_restricted_domain() } import datetime stamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M') if opt['no_multi_task']: locationname = '/home/oademasi/transfer-learning-conv-ai/ParlAI/parlai_internal/eval_data_ft_%s.pkl' % stamp else: locationname = '/home/oademasi/transfer-learning-conv-ai/ParlAI/parlai_internal/eval_data_mtft_%s.pkl' % stamp pickle.dump(eval_data, open(locationname, 'wb')) print('wrote to: ', locationname) v_report = None t_report = None return v_report, t_report
def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs())) exs_per_epoch = world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: # log before we validate self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for distributed mode" ) break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- if opt.get('run_test_after_validation', False): # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_worlds is None: # we need to load the world now self.test_worlds = load_eval_worlds( self.agent, test_opt, 'test') run_eval(self.test_worlds, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print( '[ Conducting evaluations on valid and test data using the last model. ]' ) last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_worlds = load_eval_worlds(last_agent, last_model_opt, 'valid') max_exs = last_model_opt[ 'validation_max_exs'] if last_model_opt.get( 'short_final_eval') else -1 run_eval(valid_worlds, last_model_opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(last_agent, last_model_opt, 'test') run_eval(test_worlds, last_model_opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report