def save_checkpoint(args, trainer, epoch_itr, val_loss, save_extra_state): from fairseq import distributed_utils, meters prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}.pt".format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = ( not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) if val_loss is not None and args.keep_best_checkpoints > 0: checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( args.best_checkpoint_metric, val_loss)] = ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) extra_state.update(save_extra_state) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: PathManager.copy(checkpoints[0], cp, overwrite=True) write_timer.stop() logger.info( "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( checkpoints[0], epoch, updates, val_loss, write_timer.sum ) ) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" ) for old_chk in checkpoints[args.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") for old_chk in checkpoints[args.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric)) if not args.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[args.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return write_timer = StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)) checkpoint_conds[ 'checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } if hasattr(save_checkpoint, 'best'): extra_state.update({'best': save_checkpoint.best}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: shutil.copyfile(checkpoints[0], cp) write_timer.stop() print( '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)' .format(checkpoints[0], epoch, updates, write_timer.sum)) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', ) for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint(\d+)\.pt', ) for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)
def main(): parser = argparse.ArgumentParser( description='Tool to average the params of input checkpoints to ' 'produce a new checkpoint', ) # fmt: off parser.add_argument('--inputs', nargs='+', help='Input checkpoint file paths.') parser.add_argument( '--output', required=True, metavar='FILE', help= 'Write the new checkpoint containing the averaged weights to this path.' ) num_group = parser.add_mutually_exclusive_group() num_group.add_argument( '--num-epoch-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 'and average last this many of them.') num_group.add_argument( '--num-update-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 'and average last this many of them.') parser.add_argument( '--checkpoint-upper-bound', type=int, help= 'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' ) parser.add_argument('--all-loss-checkpoints', action='store_true', help='Input checkpoint file paths.') parser.add_argument('--dir', help='Input checkpoint file paths.') # fmt: on args = parser.parse_args() print(args) num = None is_update_based = False if args.num_update_checkpoints is not None: num = args.num_update_checkpoints is_update_based = True elif args.num_epoch_checkpoints is not None: num = args.num_epoch_checkpoints assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \ '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints' assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' if num is not None: args.inputs = last_n_checkpoints( args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, ) print('averaging checkpoints: ', args.inputs) if args.all_loss_checkpoints: args.inputs = checkpoint_paths( args.dir, pattern=r"checkpoint\.best_loss_(\d+\.?\d*)\.pt") new_state = average_checkpoints(args.inputs) with PathManager.open(args.output, 'wb') as f: torch.save(new_state, f) print('Finished writing averaged checkpoint to {}'.format(args.output))