def setup_training_state(args, trainer, task): """Set up the directory for saving checkpoints. Load pretrained model if specified.""" os.makedirs(args.save_dir, exist_ok=True) # If --restore-file is already present under --save-dir, use that one # instead of --pretrained-checkpoint-file. The idea is that # --pretrained-checkpoint-file allows the user to specify restoring from a # different run's checkpoint (possibly with different training params), # while not polluting the previous run's checkpoint directory # with new checkpoints. However, if training gets interrupted # and the user restarts training, we want to resume from # the checkpoints under --save-dir, instead of # restarting again from the old run's checkpoint at # --pretrained-checkpoint-file. # # Note that if args.restore_file is an absolute path, os.path.join() will # ignore previous directory args and just use the absolute path as is. checkpoint_path = os.path.join(args.save_dir, args.restore_file) restore_state = True if os.path.isfile(checkpoint_path): print( f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}." ) elif args.pretrained_checkpoint_file and os.path.isfile( args.pretrained_checkpoint_file): checkpoint_path = args.pretrained_checkpoint_file restore_state = args.load_pretrained_checkpoint_state print( f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, " f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}." ) extra_state = default_extra_state(args) if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files: print( f"| Restoring individual models from {args.multi_model_restore_files}" ) multi_model.import_individual_models(args.multi_model_restore_files, trainer) else: loaded, loaded_extra_state = checkpoint.load_existing_checkpoint( checkpoint_path=checkpoint_path, trainer=trainer, restore_state=restore_state, ) if loaded_extra_state: extra_state.update(loaded_extra_state) print(f"| extra_state: {extra_state}") return extra_state
def test_load_checkpoint_no_restore_state(self): """Train for one step, save a checkpoint, and make sure it is loaded properly WITHOUT loading the extra state from the checkpoint.""" test_save_file = test_utils.make_temp_file() test_args = test_utils.ModelParamsDict() test_args.distributed_rank = 0 extra_state = test_utils.create_dummy_extra_state(epoch=2) trainer, _ = test_utils.gpu_train_step(test_args) trainer.save_checkpoint(test_save_file, extra_state) loaded, extra_state = checkpoint.load_existing_checkpoint( test_save_file, trainer, restore_state=False ) # Loading checkpoint without restore state should reset extra state assert loaded and extra_state is None os.remove(test_save_file)
def setup_training_state(args, trainer, task, epoch_itr): """Set up the directory for saving checkpoints. Load pretrained model if specified.""" os.makedirs(args.save_dir, exist_ok=True) # If --restore-file is already present under --save-dir, use that one # instead of --pretrained-checkpoint-file. The idea is that # --pretrained-checkpoint-file allows the user to specify restoring from a # different run's checkpoint (possibly with different training params), # while not polluting the previous run's checkpoint directory # with new checkpoints. However, if training gets interrupted # and the user restarts training, we want to resume from # the checkpoints under --save-dir, instead of # restarting again from the old run's checkpoint at # --pretrained-checkpoint-file. # # Note that if args.restore_file is an absolute path, os.path.join() will # ignore previous directory args and just use the absolute path as is. checkpoint_path = os.path.join(args.save_dir, args.restore_file) restore_state = True if os.path.isfile(checkpoint_path): print( f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}." ) elif args.pretrained_checkpoint_file and os.path.isfile( args.pretrained_checkpoint_file ): checkpoint_path = args.pretrained_checkpoint_file restore_state = args.load_pretrained_checkpoint_state print( f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, " f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}." ) extra_state = default_extra_state(args) if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files: print(f"| Restoring individual models from {args.multi_model_restore_files}") multi_model.import_individual_models(args.multi_model_restore_files, trainer) else: loaded, loaded_extra_state = checkpoint.load_existing_checkpoint( checkpoint_path=checkpoint_path, trainer=trainer, restore_state=restore_state, ) if loaded_extra_state: extra_state.update(loaded_extra_state) # Reset the start time for the current training run. extra_state["start_time"] = time.time() # Skips printing all training progress to prevent log spam. training_progress = extra_state["training_progress"] extra_state["training_progress"] = ( ["...truncated...", training_progress[-1]] if len(training_progress) > 0 else [] ) print(f"| extra_state: {extra_state}") extra_state["training_progress"] = training_progress epoch = extra_state["epoch"] if extra_state["batch_offset"] == 0: epoch -= 1 # this will be incremented when we call epoch_itr.next_epoch_itr() epoch_itr.load_state_dict( {"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]} ) checkpoint_manager = None if distributed_utils.is_master(args): checkpoint_manager = checkpoint.CheckpointManager( num_avg_checkpoints=args.num_avg_checkpoints, auto_clear_checkpoints=args.auto_clear_checkpoints, log_verbose=args.log_verbose, checkpoint_files=extra_state["checkpoint_files"], ) return extra_state, epoch_itr, checkpoint_manager