예제 #1
0
def save_checkpoint(args, trainer, epoch_itr, val_loss, save_extra_state):
    from fairseq import distributed_utils, meters

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if args.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args.no_save or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args.maximize_best_checkpoint_metric else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch
        and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0
    )
    checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
        not end_of_epoch
        and args.save_interval_updates > 0
        and updates % args.save_interval_updates == 0
    )
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best)
    )
    if val_loss is not None and args.keep_best_checkpoints > 0:
        checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
            args.best_checkpoint_metric, val_loss)] = (
            not hasattr(save_checkpoint, "best")
            or is_better(val_loss, save_checkpoint.best)
        )
    checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints

    extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})
    extra_state.update(save_extra_state)

    checkpoints = [
        os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            PathManager.copy(checkpoints[0], cp, overwrite=True)

        write_timer.stop()
        logger.info(
            "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
                checkpoints[0], epoch, updates, val_loss, write_timer.sum
            )
        )

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
        )
        for old_chk in checkpoints[args.keep_interval_updates :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[args.keep_last_epochs :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
        if not args.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[args.keep_best_checkpoints:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
예제 #2
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return

    write_timer = StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}.pt'.format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best.pt'] = (
        val_loss is not None and (not hasattr(save_checkpoint, 'best')
                                  or val_loss < save_checkpoint.best))
    checkpoint_conds[
        'checkpoint_last.pt'] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            shutil.copyfile(checkpoints[0], cp)

        write_timer.stop()
        print(
            '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'
            .format(checkpoints[0], epoch, updates, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir,
            pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_utils.checkpoint_paths(
            args.save_dir,
            pattern=r'checkpoint(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
        'produce a new checkpoint', )
    # fmt: off
    parser.add_argument('--inputs',
                        nargs='+',
                        help='Input checkpoint file paths.')
    parser.add_argument(
        '--output',
        required=True,
        metavar='FILE',
        help=
        'Write the new checkpoint containing the averaged weights to this path.'
    )
    num_group = parser.add_mutually_exclusive_group()
    num_group.add_argument(
        '--num-epoch-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
        'and average last this many of them.')
    num_group.add_argument(
        '--num-update-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
        'and average last this many of them.')
    parser.add_argument(
        '--checkpoint-upper-bound',
        type=int,
        help=
        'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
        'when using --num-update-checkpoints, this will set an upper bound on which update to use'
        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.'
        'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500'
    )
    parser.add_argument('--all-loss-checkpoints',
                        action='store_true',
                        help='Input checkpoint file paths.')
    parser.add_argument('--dir', help='Input checkpoint file paths.')
    # fmt: on
    args = parser.parse_args()
    print(args)

    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

    assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \
        '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints'
    assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
        'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'

    if num is not None:
        args.inputs = last_n_checkpoints(
            args.inputs,
            num,
            is_update_based,
            upper_bound=args.checkpoint_upper_bound,
        )
        print('averaging checkpoints: ', args.inputs)

    if args.all_loss_checkpoints:
        args.inputs = checkpoint_paths(
            args.dir, pattern=r"checkpoint\.best_loss_(\d+\.?\d*)\.pt")

    new_state = average_checkpoints(args.inputs)
    with PathManager.open(args.output, 'wb') as f:
        torch.save(new_state, f)
    print('Finished writing averaged checkpoint to {}'.format(args.output))