Example #1
0
def setup_training_state(args, trainer, task):
    """Set up the directory for saving checkpoints.
    Load pretrained model if specified."""
    os.makedirs(args.save_dir, exist_ok=True)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if os.path.isfile(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and os.path.isfile(
            args.pretrained_checkpoint_file):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(
            f"| Restoring individual models from {args.multi_model_restore_files}"
        )
        multi_model.import_individual_models(args.multi_model_restore_files,
                                             trainer)
    else:
        loaded, loaded_extra_state = checkpoint.load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)
    print(f"| extra_state: {extra_state}")
    return extra_state
Example #2
0
 def test_load_checkpoint_no_restore_state(self):
     """Train for one step, save a checkpoint, and make sure it is loaded
     properly WITHOUT loading the extra state from the checkpoint."""
     test_save_file = test_utils.make_temp_file()
     test_args = test_utils.ModelParamsDict()
     test_args.distributed_rank = 0
     extra_state = test_utils.create_dummy_extra_state(epoch=2)
     trainer, _ = test_utils.gpu_train_step(test_args)
     trainer.save_checkpoint(test_save_file, extra_state)
     loaded, extra_state = checkpoint.load_existing_checkpoint(
         test_save_file, trainer, restore_state=False
     )
     # Loading checkpoint without restore state should reset extra state
     assert loaded and extra_state is None
     os.remove(test_save_file)
Example #3
0
def setup_training_state(args, trainer, task, epoch_itr):
    """Set up the directory for saving checkpoints.
    Load pretrained model if specified."""
    os.makedirs(args.save_dir, exist_ok=True)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if os.path.isfile(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and os.path.isfile(
        args.pretrained_checkpoint_file
    ):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(f"| Restoring individual models from {args.multi_model_restore_files}")
        multi_model.import_individual_models(args.multi_model_restore_files, trainer)
    else:
        loaded, loaded_extra_state = checkpoint.load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)

    # Reset the start time for the current training run.
    extra_state["start_time"] = time.time()

    # Skips printing all training progress to prevent log spam.
    training_progress = extra_state["training_progress"]
    extra_state["training_progress"] = (
        ["...truncated...", training_progress[-1]] if len(training_progress) > 0 else []
    )
    print(f"| extra_state: {extra_state}")
    extra_state["training_progress"] = training_progress

    epoch = extra_state["epoch"]
    if extra_state["batch_offset"] == 0:
        epoch -= 1  # this will be incremented when we call epoch_itr.next_epoch_itr()
    epoch_itr.load_state_dict(
        {"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]}
    )

    checkpoint_manager = None
    if distributed_utils.is_master(args):
        checkpoint_manager = checkpoint.CheckpointManager(
            num_avg_checkpoints=args.num_avg_checkpoints,
            auto_clear_checkpoints=args.auto_clear_checkpoints,
            log_verbose=args.log_verbose,
            checkpoint_files=extra_state["checkpoint_files"],
        )

    return extra_state, epoch_itr, checkpoint_manager