def default_extra_state(args) -> Dict[str, Any]: return { "epoch": 1, "batch_offset": 0, "start_time": time.time(), "val_loss": None, "checkpoint_lowest_loss": None, "validate": { "lowest_loss": None, "num_since_best": None }, "last_bleu_eval": 0, "evaluate_bleu": { "best": None, "best_epoch": None, "num_since_best": None }, "last_checkpoints": ManagedCheckpoints( max(args.generate_bleu_eval_avg_checkpoints, args.max_checkpoints_kept), # Don't auto_clear checkpoints for no_epoch_checkpoints, because # we are only going to reuse the same file. auto_clear=(args.max_checkpoints_kept > 0 and not args.no_epoch_checkpoints), ), }
def default_extra_state(args) -> Dict[str, Any]: return { "epoch": 1, "batch_offset": 0, "start_time": time.time(), # We have both checkpoint_lowest_loss and tune_eval.lowest_loss since we # may have seen a lower loss during validation and updated # tune_eval.lowest_loss, but may not have written a new checkpoint with # that loss yet. "checkpoint_lowest_loss": None, "tune_eval": { "loss": None, "perplexity": None, "lowest_loss": None, "num_since_best": 0, }, # "last_eval_bleu": 0, "tune_bleu": { "current": None, "best": None, "best_epoch": None, "num_since_best": 0, "last_eval_step": 0, }, "last_checkpoints": ManagedCheckpoints( max(args.generate_bleu_eval_avg_checkpoints, args.max_checkpoints_kept), # Don't auto_clear checkpoints for no_epoch_checkpoints, because # we are only going to reuse the same file. auto_clear=( args.max_checkpoints_kept > 0 and not args.no_epoch_checkpoints ), ), }
def save_averaged_checkpoint(args, extra_state): epoch, offset = extra_state["epoch"], extra_state["batch_offset"] if not hasattr(save_averaged_checkpoint, "last_avg_checkpoints"): if args.max_checkpoints_kept == 0: raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.") save_averaged_checkpoint.last_avg_checkpoints = ManagedCheckpoints( max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0 ) last_checkpoints = extra_state["last_checkpoints"].get_last_n( 1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints ) if args.log_verbose: print( f"| Reading {len(last_checkpoints)} previous " f"checkpoints for averaging in epoch {epoch}, offset {offset}.", flush=True, ) averaged_state = average_checkpoints.average_checkpoints(last_checkpoints) filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt") save_averaged_checkpoint.last_avg_checkpoints.append(filename) if args.log_verbose: print( f"| Preparing to save averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) utils.torch_persistent_save(averaged_state, filename) if args.log_verbose: print( f"| Finished saving averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) return filename
def save_checkpoint(trainer, args, extra_state): epoch = extra_state["epoch"] batch_offset = extra_state["batch_offset"] val_loss = extra_state["val_loss"] if args.log_verbose: print( f"Preparing to save checkpoints for epoch {epoch}, " f"offset {batch_offset}. ", flush=True, ) if "last_checkpoints" not in extra_state: if args.generate_bleu_eval_avg_checkpoints < 1: raise argparse.ArgumentTypeError( "--generate-bleu-eval-avg-checkpoints must be >= 1.") extra_state["last_checkpoints"] = ManagedCheckpoints( max(args.generate_bleu_eval_avg_checkpoints, args.max_checkpoints_kept), # Don't auto_clear checkpoints for no_epoch_checkpoints, because # we are only going to reuse the same file. auto_clear=(args.max_checkpoints_kept > 0 and not args.no_epoch_checkpoints), ) # batch_offset being None means that we're at the end of an epoch. if batch_offset is None: if not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, f"checkpoint{epoch}.pt") save_checkpoint_maybe_continuous(epoch_filename, trainer, extra_state) extra_state["last_checkpoints"].append(epoch_filename) assert val_loss is not None if ("checkpoint_lowest_loss" not in extra_state or val_loss < extra_state["checkpoint_lowest_loss"]): extra_state["checkpoint_lowest_loss"] = val_loss best_filename = os.path.join(args.save_dir, "checkpoint_best.pt") save_checkpoint_maybe_continuous(best_filename, trainer, extra_state) # Otherwise, we're in the middle of an epoch. elif not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt") save_checkpoint_maybe_continuous(epoch_filename, trainer, extra_state) extra_state["last_checkpoints"].append(epoch_filename) last_filename = os.path.join(args.save_dir, "checkpoint_last.pt") save_checkpoint_maybe_continuous(last_filename, trainer, extra_state) # This ensures we'll always have at least one checkpoint in the list to use # for BLEU eval, even if we're not saving epoch checkpoints. if args.no_epoch_checkpoints: extra_state["last_checkpoints"].append(epoch_filename) if args.log_verbose: print( f"Finished saving checkpoints for epoch {epoch}, " f"offset {batch_offset}.", flush=True, )