Exemplo n.º 1
0
def _save_averaged_checkpoint(args, extra_state):
    epoch, offset = extra_state["epoch"], extra_state["batch_offset"]
    if not hasattr(_save_averaged_checkpoint, "last_avg_checkpoints"):
        if args.max_checkpoints_kept == 0:
            raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.")
        _save_averaged_checkpoint.last_avg_checkpoints = utils.ManagedCheckpoints(
            max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0
        )

    last_checkpoints = extra_state["last_checkpoints"].get_last_n(
        1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints
    )
    if args.log_verbose:
        print(
            f"Reading {len(last_checkpoints)} previous "
            f"checkpoints for averaging in epoch {epoch}, offset {offset}.",
            flush=True,
        )
    averaged_state = average_checkpoints.average_checkpoints(last_checkpoints)
    filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt")
    _save_averaged_checkpoint.last_avg_checkpoints.append(filename)
    if args.log_verbose:
        print(
            f"Preparing to save averaged checkpoint for "
            f"epoch {epoch}, offset {offset}.",
            flush=True,
        )
    torch.save(averaged_state, filename)
    if args.log_verbose:
        print(
            f"Finished saving averaged checkpoint for "
            f"epoch {epoch}, offset {offset}.",
            flush=True,
        )
    return filename
Exemplo n.º 2
0
def save_checkpoint(trainer, args, extra_state):
    epoch = extra_state["epoch"]
    batch_offset = extra_state["batch_offset"]
    val_loss = extra_state["val_loss"]

    if args.log_verbose:
        print(
            f"Preparing to save checkpoints for epoch {epoch}, "
            f"offset {batch_offset}. ",
            flush=True,
        )

    if "last_checkpoints" not in extra_state:
        if args.generate_bleu_eval_avg_checkpoints < 1:
            raise argparse.ArgumentTypeError(
                "--generate-bleu-eval-avg-checkpoints must be >= 1.")
        extra_state["last_checkpoints"] = utils.ManagedCheckpoints(
            max(args.generate_bleu_eval_avg_checkpoints,
                args.max_checkpoints_kept),
            # Don't auto_clear checkpoints for no_epoch_checkpoints, because
            # we are only going to reuse the same file.
            auto_clear=(args.max_checkpoints_kept > 0
                        and not args.no_epoch_checkpoints),
        )

    # batch_offset being None means that we're at the end of an epoch.
    if batch_offset is None:
        if not args.no_epoch_checkpoints:
            epoch_filename = os.path.join(args.save_dir,
                                          f"checkpoint{epoch}.pt")
            save_checkpoint_maybe_continuous(epoch_filename, trainer,
                                             extra_state)
            extra_state["last_checkpoints"].append(epoch_filename)

        assert val_loss is not None

        if ("checkpoint_lowest_loss" not in extra_state
                or val_loss < extra_state["checkpoint_lowest_loss"]):
            extra_state["checkpoint_lowest_loss"] = val_loss
            best_filename = os.path.join(args.save_dir, "checkpoint_best.pt")
            save_checkpoint_maybe_continuous(best_filename, trainer,
                                             extra_state)

    # Otherwise, we're in the middle of an epoch.
    elif not args.no_epoch_checkpoints:
        epoch_filename = os.path.join(args.save_dir,
                                      f"checkpoint{epoch}_{batch_offset}.pt")
        save_checkpoint_maybe_continuous(epoch_filename, trainer, extra_state)
        extra_state["last_checkpoints"].append(epoch_filename)

    last_filename = os.path.join(args.save_dir, "checkpoint_last.pt")
    save_checkpoint_maybe_continuous(last_filename, trainer, extra_state)

    # This ensures we'll always have at least one checkpoint in the list to use
    # for BLEU eval, even if we're not saving epoch checkpoints.
    if args.no_epoch_checkpoints:
        extra_state["last_checkpoints"].append(epoch_filename)
    if args.log_verbose:
        print(
            f"Finished saving checkpoints for epoch {epoch}, "
            f"offset {batch_offset}.",
            flush=True,
        )