def save_averaged_checkpoint(args, extra_state): epoch, offset = extra_state["epoch"], extra_state["batch_offset"] if not hasattr(save_averaged_checkpoint, "last_avg_checkpoints"): if args.max_checkpoints_kept == 0: raise argparse.ArgumentTypeError("--max-checkpoints-kept must be != 0.") save_averaged_checkpoint.last_avg_checkpoints = ManagedCheckpoints( max(args.max_checkpoints_kept, 1), auto_clear=args.max_checkpoints_kept > 0 ) last_checkpoints = extra_state["last_checkpoints"].get_last_n( 1 if args.no_epoch_checkpoints else args.generate_bleu_eval_avg_checkpoints ) if args.log_verbose: print( f"| Reading {len(last_checkpoints)} previous " f"checkpoints for averaging in epoch {epoch}, offset {offset}.", flush=True, ) averaged_state = average_checkpoints.average_checkpoints(last_checkpoints) filename = os.path.join(args.save_dir, f"averaged_checkpoint{epoch}_{offset}.pt") save_averaged_checkpoint.last_avg_checkpoints.append(filename) if args.log_verbose: print( f"| Preparing to save averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) utils.torch_persistent_save(averaged_state, filename) if args.log_verbose: print( f"| Finished saving averaged checkpoint for " f"epoch {epoch}, offset {offset}.", flush=True, ) return filename
def save_state(state, filename): torch_persistent_save(state, filename)