def _save_averaged_checkpoint(args, extra_state): epoch, offset = extra_state["epoch"], extra_state["batch_offset"] if not hasattr(_save_averaged_checkpoint, "last_avg_checkpoints"): if args.max_checkpoints_kept == 0: raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.") _save_averaged_checkpoint.last_avg_checkpoints = utils.ManagedCheckpoints( max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0 ) last_checkpoints = extra_state["last_checkpoints"].get_last_n( 1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints ) if args.log_verbose: print( f"Reading {len(last_checkpoints)} previous " f"checkpoints for averaging in epoch {epoch}, offset {offset}.", flush=True, ) averaged_state = average_checkpoints.average_checkpoints(last_checkpoints) filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt") _save_averaged_checkpoint.last_avg_checkpoints.append(filename) if args.log_verbose: print( f"Preparing to save averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) torch.save(averaged_state, filename) if args.log_verbose: print( f"Finished saving averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) return filename
def save_checkpoint(trainer, args, extra_state): epoch = extra_state["epoch"] batch_offset = extra_state["batch_offset"] val_loss = extra_state["val_loss"] if args.log_verbose: print( f"Preparing to save checkpoints for epoch {epoch}, " f"offset {batch_offset}. ", flush=True, ) if "last_checkpoints" not in extra_state: if args.generate_bleu_eval_avg_checkpoints < 1: raise argparse.ArgumentTypeError( "--generate-bleu-eval-avg-checkpoints must be >= 1.") extra_state["last_checkpoints"] = utils.ManagedCheckpoints( max(args.generate_bleu_eval_avg_checkpoints, args.max_checkpoints_kept), # Don't auto_clear checkpoints for no_epoch_checkpoints, because # we are only going to reuse the same file. auto_clear=(args.max_checkpoints_kept > 0 and not args.no_epoch_checkpoints), ) # batch_offset being None means that we're at the end of an epoch. if batch_offset is None: if not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, f"checkpoint{epoch}.pt") save_checkpoint_maybe_continuous(epoch_filename, trainer, extra_state) extra_state["last_checkpoints"].append(epoch_filename) assert val_loss is not None if ("checkpoint_lowest_loss" not in extra_state or val_loss < extra_state["checkpoint_lowest_loss"]): extra_state["checkpoint_lowest_loss"] = val_loss best_filename = os.path.join(args.save_dir, "checkpoint_best.pt") save_checkpoint_maybe_continuous(best_filename, trainer, extra_state) # Otherwise, we're in the middle of an epoch. elif not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt") save_checkpoint_maybe_continuous(epoch_filename, trainer, extra_state) extra_state["last_checkpoints"].append(epoch_filename) last_filename = os.path.join(args.save_dir, "checkpoint_last.pt") save_checkpoint_maybe_continuous(last_filename, trainer, extra_state) # This ensures we'll always have at least one checkpoint in the list to use # for BLEU eval, even if we're not saving epoch checkpoints. if args.no_epoch_checkpoints: extra_state["last_checkpoints"].append(epoch_filename) if args.log_verbose: print( f"Finished saving checkpoints for epoch {epoch}, " f"offset {batch_offset}.", flush=True, )