def save_checkpoint_bleu(args, trainer, epoch_itr, val_loss, val_bleu): if args.no_save or not distributed_utils.is_master(args): return epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best_bleu.pt'] = ( val_bleu is not None and (not hasattr(save_checkpoint_bleu, 'best_bleu') or val_bleu > save_checkpoint_bleu.best_bleu)) checkpoint_conds[ 'checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best_bleu = getattr(save_checkpoint_bleu, 'best_bleu', val_bleu) if val_bleu is not None: save_checkpoint_bleu.best_bleu = max(val_bleu, prev_best_bleu) extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, 'val_bleu': val_bleu, } if hasattr(save_checkpoint_bleu, 'best_bleu'): extra_state.update({'best_bleu': save_checkpoint_bleu.best_bleu}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order # checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt') checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return write_timer = StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0 ) checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0 ) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) ) checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } if hasattr(save_checkpoint, 'best'): extra_state.update({'best': save_checkpoint.best}) checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) write_timer.stop() print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( checkpoints[0], epoch, updates, write_timer.sum))
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)) checkpoint_conds[ 'checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'best': save_checkpoint.best, 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: for fn in checkpoints: if os.path.exists(fn): os.remove(fn) if not end_of_epoch and args.keep_interval_updates > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) else: trainer.save_checkpoint(checkpoints[0], extra_state) for fn in checkpoints[1:]: #os.symlink(os.path.basename(checkpoints[0]), fn) trainer.save_checkpoint(fn, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0 ) checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0 ) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) ) checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'best': save_checkpoint.best, 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: os.remove(old_chk)