def multiprocess_train( rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' ): """ Subprocess which initializes distributed training, and begins training. This should be launched n times for n GPUs; this is handled either in main or via srun. :param int rank: This process's rank - 1. (Starts at -1 ... n - 2). See comments. :param opt: command line options :param int port: A TCP port to use. This will need to be changed to run multiple distributed training setups on the same machine. :param int gpu: Which GPU to use. Defaults to using rank and local devices, but must be manually specified when using many-hosts. :param str hostname: Hostname of the main server. """ # Set per-host options opt = copy.deepcopy(opt) # we need to manually adjust the rank differently in multiprocessing # and distributed train rank = rank + rank_offset opt['rank'] = rank if gpu is None: # default assumption is local GPUs gpu = rank % torch.cuda.device_count() opt['gpu'] = gpu # make sure we don't just use whatever GPU was saved in the model file if 'override' not in opt: opt['override'] = {} opt['override']['gpu'] = gpu # Suppress output of workers except the main host. if opt.get('verbose') or rank != 0: print_prefix = '[rank:{:3d}]'.format(rank) else: print_prefix = None suppress_output = not opt.get('verbose') and rank != 0 with distributed_utils.override_print(suppress_output, print_prefix): # perform distributed setup, ensuring all hosts are ready torch.cuda.set_device(opt['gpu']) dist.init_process_group( backend="nccl", init_method="tcp://{}:{}".format(hostname, port), world_size=opt['distributed_world_size'], rank=rank, ) print("Distributed group initialized") # manual_seed can be a noop without this torch.cuda.init() # make sure all parameters will be in sync torch.manual_seed(42) # force a sync so that no one gets ahead, and all are seeded together distributed_utils.sync_object(None) # Run the actual training return single_train.TrainLoop(opt).train()
def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # print(world.display()) # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object( (self.train_time.time(), self.log_time.time(), self.validate_time.time())) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') v_report = run_eval(valid_world, opt, 'valid', write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() return v_report, t_report
def validate(self): opt = self.opt if self.valid_world is None: # we need to load the world now self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'], True)) # logging if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('valid', int(self.train_time.time()), valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False
def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated." ) opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs() ) exs_per_epoch = self.world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time() )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print('[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs ): stop_training = self.validate() self.last_valid_epoch = self._total_epochs # --------------- change by hengyicai ------------------------- # run evaluation on the test data as well test_opt = copy.deepcopy(self.opt) test_opt['display_examples'] = False test_opt['report_freq'] = 0 if self.test_world is None: # we need to load the world now self.test_world = _maybe_load_eval_world(self.agent, test_opt, 'test') run_eval(self.test_world, test_opt, 'test', -1, write_log=True) # --------------- change by hengyicai ------------------------- if stop_training: break if ( self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker() ): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'] )) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() # --------------- change by hengyicai ------------------------- last_model = opt.get('model_file') + '.checkpoint' if os.path.isfile(last_model): print('[ Conducting evaluations on valid and test data using the last model. ]') last_model_opt = copy.deepcopy(opt) last_model_opt['model_file'] = last_model last_agent = create_agent(last_model_opt) valid_world = _maybe_load_eval_world(last_agent, last_model_opt, 'valid') max_exs = last_model_opt['validation_max_exs'] if last_model_opt.get('short_final_eval') else -1 run_eval(valid_world, last_model_opt, 'valid', max_exs, write_log=True) test_world = _maybe_load_eval_world(last_agent, last_model_opt, 'test') run_eval(test_world, last_model_opt, 'test', max_exs, write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() # --------------- change by hengyicai ------------------------- print_announcements(opt) return v_report, t_report
def validate(self): opt = self.opt if self.valid_world is None: # we need to load the world now self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object(run_eval( self.valid_world, opt, 'valid', opt['validation_max_exs'], )) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('valid', int(self.train_time.time()), valid_report) # saving if ( opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker() ): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # --------------- change by hengyicai ------------------------- teacher_agent = self.return_teacher_agent() if teacher_agent: teacher_agent.receive_metrics(valid_report) # --------------- change by hengyicai ------------------------- # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) # --------------- change by hengyicai ------------------------- if self.opt.get('cutoff_metric_name', 'none') != 'none': cutoff_metric_name = self.opt['cutoff_metric_name'] cutoff_metric_val = self.opt['cutoff_metric_val'] if cutoff_metric_name in valid_report and cutoff_metric_val > 0: if valid_report[cutoff_metric_name] >= cutoff_metric_val: print('[ {} >= {}, stopping. ]'.format( cutoff_metric_name, cutoff_metric_val )) return True elif cutoff_metric_name not in valid_report: warn_once('[ {} is not in the validation report!' 'can not do metric cutoff stopping! ]'.format(cutoff_metric_name)) else: warn_once('[ you asked to do metric cutoff stopping,' 'but the cutoff_metric_val <= 0! ]') # --------------- change by hengyicai ------------------------- self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False