def iter_train(self, only_cached=False, _top_epoch=False, ewc_params=None): epoch = -1 base_path = util.path_model_trainer(self.ranker, self.vocab, self, self.dataset) context = { 'epoch': epoch, 'batch_size': self.config['batch_size'], 'batches_per_epoch': self.config['batches_per_epoch'], 'num_microbatches': 1 if self.config['grad_acc_batch'] == 0 else self.config['batch_size'] // self.config['grad_acc_batch'], 'device': self.device, 'base_path': base_path, } files = trainers.misc.PathManager(base_path) self.logger.info( f'train path: {base_path}, batches_per_epoch : {context["batches_per_epoch"]}' ) b_count = context['batches_per_epoch'] * context[ 'num_microbatches'] * self.batch_size ranker = self.ranker.to(self.device) optimizer = self.create_optimizer() if f'{epoch}.p' not in files['weights']: ranker.save(files['weights'][f'{epoch}.p']) if f'{epoch}.p' not in files['optimizer']: torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p']) context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'{epoch}.p'], }) #load ranker/optimizer if _top_epoch: # trainer.pipeline=msmarco_train_bm25_k1-0.82_b-0.68.100_mspairs __path = "/".join( base_path.split("/")[:-1]) + "/" + self.config['pipeline'] self.logger.info(f'loading prev model : {__path}') w_path = os.path.join(__path, 'weights', '-2.p') oppt_path = os.path.join(__path, 'optimizer', '-2.p') ranker.load(w_path) optimizer.load_state_dict(torch.load(oppt_path)) if w_path != files['weights']['-2.p']: copyfile(w_path, files['weights']['-2.p']) copyfile(oppt_path, files['optimizer']['-2.p']) context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'-2.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'-2.p'], }) yield context # before training while True: context = dict(context) epoch = context['epoch'] = context['epoch'] + 1 if epoch in files['complete.tsv']: context.update({ 'loss': files['loss.txt'][epoch], 'data_loss': files['data_loss.txt'][epoch], 'losses': {}, 'acc': files['acc.tsv'][epoch], 'unsup_acc': files['unsup_acc.tsv'][epoch], 'ranker': _load_ranker(ranker, files['weights'][f'{epoch}.p']), 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': _load_optimizer(optimizer, files['optimizer'][f'{epoch}.p']), 'optimizer_path': files['optimizer'][f'{epoch}.p'], 'cached': True, }) if not only_cached: self.fast_forward(b_count) # skip this epoch yield context continue if only_cached: break # no more cached # forward to previous versions (if needed) ranker = context['ranker']() optimizer = context['optimizer']() ranker.train() context.update({ 'loss': 0.0, 'losses': {}, 'acc': 0.0, 'unsup_acc': 0.0, }) # LOAD EWC _optpar_params = {} _fisher_params = {} print("EWC LOAD") for task in ewc_params.tasks: print("Aded task", task) _optpar_params[task] = {} _fisher_params[task] = {} for name, param in ranker.named_parameters(): if param.requires_grad: #for name, param in ranker.state_dict().items(): optpar_path = ewc_params.getOptpar(task, name) fisher_path = ewc_params.getFisher(task, name) _optpar_params[task][name] = pickle.load( open(optpar_path, "rb")).cuda() _fisher_params[task][name] = pickle.load( open(fisher_path, "rb")).cuda() with tqdm(leave=False, total=b_count, ncols=100, desc=f'train {epoch}') as pbar: for b in range(context['batches_per_epoch']): for _ in range(context['num_microbatches']): self.epoch = epoch train_batch_result = self.train_batch() #print("BATCH FINALIZED") losses = train_batch_result['losses'] loss_weights = train_batch_result['loss_weights'] acc = train_batch_result.get('acc') unsup_acc = train_batch_result.get('unsup_acc') #EWC if not self._ewc: for task in ewc_params.tasks: #print("EWC task", task) for name, param in ranker.named_parameters(): if param.requires_grad: fisher = _fisher_params[task][name] optpar = _optpar_params[task][name] losses['data'] += ( fisher * (optpar - param).pow(2) ).sum() * ewc_params.ewc_lambda #for name, param in ranker.state_dict().items(): #print(name,param) # optpar_path = ewc_params.getOptpar(task, name) # fisher_path = ewc_params.getFisher(task, name) # optpar = torch.load(optpar_path) #pickle.load(open(optpar_path,"rb")) # #torch.load('featurs.pkl',map_location=torch.device('cpu')) # param = (optpar - param) # #torch.cuda.empty_cache() # #gc.collect() # fisher = torch.load(fisher_path) #pickle.load(open(fisher_path,"rb")) # losses['data'] += (fisher * param.pow(2)).sum() * ewc_params.ewc_lambda # #torch.cuda.empty_cache() # #gc.collect() losses['data'] = losses['data'].mean() loss = sum( losses[k] * loss_weights.get(k, 1.) for k in losses) / context['num_microbatches'] context['loss'] += loss.item() for lname, lvalue in losses.items(): context['losses'].setdefault(lname, 0.) context['losses'][lname] += lvalue.item( ) / context['num_microbatches'] if acc is not None: context['acc'] += acc.item( ) / context['num_microbatches'] if unsup_acc is not None: context['unsup_acc'] += unsup_acc.item( ) / context['num_microbatches'] if loss.grad_fn is not None: if hasattr(optimizer, 'backward'): optimizer.backward(loss) else: loss.backward() else: self.logger.warn( 'loss has no grad_fn; skipping batch') pbar.update(self.batch_size) postfix = { 'loss': context['loss'] / (b + 1), } for lname, lvalue in context['losses'].items(): if lname in loss_weights and loss_weights[lname] != 1.: postfix[ f'{lname}({loss_weights[lname]})'] = lvalue / ( b + 1) else: postfix[lname] = lvalue / (b + 1) if postfix['loss'] == postfix['data']: del postfix['data'] pbar.set_postfix(postfix) optimizer.step() optimizer.zero_grad() context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'{epoch}.p'], 'loss': context['loss'] / context['batches_per_epoch'], 'losses': { k: v / context['batches_per_epoch'] for k, v in context['losses'].items() }, 'acc': context['acc'] / context['batches_per_epoch'], 'unsup_acc': context['unsup_acc'] / context['batches_per_epoch'], 'cached': False, }) if self._ewc: #print("EWC iteration") params_for_ewc = { n: p for n, p in ranker.named_parameters() if p.requires_grad } yield params_for_ewc #yield ranker.state_dict() # EWC implementation from https://github.com/ContinualAI/colab/blob/master/notebooks/intro_to_continual_learning.ipynb # fisher_dict = {} # optpar_dict = {} # # gradients accumulated can be used to calculate fisher # for name, param in ranker.state_dict().items(): # # print(name) # # print(param) # optpar_dict[name] = param.clone() # fisher_dict[name] = param.clone().pow(2) # for name, param in ranker.named_parameters(): # print(name) # print(param.grad) # optpar_dict[name] = param.data.clone() # fisher_dict[name] = param.grad.data.clone().pow(2) #pbar.set_postfix(postfix) optimizer.step() optimizer.zero_grad() #yield {'fisher':fisher_dict, 'optpar':optpar_dict} # save stuff ranker.save(files['weights'][f'{epoch}.p']) torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p']) files['loss.txt'][epoch] = context['loss'] for lname, lvalue in context['losses'].items(): files[f'loss_{lname}.txt'][epoch] = lvalue files['acc.tsv'][epoch] = context['acc'] files['unsup_acc.tsv'][epoch] = context['unsup_acc'] files['complete.tsv'][epoch] = 1 # mark as completed yield context
def _load_ranker_weights_epoch(self, ranker, vocab, trainer, dataset): epcoh = self.config['epoch'] base_path = util.path_model_trainer(ranker, vocab, trainer, dataset) weight_path = os.path.join(base_path, 'weights', f'{epcoh}.p') self._load_ranker_weights_file_path(ranker, weight_path)
def run(self): validator = self.valid_pred.pred_ctxt() top_epoch, top_value, top_train_ctxt, top_valid_ctxt = None, None, None, None prev_train_ctxt = None file_output = { 'ranker': self.trainer.ranker.path_segment(), 'vocab': self.trainer.vocab.path_segment(), 'trainer': self.trainer.path_segment(), 'dataset': self.trainer.dataset.path_segment(), 'valid_ds': self.valid_pred.dataset.path_segment(), 'validation_metric': self.config['val_metric'], 'logfile': util.path_log() } # initialize dataset(s) if not self.config['skip_ds_init']: self.trainer.dataset.init(force=False) self.valid_pred.dataset.init(force=False) if self.config['test']: self.test_pred.dataset.init(force=False) base_path_g = None ewc_params = {} # initialize EWC base_path = util.path_model_trainer(self.trainer.ranker, self.trainer.vocab, self.trainer, self.trainer.dataset) ewc_path = "/".join(base_path.split("/") [:-1]) + "/ewc-" + self.config['ewc'] + ".pickle" task_name = self.trainer.dataset.path_segment().split("_")[0] try: my_ewc = pickle.load(open(ewc_path, "rb")) except (OSError, IOError) as e: my_ewc = EWCValues(path=ewc_path, ewc_lambda=float(self.config['ewc'])) #print("EWC PATH: ",ewc_path, task_name) _train_it = self.trainer.iter_train( only_cached=self.config['only_cached'], _top_epoch=self.config.get('finetune'), ewc_params=my_ewc) for train_ctxt in _train_it: if self.config.get('onlytest'): base_path_g = train_ctxt['base_path'] self.logger.debug(f'[catfog] skipping training') top_train_ctxt = train_ctxt break if prev_train_ctxt is not None and top_epoch is not None and prev_train_ctxt is not top_train_ctxt: self._purge_weights(prev_train_ctxt) if train_ctxt['epoch'] >= 0 and not self.config['only_cached']: message = self._build_train_msg(train_ctxt) if train_ctxt['cached']: self.logger.debug(f'[train] [cached] {message}') else: self.logger.debug(f'[train] {message}') if train_ctxt['epoch'] == -1 and not self.config['initial_eval']: continue valid_ctxt = dict(validator(train_ctxt)) message = self._build_valid_msg(valid_ctxt) if valid_ctxt['epoch'] >= self.config['warmup']: if self.config['val_metric'] == '': top_epoch = valid_ctxt['epoch'] top_train_ctxt = train_ctxt top_valid_ctxt = valid_ctxt elif top_value is None or valid_ctxt['metrics'][ self.config['val_metric']] > top_value: message += ' <---' top_epoch = valid_ctxt['epoch'] top_value = valid_ctxt['metrics'][ self.config['val_metric']] if top_train_ctxt is not None: self._purge_weights(top_train_ctxt) top_train_ctxt = train_ctxt top_valid_ctxt = valid_ctxt else: if prev_train_ctxt is not None: self._purge_weights(prev_train_ctxt) if not self.config['only_cached']: if valid_ctxt['cached']: self.logger.debug(f'[valid] [cached] {message}') else: self.logger.info(f'[valid] {message}') if top_epoch is not None: epochs_since_imp = valid_ctxt['epoch'] - top_epoch if self.config[ 'early_stop'] > 0 and epochs_since_imp >= self.config[ 'early_stop']: self.logger.warn( 'stopping after epoch {epoch} ({early_stop} epochs with no ' 'improvement to {val_metric})'.format( **valid_ctxt, **self.config)) break if train_ctxt['epoch'] >= self.config['max_epoch']: self.logger.warn( 'stopping after epoch {max_epoch} (max_epoch)'.format( **self.config)) break prev_train_ctxt = train_ctxt if not self.config.get('onlytest'): self.logger.info('top validation epoch={} {}={}'.format( top_epoch, self.config['val_metric'], top_value)) self.logger.info(f'[catfog: top_train_ctxt] {top_train_ctxt}') file_output.update({ 'valid_epoch': top_epoch, 'valid_run': top_valid_ctxt['run_path'], 'valid_metrics': top_valid_ctxt['metrics'], }) # save top train epoch for faster testing without needing the retraining phase if not self.config.get('onlytest'): #pickle.dump(top_epoch, open( top_train_ctxt['base_path']+"/top_epoch.pickle", "wb") ) # move best to -2.p self.trainer.save_best(top_epoch, top_train_ctxt['base_path']) # EWC, recover parms from model after extra epoch self.trainer.setewc() ewc_params = next(_train_it) #ewc_params.cpu() # EWC implementation from https://github.com/ContinualAI/colab/blob/master/notebooks/intro_to_continual_learning.ipynb fisher_dict = {} optpar_dict = {} # gradients accumulated can be used to calculate fisher for name, param in ewc_params.items(): # print(name) # print(param) optpar_path = my_ewc.getOptpar(task_name, name) pickle.dump(param, open(optpar_path, "wb")) #optpar_dict[name] = param.clone() fisher_path = my_ewc.getFisher(task_name, name) param = param.pow(2) pickle.dump(param, open(fisher_path, "wb")) #fisher_dict[name] = param.clone().pow(2) my_ewc.addNew(task_name, name) #### # load EWC object # get task ID -> up before the loop train_iter #### #print("EWC params", ewc_params) #my_ewc.setValues(task_name, ewc_params) #my_ewc.setValues(task_name, {'fisher':fisher_dict, 'optpar':optpar_dict}) ### # save EWC object ### pickle.dump(my_ewc, open(ewc_path, "wb")) if self.config.get( 'onlytest' ): # for onlytest use also finetune=true, to load best epoch at first iteration self.logger.debug(f'[catfog] loading top context') #top_epoch = pickle.load(open(base_path_g+"/top_epoch.pickle", "rb")) #self.logger.debug(f'[catfog] loading top context ... {top_epoch} epoch') #top_train_ctxt = self.trainer.trainCtx(top_epoch) self.logger.debug( f'[catfog] Top epoch context: {dict(top_train_ctxt)}') if self.config['test']: self.logger.info(f'Starting load ranker') top_train_ctxt['ranker'] = onir.trainers.base._load_ranker( top_train_ctxt['ranker'](), top_train_ctxt['ranker_path']) self.logger.info(f'Starting test predictor run') with self.logger.duration('testing'): test_ctxt = self.test_pred.run(top_train_ctxt) file_output.update({ 'test_ds': self.test_pred.dataset.path_segment(), 'test_run': test_ctxt['run_path'], 'test_metrics': test_ctxt['metrics'], }) with open(util.path_modelspace() + '/val_test.jsonl', 'at') as f: json.dump(file_output, f) f.write('\n') if not self.config.get('onlytest'): self.logger.info('valid run at {}'.format(valid_ctxt['run_path'])) if self.config['test']: self.logger.info('test run at {}'.format(test_ctxt['run_path'])) if not self.config.get('onlytest'): self.logger.info('valid ' + self._build_valid_msg(top_valid_ctxt)) if self.config['test']: self.logger.info('test ' + self._build_valid_msg(test_ctxt)) self._write_metrics_file(test_ctxt)
def iter_train(self, only_cached=False, _top_epoch=False): epoch = -1 base_path = util.path_model_trainer(self.ranker, self.vocab, self, self.dataset) context = { 'epoch': epoch, 'batch_size': self.config['batch_size'], 'batches_per_epoch': self.config['batches_per_epoch'], 'num_microbatches': 1 if self.config['grad_acc_batch'] == 0 else self.config['batch_size'] // self.config['grad_acc_batch'], 'device': self.device, 'base_path': base_path, } files = trainers.misc.PathManager(base_path) self.logger.info(f'train path: {base_path}, batches_per_epoch : {context["batches_per_epoch"]}') b_count = context['batches_per_epoch'] * context['num_microbatches'] * self.batch_size ranker = self.ranker.to(self.device) optimizer = self.create_optimizer() if f'{epoch}.p' not in files['weights']: ranker.save(files['weights'][f'{epoch}.p']) if f'{epoch}.p' not in files['optimizer']: torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p']) context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'{epoch}.p'], }) #load ranker/optimizer if _top_epoch: __path="/".join(base_path.split("/")[:-1])+"/"+self.config['pipeline'] # trainer.pipeline=msmarco_train_bm25_k1-0.82_b-0.68.100_mspairs self.logger.info(f'loading prev model : {__path}') w_path = os.path.join(__path, 'weights','-2.p') oppt_path = os.path.join(__path, 'optimizer','-2.p') ranker.load(w_path) optimizer.load_state_dict(torch.load(oppt_path)) if w_path!=files['weights']['-2.p']: copyfile(w_path,files['weights']['-2.p'] ) copyfile(oppt_path,files['optimizer']['-2.p'] ) context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'-2.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'-2.p'], }) yield context # before training while True: context = dict(context) epoch = context['epoch'] = context['epoch'] + 1 if epoch in files['complete.tsv']: context.update({ 'loss': files['loss.txt'][epoch], 'data_loss': files['data_loss.txt'][epoch], 'losses': {}, 'acc': files['acc.tsv'][epoch], 'unsup_acc': files['unsup_acc.tsv'][epoch], 'ranker': _load_ranker(ranker, files['weights'][f'{epoch}.p']), 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': _load_optimizer(optimizer, files['optimizer'][f'{epoch}.p']), 'optimizer_path': files['optimizer'][f'{epoch}.p'], 'cached': True, }) if not only_cached: self.fast_forward(b_count) # skip this epoch yield context continue if only_cached: break # no more cached # forward to previous versions (if needed) ranker = context['ranker']() optimizer = context['optimizer']() ranker.train() context.update({ 'loss': 0.0, 'losses': {}, 'acc': 0.0, 'unsup_acc': 0.0, }) with tqdm(leave=False, total=b_count, ncols=100, desc=f'train {epoch}') as pbar: for b in range(context['batches_per_epoch']): for _ in range(context['num_microbatches']): self.epoch = epoch train_batch_result = self.train_batch() #print("BATCH FINALIZED") losses = train_batch_result['losses'] loss_weights = train_batch_result['loss_weights'] acc = train_batch_result.get('acc') unsup_acc = train_batch_result.get('unsup_acc') loss = sum(losses[k] * loss_weights.get(k, 1.) for k in losses) / context['num_microbatches'] context['loss'] += loss.item() for lname, lvalue in losses.items(): context['losses'].setdefault(lname, 0.) context['losses'][lname] += lvalue.item() / context['num_microbatches'] if acc is not None: context['acc'] += acc.item() / context['num_microbatches'] if unsup_acc is not None: context['unsup_acc'] += unsup_acc.item() / context['num_microbatches'] if loss.grad_fn is not None: if hasattr(optimizer, 'backward'): optimizer.backward(loss) else: loss.backward() else: self.logger.warn('loss has no grad_fn; skipping batch') pbar.update(self.batch_size) postfix = { 'loss': context['loss'] / (b + 1), } for lname, lvalue in context['losses'].items(): if lname in loss_weights and loss_weights[lname] != 1.: postfix[f'{lname}({loss_weights[lname]})'] = lvalue / (b + 1) else: postfix[lname] = lvalue / (b + 1) if postfix['loss'] == postfix['data']: del postfix['data'] pbar.set_postfix(postfix) optimizer.step() optimizer.zero_grad() context.update({ 'ranker': lambda: ranker, 'ranker_path': files['weights'][f'{epoch}.p'], 'optimizer': lambda: optimizer, 'optimizer_path': files['optimizer'][f'{epoch}.p'], 'loss': context['loss'] / context['batches_per_epoch'], 'losses': {k: v / context['batches_per_epoch'] for k, v in context['losses'].items()}, 'acc': context['acc'] / context['batches_per_epoch'], 'unsup_acc': context['unsup_acc'] / context['batches_per_epoch'], 'cached': False, }) # save stuff ranker.save(files['weights'][f'{epoch}.p']) torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p']) files['loss.txt'][epoch] = context['loss'] for lname, lvalue in context['losses'].items(): files[f'loss_{lname}.txt'][epoch] = lvalue files['acc.tsv'][epoch] = context['acc'] files['unsup_acc.tsv'][epoch] = context['unsup_acc'] files['complete.tsv'][epoch] = 1 # mark as completed yield context