def _maybe_load_eval_world(agent, opt, datatype): if not is_primary_worker(): # only need the validation on the main worker return None else: return load_eval_world(agent, opt, datatype)
def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object( (self.train_time.time(), self.log_time.time(), self.validate_time.time())) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() print("\n".join([ "", "*" * 80, "Thank you for using ParlAI! We are conducting a user survey.", "Please consider filling it out at https://forms.gle/uEFbYGP7w6hiuGQT9", "*" * 80, "" ])) return v_report, t_report
def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = num_workers( ) * self.world.get_total_epochs() exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object( (self.train_time.time(), self.log_time.time(), self.validate_time.time())) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.agent.save(opt['model_file'] + '.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.agent.save(opt['model_file']) elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') v_report = run_eval(valid_world, opt, 'valid', write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() return v_report, t_report
def validate(self): opt = self.opt if self.valid_world is None: # we need to load the world now self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval( self.valid_world, opt, 'valid', opt['validation_max_exs'], )) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('valid', int(self.train_time.time()), valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False
def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated." ) opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs() ) exs_per_epoch = self.world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time() )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print('[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs ): stop_training = self.validate() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_world is None: # we need to load the world now self.test_world = _maybe_load_eval_world(self.agent, test_opt, 'test') run_eval(self.test_world, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break if ( self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker() ): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'] )) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print('[ Conducting evaluations on valid and test data using the last model. ]') last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_world = _maybe_load_eval_world(last_agent, last_model_opt, 'valid') max_exs = last_model_opt['validation_max_exs'] if last_model_opt.get('short_final_eval') else -1 run_eval(valid_world, last_model_opt, 'valid', max_exs, write_log=True) test_world = _maybe_load_eval_world(last_agent, last_model_opt, 'test') run_eval(test_world, last_model_opt, 'test', max_exs, write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report