class TrainLoop(): def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) #specify model such as seq2seq self.world = create_task(opt, self.agent) # bacthworld or other world # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \ else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.val_every_n_secs = \ opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \ else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \ > 0 else float('inf') self.val_every_n_epochs = \ opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \ else float('inf') # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if (opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix)): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def save_model(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, }, f) def validate(self): opt = self.opt if self.valid_world is None: # we need to load the world now self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'], True)) # logging if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('valid', int(self.train_time.time()), valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _average_dicts(self, all_versions): # instead of a list-of-dicts with like keys, make a dict-of-lists with # keys to reduce to_reduce = {} for d in all_versions: for k, v in d.items(): to_reduce.setdefault(k, []).append(v) # now perform the reduction finalized = {} for k, values in to_reduce.items(): if k == 'exs' or k == 'total_skipped_batches': # sum across workers finalized[k] = np.sum(values) elif isinstance(values[0], dict): # do the same procedure recursively finalized[k] = self._average_dicts(values) else: # all other cases, take the mean across the workers finalized[k] = np.mean(values) return finalized def _sync_training_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return self._average_dicts(all_versions) def _nice_format(self, dictionary): rounded = {} for k, v in dictionary.items(): if isinstance(v, dict): rounded[k] = self._nice_format(v) elif isinstance(v, float): rounded[k] = round_sigfigs(v, 4) else: rounded[k] = v return rounded def _compute_eta(self, epochs_completed, time_elapsed): """ Computes the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self._sync_training_metrics(self.world.report()) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), self._nice_format(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('train', self._total_exs, train_report) def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # print(world.display()) # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object( (self.train_time.time(), self.log_time.time(), self.validate_time.time())) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') v_report = run_eval(valid_world, opt, 'valid', write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() return v_report, t_report
class TrainLoop(): def __init__(self, opt): if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt[ 'max_train_time'] if opt['max_train_time'] > 0 else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt[ 'log_every_n_secs'] > 0 else float('inf') self.val_every_n_secs = opt['validation_every_n_secs'] if opt[ 'validation_every_n_secs'] > 0 else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt[ 'save_every_n_secs'] > 0 else float('inf') self.val_every_n_epochs = opt['validation_every_n_epochs'] if opt[ 'validation_every_n_epochs'] > 0 else float('inf') self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def validate(self): opt = self.opt # run evaluation on valid set valid_report, self.valid_world = run_eval(self.agent, opt, 'valid', opt['validation_max_exs'], valid_world=self.valid_world) # logging if opt['tensorboard_log'] is True: self.writer.add_metrics('valid', int(math.floor(self.train_time.time())), valid_report) # saving if opt.get('model_file') and opt.get('save_after_valid'): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.agent.save(opt['model_file'] + '.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid: print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): print("[ saving best valid model: " + opt['model_file'] + " ]") self.agent.save(opt['model_file']) print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[ 'validation_cutoff']: print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if opt['validation_patience'] > 0 and self.impatience >= opt[ 'validation_patience']: print('[ ran out of patience! stopping training. ]') return True return False def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report(compute_time=True) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(math.floor(self.train_time.time()))) total_exs = self.world.get_total_exs() logs.append('total_exs:{}'.format(total_exs)) exs_per_ep = self.world.num_examples() if exs_per_ep: logs.append('epochs:{}'.format(round(total_exs / exs_per_ep, 2))) if 'time_left' in train_report: logs.append('time_left:{}s'.format( math.floor(train_report.pop('time_left', "")))) log = '[ {} ] {}'.format(' '.join(logs), train_report) print(log) self.log_time.reset() if opt['tensorboard_log'] is True: self.writer.add_metrics('train', int(logs[1].split(":")[1]), train_report) def train(self): opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # check counters and timers if world.get_total_epochs() >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, self.train_time.time())) break if self.train_time.time() > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format( self.train_time.time())) break if self.log_time.time() > self.log_every_n_secs: self.log() if self.validate_time.time() > self.val_every_n_secs: stop_training = self.validate() if stop_training: break if world.get_total_epochs( ) - self.last_valid_epoch >= self.val_every_n_epochs: stop_training = self.validate() self.last_valid_epoch = world.get_total_epochs() if stop_training: break if self.save_time.time() > self.save_every_n_secs and opt.get( 'model_file'): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.agent.save(opt['model_file'] + '.checkpoint') self.save_time.reset() if not self.saved: # save agent self.agent.save(opt['model_file']) elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True) t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True) v_world.shutdown() t_world.shutdown() return v_report, t_report
class DefaultTeacher(FbDialogTeacher): def __init__(self, opt, shared=None): opt = copy.deepcopy(opt) super().__init__(opt, shared) self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available() self.is_combine_attr = (hasattr(self, 'other_task_datafiles') and self.other_task_datafiles) self.random_policy = opt.get('random_policy', False) self.count_sample = opt.get('count_sample', False) self.anti = opt.get('anti', False) if self.random_policy: random.seed(17) if not shared: if not self.stream and opt.get('pace_by', 'sample') == 'bucket': score_list = [episode[0][2] for episode in self.data.data] assert score_list == sorted(score_list) num_buckets = opt.get('num_buckets', int(self.num_episodes() / 10)) lb_indices = [ int(len(score_list) * i / num_buckets) for i in range(num_buckets) ] lbs = [score_list[idx] for idx in lb_indices] bucket_ids = [ self.sort_into_bucket(ctrl_val, lbs) for ctrl_val in score_list ] bucket_cnt = [0 for _ in range(num_buckets)] for i in range(num_buckets): bucket_cnt[i] = bucket_ids.count(i) self.bucket_cnt = bucket_cnt self.lastYs = [None] * self.bsz # build multiple task data self.tasks = [self.data] if self.is_combine_attr: print('[ build multiple task data ... ]') for datafile in self.other_task_datafiles: task_opt = copy.deepcopy(opt) task_opt['datafile'] = datafile self.tasks.append( DialogData(task_opt, data_loader=self.setup_data, cands=self.label_candidates())) print('[ build multiple task data done! ]') # record the selections of each subtasks self.subtasks = opt['subtasks'].split(':') self.subtask_counter = OrderedDict() self.p_selections = OrderedDict() self.c_selections = OrderedDict() for t in self.subtasks: self.subtask_counter[t] = 0 self.p_selections[t] = [] self.c_selections[t] = [] if self.count_sample and not self.stream: self.sample_counter = OrderedDict() for idx, t in enumerate(self.subtasks): self.sample_counter[t] = [ 0 for _ in self.tasks[idx].data ] # setup the tensorboard log if opt['tensorboard_log_teacher'] is True: opt['tensorboard_tag'] = 'task' teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split( ',') opt['tensorboard_metrics'] = ','.join( opt['tensorboard_metrics'].split(',') + teacher_metrics) self.writer = TensorboardLogger(opt) else: self.lastYs = shared['lastYs'] self.tasks = shared['tasks'] if not self.stream and opt.get('pace_by', 'sample') == 'bucket': self.bucket_cnt = shared['bucket_cnt'] if 'writer' in shared: self.writer = shared['writer'] if 'subtask_counter' in shared: self.subtask_counter = shared['subtask_counter'] if 'p_selections' in shared: self.p_selections = shared['p_selections'] if 'c_selections' in shared: self.c_selections = shared['c_selections'] # build the policy net, criterion and optimizer here self.state_dim = 32 + len(self.tasks) # hand-craft features self.action_dim = len(self.tasks) if not shared: self.policy = PolicyNet(self.state_dim, self.action_dim) self.critic = CriticNet(self.state_dim, self.action_dim) init_teacher = get_init_teacher(opt, shared) if init_teacher is not None: # load teacher parameters if available print('[ Loading existing teacher params from {} ]' ''.format(init_teacher)) states = self.load(init_teacher) else: states = {} else: self.policy = shared['policy'] self.critic = shared['critic'] states = shared['states'] if ( # only build an optimizer if we're training 'train' in opt.get('datatype', '') and # and this is the main model shared is None): # for policy net self.optimizer = self.init_optim( [p for p in self.policy.parameters() if p.requires_grad], lr=opt['learningrate_teacher'], optim_states=states.get('optimizer'), saved_optim_type=states.get('optimizer_type')) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler' in states: self.scheduler.load_state_dict(states['lr_scheduler']) # for critic net self.optimizer_critic = self.init_optim( [p for p in self.critic.parameters() if p.requires_grad], lr=opt['learningrate_teacher_critic'], optim_states=states.get('optimizer_critic'), saved_optim_type=states.get('optimizer_type')) self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_critic, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler_critic' in states: self.scheduler_critic.load_state_dict( states['lr_scheduler_critic']) self.critic_criterion = torch.nn.SmoothL1Loss() self.reward_metric = opt.get('reward_metric', 'total_metric') self.reward_metric_mode = opt.get('reward_metric_mode', 'max') self.prev_prev_valid_report = states[ 'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None self.prev_valid_report = states[ 'prev_valid_report'] if 'prev_valid_report' in states else None self.current_valid_report = states[ 'current_valid_report'] if 'current_valid_report' in states else None self.saved_actions = states[ 'saved_actions'] if 'saved_actions' in states else OrderedDict() self.saved_state_actions = states[ 'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict( ) if self.use_cuda: for k, v in self.saved_actions.items(): self.saved_actions[k] = v.cuda() for k, v in self.saved_state_actions.items(): self.saved_state_actions[k] = v.cuda() self._number_teacher_updates = states[ '_number_teacher_updates'] if '_number_teacher_updates' in states else 0 # enable the batch_act self.use_batch_act = self.bsz > 1 self.T = self.opt.get('T', 1000) self.c0 = self.opt.get('c0', 0.01) self.p = self.opt.get('p', 2) # setup the timer self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.action_log_time = Timer() self.move_to_cuda() def move_to_cuda(self): if self.use_cuda: self.policy.cuda() self.critic.cuda() @classmethod def optim_opts(self): """ Fetch optimizer selection. By default, collects everything in torch.optim, as well as importing: - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim Override this (and probably call super()) to add your own optimizers. """ # first pull torch.optim in optims = { k.lower(): v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper() } try: import apex.optimizers.fused_adam as fused_adam optims['fused_adam'] = fused_adam.FusedAdam except ImportError: pass try: # https://openreview.net/pdf?id=S1fUpoR5FQ from qhoptim.pyt import QHM, QHAdam optims['qhm'] = QHM optims['qhadam'] = QHAdam except ImportError: # no QHM installed pass return optims def init_optim(self, params, lr, optim_states=None, saved_optim_type=None): """ Initialize optimizer with teacher parameters. :param params: parameters from the teacher :param optim_states: optional argument providing states of optimizer to load :param saved_optim_type: type of optimizer being loaded, if changed will skip loading optimizer states """ opt = self.opt # set up optimizer args kwargs = {'lr': lr} if opt.get('momentum_teacher') > 0 and opt['optimizer_teacher'] in [ 'sgd', 'rmsprop', 'qhm' ]: # turn on momentum for optimizers that use it kwargs['momentum'] = opt['momentum_teacher'] if opt['optimizer_teacher'] == 'sgd' and opt.get( 'nesterov_teacher', True): # for sgd, maybe nesterov kwargs['nesterov'] = opt.get('nesterov_teacher', True) elif opt['optimizer_teacher'] == 'qhm': # qhm needs a nu kwargs['nu'] = opt.get('nus_teacher', (0.7, ))[0] elif opt['optimizer_teacher'] == 'adam': # turn on amsgrad for adam # amsgrad paper: https://openreview.net/forum?id=ryQu7f-RZ kwargs['amsgrad'] = True elif opt['optimizer_teacher'] == 'qhadam': # set nus for qhadam kwargs['nus'] = opt.get('nus_teacher', (0.7, 1.0)) if opt['optimizer_teacher'] in [ 'adam', 'sparseadam', 'adamax', 'qhadam' ]: # set betas for optims that use it kwargs['betas'] = opt.get('betas_teacher', (0.9, 0.999)) optim_class = self.optim_opts()[opt['optimizer_teacher']] optimizer = optim_class(params, **kwargs) if optim_states: if saved_optim_type != opt['optimizer_teacher']: print('WARNING: not loading optim state since optim class ' 'changed.') else: try: optimizer.load_state_dict(optim_states) except ValueError: print('WARNING: not loading optim state since model ' 'params changed.') if self.use_cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() return optimizer def load(self, path): """ Return opt and teacher states. TODO: load behaviors should be consistent with function state_dict(). """ states = torch.load(path, map_location=lambda cpu, _: cpu) if 'policy' in states: self.policy.load_state_dict(states['policy']) if 'critic' in states: self.critic.load_state_dict(states['critic']) if 'optimizer' in states and hasattr(self, 'optimizer'): self.optimizer.load_state_dict(states['optimizer']) if 'optimizer_critic' in states and hasattr(self, 'optimizer_critic'): self.optimizer_critic.load_state_dict(states['optimizer_critic']) return states def share(self): shared = super().share() if hasattr(self, 'bucket_cnt'): shared['bucket_cnt'] = self.bucket_cnt shared['tasks'] = self.tasks shared['policy'] = self.policy shared['critic'] = self.critic shared['states'] = { 'optimizer_type': self.opt['optimizer_teacher'], 'prev_prev_valid_report': self.prev_prev_valid_report, 'prev_valid_report': self.prev_valid_report, 'current_valid_report': self.current_valid_report, 'saved_actions': self.saved_actions, 'saved_state_actions': self.saved_state_actions, } if hasattr(self, 'writer'): shared['writer'] = self.writer if hasattr(self, 'subtask_counter'): shared['subtask_counter'] = self.subtask_counter if hasattr(self, 'p_selections'): shared['p_selections'] = self.p_selections if hasattr(self, 'c_selections'): shared['c_selections'] = self.c_selections return shared @staticmethod def sort_into_bucket(val, bucket_lbs): """ Returns the highest bucket such that val >= lower bound for that bucket. Inputs: val: float. The value to be sorted into a bucket. bucket_lbs: list of floats, sorted ascending. Returns: bucket_id: int in range(num_buckets); the bucket that val belongs to. """ num_buckets = len(bucket_lbs) for bucket_id in range(num_buckets - 1, -1, -1): # iterate descending lb = bucket_lbs[bucket_id] if val >= lb: return bucket_id raise ValueError('val %f is not >= any of the lower bounds: %s' % (val, bucket_lbs)) def pace_function(self, states, sum_num, T=1000, c0=0.01, p=2): train_step = states['train_step'] progress = self.root_p_pace(train_step, T, c0, p) return int(sum_num * progress) @staticmethod def root_p_pace(timestep, T=1000, c0=0.01, p=2): root_p = math.pow( timestep * (1 - math.pow(c0, p)) / T + math.pow(c0, p), 1.0 / p) return min(1.0, root_p) def act(self, observation=None, task_idx=0): """Send new dialog message.""" if not hasattr(self, 'epochDone'): # reset if haven't yet self.reset() # get next example, action is episode_done dict if already out of exs action, self.epochDone = self.next_example(observation=observation, task_idx=task_idx) action['id'] = self.getID() # remember correct answer if available self.lastY = action.get('labels', action.get('eval_labels', None)) if ((not self.datatype.startswith('train') or 'evalmode' in self.datatype) and 'labels' in action): # move labels to eval field so not used for training # but this way the model can use the labels for perplexity or loss action = action.copy() labels = action.pop('labels') if not self.opt.get('hide_labels', False): action['eval_labels'] = labels return action def _cry_for_missing_in_obs(self, something): raise RuntimeError( "{} is needed to include in observations to build states!".format( something)) def _build_states(self, observations): for key in ['train_step', 'train_report', 'loss_desc', 'prob_desc']: if key not in observations[0]: self._cry_for_missing_in_obs(key) train_step = observations[0]['train_step'] # scala train_step = min(train_step / self.T, 1) train_report = observations[0]['train_report'] nll_loss = train_report.get('nll_loss', 0) / 10 # scala loss_desc = observations[0]['loss_desc'] loss_desc = F.normalize(loss_desc, p=2, dim=-1) prob_desc = observations[0]['prob_desc'] prob_desc = F.normalize(prob_desc, p=2, dim=-1) if hasattr(self, 'subtask_counter'): subtask_progress = self.subtask_counter.values() max_min = max(subtask_progress) - min(subtask_progress) subtask_progress = [ (item - min(subtask_progress)) / max_min if max_min > 0 else 0 for item in subtask_progress ] else: subtask_progress = [0] subtask_progress = torch.FloatTensor(subtask_progress) if self.use_cuda: subtask_progress = subtask_progress.cuda() prev_valid_report = self.prev_valid_report if prev_valid_report is None: prev_valid_report = {} bleu = prev_valid_report.get('bleu', 0) valid_nll_loss = prev_valid_report.get('nll_loss', 0) / 10 dist_1_ratio = prev_valid_report.get('dist_1_ratio', 0) dist_2_ratio = prev_valid_report.get('dist_2_ratio', 0) dist_3_ratio = prev_valid_report.get('dist_3_ratio', 0) embed_avg = prev_valid_report.get('embed_avg', 0) embed_greedy = prev_valid_report.get('embed_greedy', 0) embed_extrema = prev_valid_report.get('embed_extrema', 0) embed_coh = prev_valid_report.get('embed_coh', 0) intra_dist_1 = prev_valid_report.get('intra_dist_1', 0) / 10 intra_dist_2 = prev_valid_report.get('intra_dist_2', 0) / 10 intra_dist_3 = prev_valid_report.get('intra_dist_3', 0) / 10 response_length = prev_valid_report.get( 'response_length', 0) / self.opt.get('label_truncate', 100) # sent_entropy_uni = prev_valid_report.get('sent_entropy_uni', 0) / 100 # sent_entropy_bi = prev_valid_report.get('sent_entropy_bi', 0) / 100 # sent_entropy_tri = prev_valid_report.get('sent_entropy_tri', 0) / 100 word_entropy_uni = prev_valid_report.get('word_entropy_uni', 0) / 100 word_entropy_bi = prev_valid_report.get('word_entropy_bi', 0) / 100 word_entropy_tri = prev_valid_report.get('word_entropy_tri', 0) / 100 states = torch.FloatTensor([ train_step, nll_loss, bleu, valid_nll_loss, dist_1_ratio, dist_2_ratio, dist_3_ratio, embed_avg, embed_greedy, embed_extrema, embed_coh, intra_dist_1, intra_dist_2, intra_dist_3, response_length, # sent_entropy_uni, sent_entropy_bi, sent_entropy_tri, word_entropy_uni, word_entropy_bi, word_entropy_tri ]) if self.use_cuda: states = states.cuda() states = torch.cat([states, loss_desc, prob_desc, subtask_progress], dim=-1).unsqueeze(dim=0) return states def __uniform_weights(self): w = 1 / len(self.tasks) weights = torch.FloatTensor([w] * len(self.tasks)) if self.use_cuda: weights = weights.cuda() return weights.unsqueeze(dim=0) def __load_training_batch(self, observations): if observations and len( observations) > 0 and observations[0] and self.is_combine_attr: if not self.random_policy: with torch.no_grad(): current_states = self._build_states(observations) action_probs = self.policy(current_states) if self.action_log_time.time() > self.log_every_n_secs and len( self.tasks) > 1: with torch.no_grad(): # log the action distributions action_p = ','.join([ str(round_sigfigs(x, 4)) for x in action_probs[0].data.tolist() ]) log = '[ {} {} ]'.format('Action probs:', action_p) print(log) self.action_log_time.reset() sample_from = Categorical(action_probs[0]) action = sample_from.sample() train_step = observations[0]['train_step'] self.saved_actions[train_step] = sample_from.log_prob(action) self.saved_state_actions[train_step] = torch.cat( [current_states, action_probs], dim=1) selected_task = action.item() self.subtask_counter[self.subtasks[selected_task]] += 1 probs = action_probs[0].tolist() selection_report = {} for idx, t in enumerate(self.subtasks): selection_report['p_{}'.format(t)] = probs[idx] self.p_selections[t].append(probs[idx]) selection_report['c_{}'.format( t)] = self.subtask_counter[t] self.c_selections[t].append(self.subtask_counter[t]) self.writer.add_metrics(setting='Teacher/task_selection', step=train_step, report=selection_report) else: selected_task = random.choice(range(len(self.tasks))) self.subtask_counter[self.subtasks[selected_task]] += 1 else: selected_task = 0 return self.__load_batch(observations, task_idx=selected_task) def __load_batch(self, observations, task_idx=0): if observations is None: observations = [None] * self.bsz bsz = len(observations) batch = [] # Sample from multiple tasks using the policy net for idx in range(bsz): batch.append(self.act(observations[idx], task_idx=task_idx)) return batch def batch_act(self, observations): """ Returns an entire batch of examples instead of just one. """ if not hasattr(self, 'epochDone'): # reset if haven't yet self.reset() if self.opt['datatype'] == 'train': batch = self.__load_training_batch(observations) else: batch = self.__load_batch(observations) # pad batch if len(batch) < self.bsz: batch += [{ 'episode_done': True, 'id': self.getID() }] * (self.bsz - len(batch)) # remember correct answer if available (for padding, None) for i, ex in enumerate(batch): if 'labels' in ex: labels = ex['labels'] self.lastYs[i] = labels if not self.datatype.startswith( 'train') or 'evalmode' in self.datatype: del ex['labels'] if not self.opt.get('hide_labels', False): ex['eval_labels'] = labels else: self.lastYs[i] = ex.get('eval_labels', None) return batch def next_example(self, observation=None, task_idx=0): """ Returns the next example. If there are multiple examples in the same episode, returns the next one in that episode. If that episode is over, gets a new episode index and returns the first example of that episode. """ if self.stream: action, epoch_done = self.tasks[task_idx].get() else: if self.episode_done: self.episode_idx = self.next_episode_idx() self.entry_idx = 0 else: self.entry_idx += 1 if self.episode_idx >= self.num_episodes(): return {'episode_done': True}, True if observation is None or self.opt['datatype'] != 'train': # The first step of the training or validation mode sampled_episode_idx = self.episode_idx sampled_entry_idx = self.entry_idx else: # --------------- pick the sample according to the pace function ----------- pace_by = self.opt.get('pace_by', 'sample') if pace_by == 'sample': sum_num = self.num_episodes() elif pace_by == 'bucket': sum_num = len(self.bucket_cnt) else: raise ValueError('pace_by must be {} or {}!'.format( 'sample', 'bucket')) states4pace_func = observation if hasattr(self, 'subtask_counter'): states4pace_func = { 'train_step': self.subtask_counter[self.subtasks[task_idx]] } threshold = self.pace_function(states4pace_func, sum_num, self.T, self.c0, self.p) if pace_by == 'sample': stop_step = threshold elif pace_by == 'bucket': stop_step = sum(self.bucket_cnt[:threshold]) else: raise ValueError('pace_by must be {} or {}!'.format( 'sample', 'bucket')) stop_step = self.num_episodes( ) if stop_step > self.num_episodes() else stop_step # sampled_episode_idx = random.choice(list(range(self.num_episodes()))[:stop_step]) sampled_episode_idx = np.random.choice(stop_step) sampled_entry_idx = 0 # make sure the episode only contains one entry if self.anti: sampled_episode_idx = self.num_episodes( ) - 1 - sampled_episode_idx if self.count_sample: self.sample_counter[ self.subtasks[task_idx]][sampled_episode_idx] += 1 ex = self.get(sampled_episode_idx, sampled_entry_idx, task_idx=task_idx) if observation is None or self.opt['datatype'] != 'train': self.episode_done = ex.get('episode_done', False) if (not self.random and self.episode_done and self.episode_idx + self.opt.get("batchsize", 1) >= self.num_episodes()): epoch_done = True else: epoch_done = False else: # in the setting of curriculum leaning, samples are not uniformly # picked from the training set, so, the epoch records here make no sense. epoch_done = False action = ex return action, epoch_done def get(self, episode_idx, entry_idx=0, task_idx=0): return self.tasks[task_idx].get(episode_idx, entry_idx)[0] def update_params(self): self._number_teacher_updates += 1 if self.opt.get('gradient_clip_teacher', -1) > 0: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.opt['gradient_clip_teacher']) self.optimizer.step() def update_critic_params(self): if self.opt.get('gradient_clip_teacher', -1) > 0: torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.opt['gradient_clip_teacher']) self.optimizer_critic.step() def receive_metrics(self, metrics_dict): if self.is_combine_attr and not self.random_policy: assert self.reward_metric in metrics_dict, '{} is not in the metrics_dict!'.format( self.reward_metric) self.prev_prev_valid_report = self.prev_valid_report self.prev_valid_report = self.current_valid_report self.current_valid_report = metrics_dict delt_reward = None if self.prev_prev_valid_report and self.prev_valid_report and self.current_valid_report: delt_reward1 = self.current_valid_report[ self.reward_metric] - self.prev_valid_report[ self.reward_metric] delt_reward0 = self.prev_valid_report[ self.reward_metric] - self.prev_prev_valid_report[ self.reward_metric] if self.reward_metric_mode == 'min': delt_reward1 = -delt_reward1 delt_reward0 = -delt_reward0 delt_reward = delt_reward1 / (delt_reward0 + 1e-6) - 1 if delt_reward and len(self.saved_actions) > 0 and len( self.saved_state_actions) > 0: reward = torch.clamp(torch.FloatTensor([delt_reward]), -10, 10) if self.use_cuda: reward = reward.cuda() with torch.no_grad(): batch_state_actions = torch.cat(list( self.saved_state_actions.values()), dim=0) if self.use_cuda: batch_state_actions = batch_state_actions.cuda() estimate_rewards = self.critic( batch_state_actions).squeeze() advantages = reward - estimate_rewards # rescale the rewards by ranking episode_len = len(advantages) ranks = torch.FloatTensor( list( reversed( ss.rankdata(advantages.cpu(), method='dense')))).unsqueeze(dim=1) rescaled_rewards = torch.sigmoid( 12 * (0.5 - ranks / episode_len)) rescaled_rewards = [r.item() for r in rescaled_rewards] policy_loss = [] idx = 0 for model_train_step, log_prob in self.saved_actions.items(): policy_loss.append(-log_prob.unsqueeze(dim=0) * rescaled_rewards[idx]) idx += 1 policy_loss = torch.cat(policy_loss).sum() # regularization term regarding action distribution bsz = batch_state_actions.size(0) action_probs = torch.cat(list( self.saved_state_actions.values()), dim=0).narrow(1, self.state_dim, self.action_dim) action_ent = torch.sum( -action_probs * torch.log(action_probs)) / bsz self.policy.train() self.optimizer.zero_grad() policy_loss = policy_loss + self.opt.get('reg_action', 0.001) * (-action_ent) policy_loss.backward() self.update_params() # lr_scheduler step on teacher loss policy_loss_item = policy_loss.item() if self.opt.get('optimizer_teacher', '') == 'sgd': self.scheduler.step(policy_loss_item) # training on the critic self.critic.train() self.optimizer_critic.zero_grad() batch_values = self.critic(batch_state_actions) critic_target = torch.FloatTensor(bsz, 1) critic_target = critic_target.fill_(reward.item()) if self.use_cuda: critic_target = critic_target.cuda() critic_loss = self.critic_criterion(batch_values, critic_target) critic_loss.backward() self.update_critic_params() critic_loss_item = critic_loss.item() if self.opt.get('optimizer_teacher', '') == 'sgd': self.scheduler_critic.step(critic_loss_item) # log something print( '[ reward: {}; mean_advantage_reward: {}; policy loss: {};' ' critic loss: {}; action ent: {}; episode length: {} ]'. format(reward.item(), np.mean(advantages.tolist()), policy_loss_item, critic_loss_item, action_ent.item(), len(self.saved_actions))) report = { 'reward': reward.item(), 'mean_advantage_reward': np.mean(advantages.tolist()), 'policy_loss': policy_loss_item, 'critic_loss': critic_loss_item, 'action_ent': action_ent.item(), } self.writer.add_metrics(setting='Teacher/receive_metrics', step=self._number_teacher_updates, report=report) # clear history actions self.saved_actions.clear() self.saved_state_actions.clear() def state_dict(self): """ Get the state dict for saving TODO: save more teacher-related states for reloading """ states = {} if hasattr(self, 'policy'): # save model params if hasattr(self.policy, 'module'): # did we wrap in a DistributedDataParallel states['policy'] = self.policy.module.state_dict() else: states['policy'] = self.policy.state_dict() if hasattr(self, 'critic'): # save model params if hasattr(self.critic, 'module'): # did we wrap in a DistributedDataParallel states['critic'] = self.critic.module.state_dict() else: states['critic'] = self.critic.state_dict() if hasattr(self, 'optimizer'): # save optimizer params states['optimizer'] = self.optimizer.state_dict() states['optimizer_type'] = self.opt['optimizer_teacher'] if hasattr(self, 'optimizer_critic'): states['optimizer_critic'] = self.optimizer_critic.state_dict() if getattr(self, 'scheduler', None): states['lr_scheduler'] = self.scheduler.state_dict() if getattr(self, 'scheduler_critic', None): states['lr_scheduler_critic'] = self.scheduler_critic.state_dict() states['prev_prev_valid_report'] = self.prev_prev_valid_report states['prev_valid_report'] = self.prev_valid_report states['current_valid_report'] = self.current_valid_report states['saved_actions'] = self.saved_actions states['saved_state_actions'] = self.saved_state_actions states['_number_teacher_updates'] = self._number_teacher_updates return states def save(self, path=None): if path: teacher_path = path else: model_file = self.opt.get('model_file', None) if model_file: teacher_path = model_file + '.teacher' else: teacher_path = None if teacher_path: states = self.state_dict() if states: with open(teacher_path, 'wb') as write: torch.save(states, write) # save opt file with open(teacher_path + '.opt', 'w', encoding='utf-8') as handle: json.dump(self.opt, handle) # for convenience of working with jq, make sure there's a newline handle.write('\n') if self.count_sample: # save sample count info for task_name, task_val in self.sample_counter.items(): with open(teacher_path + '.sample_count.{}'.format(task_name), 'w', encoding='utf-8') as f: f.write('\n'.join([str(item) for item in task_val])) self.write_selections('p_selections', teacher_path) self.write_selections('c_selections', teacher_path) def write_selections(self, selections, teacher_path): if hasattr(self, selections): with open(teacher_path + '.{}'.format(selections), 'w', encoding='utf-8') as f: f.write('\t'.join(self.subtasks)) f.write('\n') for idx in range( len(getattr(self, selections)[self.subtasks[0]])): p_line = [] for t in self.subtasks: p_line.append(str(getattr(self, selections)[t][idx])) f.write('\t'.join(p_line)) f.write('\n')
class TrainLoop(): def __init__(self, opt): if isinstance(opt, ParlaiParser): opt = opt.parse_args() # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: if opt['dict_file'] is None and opt.get( 'model_file_transmitter') and opt.get( 'model_file_receiver'): opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt[ 'model_file_receiver'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=False) # Create model and assign it to the specified task print("[ create meta-agent ... ]") self.agent = create_agent(opt) print("[ create agent A ... ]") shared = self.agent.share() self.agent_a = create_agent_from_shared(shared) self.agent_a.set_id(suffix=' A') print("[ create agent B ... ]") self.agent_b = create_agent_from_shared(shared) # self.agent_b = create_agent(opt) self.agent_b.set_id(suffix=' B') # self.agent_a.copy(self.agent, 'transmitter') # self.agent_b.copy(self.agent, 'transmitter') self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b]) # TODO: if batch, it is also not parallel # self.world = BatchSelfPlayWorld(opt, self_play_world) self.train_time = Timer() self.train_dis_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys_episode = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt[ 'max_train_time'] if opt['max_train_time'] > 0 else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt[ 'log_every_n_secs'] > 0 else float('inf') self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt[ 'train_display_every_n_secs'] > 0 else float('inf') self.val_every_n_secs = opt['validation_every_n_secs'] if opt[ 'validation_every_n_secs'] > 0 else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt[ 'save_every_n_secs'] > 0 else float('inf') self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file_transmitter') and os.path.isfile( opt['model_file_transmitter'] + '.best_valid'): with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def validate(self): opt = self.opt valid_report, self.valid_world = run_eval(self.agent, opt, 'valid', opt['validation_max_exs'], valid_world=self.valid_world) if opt['tensorboard_log'] is True: self.writer.add_metrics('valid', self.parleys_episode, valid_report) if opt.get('model_file_transmitter') and opt.get('save_after_valid'): print("[ saving transmitter checkpoint: " + opt['model_file_transmitter'] + ".checkpoint ]") self.agent.save(component='transmitter') # if opt.get('model_file_receiver') and opt.get('save_after_valid'): # print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]") # self.agent.save(component='receiver') if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid: print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): print("[ saving best valid model: " + opt['model_file'] + " ]") # the fine-tuned transmitter part is actually what we want for PSquare bot self.agent.save() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[ 'validation_cutoff']: print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() if 0 < opt['validation_patience'] <= self.impatience: print('[ ran out of patience! stopping training. ]') return True return False def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(math.floor(self.train_time.time()))) logs.append('parleys:{}'.format(self.parleys_episode)) if 'time_left' in train_report: logs.append('time_left:{}s'.format( math.floor(train_report.pop('time_left', "")))) if 'num_epochs' in train_report: logs.append('num_epochs:{}'.format( train_report.pop('num_epochs', ''))) log = '[ {} ] {}'.format(' '.join(logs), train_report) print(log) self.log_time.reset() if opt['tensorboard_log'] is True: self.writer.add_metrics('train', self.parleys_episode, train_report) def train(self): # print('#### Validating at {} training episode '.format(self.parleys_episode)) # self.validate() opt = self.opt world = self.world with world: while True: self.parleys_episode += 1 if self.parleys_episode % 100 == 0: print('#### Training {} episode '.format( self.parleys_episode)) if self.train_dis_time.time() > self.train_dis_every_n_secs: is_display = True # clear to zero self.train_dis_time.reset() else: is_display = False world.parley_episode(is_training=True, is_display=is_display) if world.get_total_epochs() >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, self.train_time.time())) break if self.train_time.time() > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format( self.train_time.time())) break if self.log_time.time() > self.log_every_n_secs: self.log() if self.validate_time.time() > self.val_every_n_secs: print('#### Validating at {} training episode '.format( self.parleys_episode)) stop_training = self.validate() if stop_training: break if self.save_time.time() > self.save_every_n_secs: if opt.get('model_file_transmitter'): print("[ saving transmitter checkpoint: " + opt['model_file_transmitter'] + ".checkpoint ]") self.agent.save(opt['model_file_transmitter'] + '.checkpoint', component='transmitter') if opt.get('model_file_receiver'): print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]") self.agent.save(opt['model_file_receiver'] + '.checkpoint', component='receiver') self.save_time.reset() if not self.saved: # save agent # self.agent.save(component='transmitter') self.agent.save() # self.agent.save(component='receiver') # TODO: API for save all components elif opt.get('model_file_transmitter') and opt.get( 'model_file_receiver' ): # TODO: check if both components are necessary # reload best validation model self.agent = create_agent(opt) v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True) t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True) v_world.shutdown() t_world.shutdown() return v_report, t_report