def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation set_seed.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state["metrics"] = metrics.state_dict() checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, )
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args['distributed_training'] ['fix_batches_to_gpus'], # shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']), shuffle=False, ) update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args['optimization']['update_freq']) else args['optimization']['update_freq'][-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args['dataset']['valid_subset'].split(',') max_update = args['optimization']['max_update'] or math.inf num_updates = 0 # init as 0, for zero-shot learning for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args['common']['log_interval'] == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset epoch-level meters metrics.reset_meters('train_inner') if (not args['dataset']['disable_validation'] and args['checkpoint']['save_interval_updates'] > 0 and num_updates % args['checkpoint']['save_interval_updates'] == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def single_main(args, init_distributed=False): assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' metrics.reset() # 0. Initialize CUDA and distributed training if torch.cuda.is_available() and not args['common']['cpu']: torch.cuda.set_device(args['distributed_training']['device_id']) set_seed.set_seed(args['common']['seed']) if init_distributed: args['distributed_training'][ 'distributed_rank'] = distributed_utils.distributed_init(args) # Verify checkpoint directory if distributed_utils.is_master(args): save_dir = args['checkpoint']['save_dir'] checkpoint_utils.verify_checkpoint_directory(save_dir) PathManager.rm(os.path.join( save_dir, '*.pt')) # this code will remove pre-trained models # 1. Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # 2. Load valid dataset (we load training data below, based on the latest checkpoint) task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1) # 3. Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) LOGGER.info(model) LOGGER.info('model {}, criterion {}'.format(args['model']['arch'], criterion.__class__.__name__)) LOGGER.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # 4. Build trainer trainer = Trainer(args, task, model, criterion) LOGGER.info('training on {} GPUs'.format( args['distributed_training']['distributed_world_size'])) LOGGER.info( 'max tokens per GPU = {} and max sentences per GPU = {}'.format( args['dataset']['max_tokens'], args['dataset']['max_sentences'], )) # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer, combine=False) # 6. Train until the learning rate gets too small max_epoch = args['optimization']['max_epoch'] or math.inf max_update = args['optimization']['max_update'] or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() valid_subsets = args['dataset']['valid_subset'].split(',') while (lr > args['optimization']['min_lr'] and epoch_itr.next_epoch_idx <= max_epoch and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr) if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[ 'dataset']['validate_interval'] == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) # early stop if should_stop_early(args, valid_losses[0]): LOGGER.info( 'early stop since valid performance hasn\'t improved for last {} runs' .format(args['checkpoint']['patience'])) break epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, combine=False, # TODO to be checked # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in args['task']['data']), ) train_meter.stop() LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))
def save_checkpoint(args, trainer, epoch_itr, val_loss): from ncc import meters from ncc.utils import distributed_utils prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if args['checkpoint'][ 'maximize_best_checkpoint_metric'] else min save_checkpoint.best = best_function(val_loss, prev_best) if args['checkpoint']['no_save'] or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args['checkpoint'][ 'maximize_best_checkpoint_metric'] else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}.pt".format(epoch)] = ( end_of_epoch and not args['checkpoint']['no_epoch_checkpoints'] and epoch % args['checkpoint']['save_interval'] == 0) checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = ( not end_of_epoch and args['checkpoint']['save_interval_updates'] > 0 and updates % args['checkpoint']['save_interval_updates'] == 0) checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) if val_loss is not None and args['checkpoint']['keep_best_checkpoints'] > 0: checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( args['checkpoint']['best_checkpoint_metric'], val_loss)] = (not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) checkpoint_conds[ "checkpoint_last.pt"] = not args['checkpoint']['no_last_checkpoints'] extra_state = { "train_iterator": epoch_itr.state_dict(), "val_loss": val_loss } if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(args['checkpoint']['save_dir'], fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: PathManager.copy(checkpoints[0], cp) write_timer.stop() LOGGER.info( "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {:.6f} seconds)" .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum)) if not end_of_epoch and args['checkpoint']['keep_interval_updates'] > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args['checkpoint']['save_dir'], pattern=r"checkpoint_\d+_(\d+)\.pt") for old_chk in checkpoints[ args['checkpoint']['keep_interval_updates']:]: if os.path.lexists(old_chk): os.remove(old_chk) if args['checkpoint']['keep_last_epochs'] > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args['checkpoint']['save_dir'], pattern=r"checkpoint(\d+)\.pt") for old_chk in checkpoints[args['checkpoint']['keep_last_epochs']:]: if os.path.lexists(old_chk): os.remove(old_chk) if args['checkpoint']['keep_best_checkpoints'] > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( args['checkpoint']['save_dir'], pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( args['checkpoint']['best_checkpoint_metric'])) if not args['checkpoint']['maximize_best_checkpoint_metric']: checkpoints = checkpoints[::-1] for old_chk in checkpoints[ args['checkpoint']['keep_best_checkpoints']:]: if os.path.lexists(old_chk): os.remove(old_chk)
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation set_seed.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) accs, mrrs, maps, ndcgs = [], [], [], [] trainer.model.eval() trainer.criterion.eval() with torch.no_grad(): for sample in progress: sample = trainer._prepare_sample(sample) inputs = list(sample['net_input'].values()) code_repr = trainer.model.code_forward(*inputs[:6]) desc_repr = trainer.model.desc_forward(*inputs[6:8]) code_repr = code_repr / code_repr.norm(dim=-1, keepdim=True) desc_repr = desc_repr / desc_repr.norm(dim=-1, keepdim=True) similarity = code_repr @ desc_repr.t() acc, mrr, map, ndcg = inference(similarity) accs.append(acc.mean().item()) mrrs.append(mrr.mean().item()) maps.append(map.mean().item()) ndcgs.append(ndcg.mean().item()) accs = round(float(np.mean(accs)), 6) mrrs = round(float(np.mean(mrrs)), 6) maps = round(float(np.mean(maps)), 6) ndcgs = round(float(np.mean(ndcgs)), 6) stats = {'acc': accs, 'mrr': mrrs, 'map': maps, 'ndcg': ndcgs} progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses
def save_expert_outputs(args, task, trainer): print("| Start saving expert outputs..") expert_outputs = gen_outputs(args, task, trainer) output_path = os.path.join( args['checkpoint']['save_dir'], 'train_output.json.{}'.format( args['distributed_training']['distributed_rank'])) print('Save topk output at {}'.format(output_path)) json.dump(expert_outputs, open(output_path, 'w')) # distributed_utils.barrier(args, 'save_expert_outputs') if distributed_utils.is_master(args): expert_outputs_ = [] # copy valid bleu result val_bleu_path1 = os.path.join(args['checkpoint']['save_dir'], 'val_bleu.json') val_bleu_path2 = os.path.join( args['task']['data'], 'expert_bleu_{}_{}_{}.json'.format( '_'.join(args['task']['programming_langs']), args['task']['source_lang'], args['task']['target_lang'])) cmd = 'cp {} {}'.format(val_bleu_path1, val_bleu_path2) print(cmd) os.system(cmd) for i in range(args['distributed_training']['distributed_world_size']): output_path = os.path.join(args['checkpoint']['save_dir'], 'train_output.json.{}'.format(i)) expert_outputs_.append(json.load(open(output_path, 'r'))) try: os.remove(output_path) except: pass for j in range(len(expert_outputs_[0])): for i in range( args['distributed_training']['distributed_world_size']): if expert_outputs_[i][j] is not None: expert_outputs[j] = expert_outputs_[i][j] break assert expert_outputs[j] is not None path = os.path.join( args['task']['data'], '{}_{}_{}_topk_idx'.format( '_'.join(args['task']['programming_langs']), args['task']['source_lang'], args['task']['target_lang'])) TeacherOutputDataset.save_bin(path, [o[0] for o in expert_outputs], np.int32) path = os.path.join( args['task']['data'], '{}_{}_{}_topk_prob'.format( '_'.join(args['task']['programming_langs']), args['task']['source_lang'], args['task']['target_lang'])) TeacherOutputDataset.save_bin(path, [o[1] for o in expert_outputs], np.float) LOGGER.info( "| Save expert@{}_{}_{}. Bleu.Json: {}, TopK.Idx/Prob: {}.".format( '_'.join(args['task']['programming_langs']), args['task']['source_lang'], args['task']['target_lang'], val_bleu_path2, path, ))
def is_data_parallel_master(self): return distributed_utils.is_master(self.args)
def validate(args, trainer, task, epoch_itr, valid_subsets, dev_subsets, dev_refs): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation utils.set_torch_seed(args['dataset']['fixed_validation_seed']) for subset in valid_subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) # calculate accuracy match = stats.pop('match') total = stats.pop('total') valid_acc = match / total progress.print( { 'accuracy': f'{round(100. * valid_acc, 2)}%', 'bleu': stats['bleu'], 'loss': stats['loss'], }, tag=subset, step=trainer.get_num_updates()) # for subset in dev_subsets: # hypotheses, references = {}, dev_refs # # # Initialize data iterator # itr = task.get_batch_iterator( # dataset=task.dataset(subset), # max_tokens=args['dataset']['max_tokens_valid'], # max_sentences=args['dataset']['max_sentences_valid'], # max_positions=utils.resolve_max_positions( # task.max_positions(), # trainer.get_model().max_positions(), # ), # ignore_invalid_inputs=args['dataset']['skip_invalid_size_inputs_valid_test'], # required_batch_size_multiple=args['dataset']['required_batch_size_multiple'], # seed=args['common']['seed'], # num_shards=args['distributed_training']['distributed_world_size'], # shard_id=args['distributed_training']['distributed_rank'], # num_workers=args['dataset']['num_workers'], # ).next_epoch_itr(shuffle=False) # progress = progress_bar.progress_bar( # itr, # log_format=args['common']['log_format'], # log_interval=args['common']['log_interval'], # epoch=epoch_itr.epoch, # prefix=f"valid on '{subset}' subset", # tensorboard_logdir=( # args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None # ), # default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), # ) # # # create a new root metrics aggregator so validation metrics # # don't pollute other aggregators (e.g., train meters) # with metrics.aggregate(new_root=True) as agg: # for sample in progress: # with torch.no_grad(): # trainer.model.eval() # trainer.criterion.eval() # sample = trainer._prepare_sample(sample) # hyps, _, _, ids = trainer.task.step_out(sample, trainer.model) # for idx, hypo in zip(ids, hyps): # hypotheses[idx] = hypo # # from third_party.pycocoevalcap.bleu.google_bleu import compute_bleu # assert set(hypotheses.keys()) == set(references.keys()) # bleus = [ # compute_bleu([references[idx]], [hypotheses[idx]], smooth=Trainer)[0] # for idx in hypotheses.keys() # ] # dev_bleu = round(100. * sum(bleus) / len(bleus), 2) # # log validation stats # stats = agg.get_smoothed_values() # stats['bleu'] = dev_bleu # stats = get_dev_stats(args, trainer, stats) # progress.print(stats, tag=subset, step=trainer.get_num_updates()) # return valid_acc, dev_bleu return valid_acc, None
def single_main(args, init_distributed=False): assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' metrics.reset() # 0. Initialize CUDA and distributed training if torch.cuda.is_available() and not args['common']['cpu']: torch.cuda.set_device(args['distributed_training']['device_id']) random.seed(args['common']['seed']) np.random.seed(args['common']['seed']) torch.manual_seed(args['common']['seed']) torch.cuda.manual_seed(args['common']['seed']) if init_distributed: args['distributed_training'][ 'distributed_rank'] = distributed_utils.distributed_init(args) # Verify checkpoint directory if distributed_utils.is_master(args): save_dir = args['checkpoint']['save_dir'] checkpoint_utils.verify_checkpoint_directory(save_dir) remove_files(save_dir, 'pt') # this code will remove pre-trained models # 1. Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # 2. Load valid dataset (we load training data below, based on the latest checkpoint) # calculate accuracy for decay learning rate task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1) # # compute meteor to select model # task.load_dataset(args['dataset']['dev_subset'], combine=False, epoch=1) # # load dev/ref.txt # dev_refs = load_refs(os.path.join(args['task']['data'], args['dataset']['dev_ref_subset'])) # 3. Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) LOGGER.info(model) LOGGER.info('model {}, criterion {}'.format(args['model']['arch'], criterion.__class__.__name__)) LOGGER.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # 4. Build trainer trainer = Trainer(args, task, model, criterion) LOGGER.info('training on {} GPUs'.format( args['distributed_training']['distributed_world_size'])) LOGGER.info( 'max tokens per GPU = {} and max sentences per GPU = {}'.format( args['dataset']['max_tokens'], args['dataset']['max_sentences'], )) # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer, combine=False) # 6. Train until the learning rate gets too small max_epoch = args['optimization']['max_epoch'] or math.inf max_update = args['optimization']['max_update'] or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() valid_subsets = args['dataset']['valid_subset'].split(',') dev_subsets = args['dataset']['dev_subset'].split(',') valid_accs_after_60e = [] while (lr > args['optimization']['min_lr'] and epoch_itr.next_epoch_idx <= max_epoch and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr) if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[ 'dataset']['validate_interval'] == 0: valid_acc, dev_prf = validate(args, trainer, task, epoch_itr, valid_subsets, dev_subsets, dev_refs=None) else: valid_acc, dev_prf = None, None # if epoch_itr.next_epoch_idx > 61 and valid_acc < valid_accs_after_60e[-1]: # """ # We start with a learning rate of 0.5 and start # decaying it by a factor of 0.8 after 60 epochs if # accuracy on the validation set goes down, and # terminate training when the learning rate goes # below 0.001. # """ # lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink']) # # if epoch_itr.epoch >= 60: # valid_accs_after_60e.append(valid_acc) # if len(valid_accs_after_60e) > 10 and valid_accs_after_60e[-5] >= valid_acc: # lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink']) # valid_accs_after_60e.append(valid_acc) if len(valid_accs_after_60e ) > 10 and valid_accs_after_60e[-5] >= valid_acc: lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink']) # eval on dev and dev.ref data # save checkpoint if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_acc) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, combine=False, # TODO to be checked # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in args['task']['data']), ) train_meter.stop() LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))