예제 #1
0
    def _load_checkpoint(self, filename):
        """Helper function for loading megatron checkpoints.

        Args:
            filename (str): Path to megatron checkpoint.
        """
        state_dict = torch.load(filename, map_location='cpu')
        if 'checkpoint_version' in state_dict:
            if state_dict['checkpoint_version'] is not None:
                set_checkpoint_version(state_dict['checkpoint_version'])
                logging.info(
                    f"Megatron-lm checkpoint version found. Setting checkpoint_version to {state_dict['checkpoint_version']}."
                )
        else:
            logging.warning(
                'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
            )
            set_checkpoint_version(0)
        # to load from Megatron pretrained checkpoint
        if 'model' in state_dict:
            self.language_model.load_state_dict(
                state_dict['model'][self._language_model_key])
        else:
            self.load_state_dict(state_dict)

        logging.info(f"Checkpoint loaded from from {filename}")
예제 #2
0
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """ LightningModule hook that's used to restore things saved with on_save_checkpoint."""

        if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
            if get_checkpoint_version():
                assert (
                    checkpoint['checkpoint_version'] == get_checkpoint_version()
                ), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version'
            else:
                set_checkpoint_version(checkpoint['checkpoint_version'])
                logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}")
        return None
예제 #3
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if isinstance(self.lightning_module.bert_model, MegatronBertEncoder):
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # TODO: figure out how to override clip gradients again
                # Update PTL trainer to use our _clip_gradients
                # self._trainer.accelerator_backend._clip_gradients = self._clip_gradients

                if get_checkpoint_version():
                    # Restored from .nemo, checkpoint_version will already be set
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    logging.info(
                        f"Restoring from pretrained model parallel checkpoint: {self.lightning_module.bert_model._restore_path}"
                    )
                    self.lightning_module.bert_model.restore_weights(
                        self.lightning_module.bert_model._restore_path)

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
예제 #4
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if self.lightning_module.has_megatron_encoder:
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    self.lightning_module.restore_megatron_encoder_weights()
            else:
                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                else:
                    self.lightning_module.restore_megatron_encoder_weights()

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
예제 #5
0
    def setup(self, stage: str) -> None:
        """ PTL hook that is called after DDP is initialized.
            Called at the beginning of fit and test.

        Args:
            stage (str): either 'fit' or 'test'
        """
        # TODO: implement model parallel for test stage
        if stage == 'fit':
            # set find_unused_parameters to True by default for NLP models
            if isinstance(self.trainer.accelerator.training_type_plugin,
                          DDPPlugin):
                self.trainer.accelerator.training_type_plugin._ddp_kwargs[
                    'find_unused_parameters'] = True

            # adds self.bert_model config to .nemo file
            if hasattr(self, 'bert_model') and self.bert_model is not None:
                self.register_bert_model()

            app_state = AppState()

            if app_state.model_parallel_size is not None:

                if app_state.model_parallel_group is None:
                    self.init_model_parallel(app_state.global_rank,
                                             app_state.world_size)

                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self._trainer.get_model().parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # Update PTL trainer to use our configure_ddp
                self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp
                # Update PTL trainer to use our _clip_gradients
                self._trainer.accelerator_backend._clip_gradients = self._clip_gradients
                self._trainer.checkpoint_connector = NLPCheckpointConnector(
                    self._trainer)

                # Configure checkpointing for model parallel
                if app_state.create_checkpoint_callback:
                    # global rank 0 is configured by exp_manager
                    if not is_global_rank_zero(
                    ) and app_state.data_parallel_rank == 0:
                        configure_checkpointing(
                            self._trainer,
                            app_state.log_dir,
                            app_state.checkpoint_name,
                            app_state.checkpoint_callback_params,
                        )

                if isinstance(self.bert_model, MegatronBertEncoder):
                    self.bert_model.complete_lazy_init()

                    # model parallel checkpoints need to be restored after torch.distributed is initialized
                    if self._trainer.resume_from_checkpoint is not None:
                        # update path based on model parallel rank
                        filepath = self._trainer.resume_from_checkpoint
                        dirname = os.path.dirname(os.path.dirname(filepath))
                        basename = os.path.basename(filepath)
                        filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                        self._trainer.resume_from_checkpoint = filepath
                        logging.info(
                            f'Resuming training from checkpoint {self._trainer.resume_from_checkpoint}'
                        )
                        # need to set checkpoint version for megatron-lm
                        checkpoint_version = torch.load(
                            self._trainer.resume_from_checkpoint).get(
                                'checkpoint_version', None)
                        if checkpoint_version is not None:
                            set_checkpoint_version(checkpoint_version)
                        else:
                            logging.warning(
                                'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                            )
                            set_checkpoint_version(0)
                    else:
                        logging.info(
                            f"Restoring from pretrained model parallel checkpoint: {self.bert_model._restore_path}"
                        )
                        self.bert_model.restore_weights(
                            self.bert_model._restore_path)

                    logging.info(
                        "Replacing sampler with model parallel sampler")
                    mp_sampler = torch.utils.data.distributed.DistributedSampler(
                        self._train_dl.dataset,
                        num_replicas=app_state.data_parallel_size,
                        rank=app_state.data_parallel_rank,
                    )
                    mp_dl = self._trainer.replace_sampler(
                        self._train_dl, mp_sampler)
                    self._train_dl = mp_dl
                else:
                    raise NotImplementedError(
                        f'The BERT encoder: {self.bert_model} does not support model parallelism yet.'
                    )
예제 #6
0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native

        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, DistributedDataParallel):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_last("WARNING: could not find the metadata file {} ".format(
            tracker_filename))
        print_rank_last(
            "    will not load any checkpoints and will start from "
            "random")
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, "r") as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == "release"
            if not release:
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(
                        tracker_filename))
                sys.exit()

    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename)

    # Checkpoint.
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration,
                                                     release, 0)
    checkpoint_name_local = get_fmoe_checkpoint_name(
        load_dir, iteration, release, mpu.get_data_parallel_rank())
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later"
        .format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        ))

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
            state_dict = torch.load(checkpoint_name, map_location="cpu")
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler

            # For backward compatibility.
            print_rank_last(
                " > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
        except BaseException:
            print_rank_last("could not load the checkpoint")
            sys.exit()
        return state_dict

    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

    state_dict = merge_state_dict(state_dict_rank0, state_dict_local,
                                  args.fp16)

    # set checkpoint version
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict["iteration"]
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = state_dict["total_iters"]
            except KeyError:
                print_rank_last("A metadata file exists but unable to load "
                                "iteration from checkpoint {}, exiting".format(
                                    checkpoint_name_local))
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              "consumed_train_samples", 0)
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              "consumed_valid_samples", 0)
    else:
        print_rank_last("could not find arguments in the checkpoint ...")

    # Model.
    model.load_state_dict(state_dict["model"])

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(state_dict["optimizer"])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
        except KeyError:
            print_rank_last("Unable to load optimizer from checkpoint {}. "
                            "Specify --no-load-optim or --finetune to prevent "
                            "attempting to load the optimizer state, "
                            "exiting ...".format(checkpoint_name_local))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(
                state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_last("Unable to load optimizer from checkpoint {}. "
                            "Specify --no-load-rng or --finetune to prevent "
                            "attempting to load the optimizer state, "
                            "exiting ...".format(checkpoint_name_local))
            sys.exit()

    torch.distributed.barrier()
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}"
        .format(args.load, iteration))

    return iteration
예제 #7
0
    def restore_from(
        cls,
        restore_path: str,
        override_config_path: Optional[Union[OmegaConf, str]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        trainer: Trainer = None,
        save_restore_connector: SaveRestoreConnector = None,
    ):
        """
        Restores model instance (weights and configuration) from .nemo file.

        Args:
            restore_path: path to .nemo file from which model should be instantiated
            override_config_path: path to a yaml config that will override the internal
                config file or an OmegaConf / DictConfig object representing the model config.
            map_location: Optional torch.device() to map the instantiated model to a device.
                By default (None), it will select a GPU if available, falling back to CPU otherwise.
            strict: Passed to load_state_dict. Set to True by default.
            return_config: If set to true, will return just the underlying config of the restored
                model as an OmegaConf DictConfig object without instantiating the model.
            trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo

            Example:
                ```
                model = nemo.collections.nlp.models.TokenClassificationModel.restore_from('token_classification.nemo')
                assert isinstance(model, nemo.collections.nlp.models.TokenClassificationModel)
                ```

        Returns:
            An instance of type cls or its underlying config (if return_config is set).
        """
        if save_restore_connector is None:
            save_restore_connector = SaveRestoreConnector()

        if not os.path.exists(restore_path):
            raise FileNotFoundError(f"Can't find {restore_path}")

        app_state = AppState()
        app_state.model_restore_path = os.path.abspath(
            os.path.expanduser(restore_path))

        # detect if we have a model parallel .nemo file
        with tempfile.TemporaryDirectory() as tmpdir:
            cwd = os.getcwd()
            os.chdir(tmpdir)
            # detect if model parallel from tarfile
            tar = tarfile.open(app_state.model_restore_path, "r:gz")
            names = tar.getnames()
            mp_ranks = []
            for name in names:
                if 'mp_rank' in name:
                    mp_ranks.append(name)
            if mp_ranks:
                app_state.model_parallel_size = len(
                    mp_ranks
                ) // 2  # directory and file are included in getnames()

                # get checkpoint version
                checkpoint_version_member = None
                for member in tar.getmembers():
                    if 'megatron_checkpoint_version.json' in member.name:
                        checkpoint_version_member = member
                tar.extract(checkpoint_version_member, tmpdir)
                with open(checkpoint_version_member.name, 'r') as f:
                    checkpoint_version = json.load(f).get(
                        'checkpoint_version', None)
                logging.info(
                    (f'Detected model parallel .nemo file: {restore_path}. '
                     f'Assuming megatron model parallelism with '
                     f'model_parallel_size: {app_state.model_parallel_size} '
                     f'and checkpoint version: {checkpoint_version}'))
            tar.close()
            os.chdir(cwd)

        if app_state.model_parallel_size is not None:
            if not isinstance(trainer, Trainer):
                raise ValueError(
                    "trainer must be a PyTorch Lightning Trainer to restore model parallel .nemo files."
                )

            if checkpoint_version is None:
                raise ValueError(
                    "Restoring from megatron model parallel .nemo but could not find megatron checkpoint version."
                )
            else:
                logging.info(
                    f"Setting megatron checkpoint version: {checkpoint_version}"
                )
                set_checkpoint_version(checkpoint_version)

            app_state.world_size = trainer.num_gpus * trainer.num_nodes

            if trainer.local_rank is not None:
                app_state.local_rank = trainer.local_rank
            else:
                raise ValueError(
                    "trainer.local_rank is None. local_rank needed to restore model parallel models."
                )

            model_parallel_rank = compute_model_parallel_rank(
                trainer.local_rank, app_state.model_parallel_size)
            app_state.model_parallel_rank = model_parallel_rank

            cls.update_save_restore_connector(save_restore_connector)
            restored_model = cls._save_restore_connector.restore_from(
                cls, app_state.model_restore_path, override_config_path,
                map_location, strict, return_config)
            restored_model.set_trainer(trainer)
            return restored_model
        else:
            return super().restore_from(
                app_state.model_restore_path,
                override_config_path,
                map_location,
                strict,
                return_config,
                save_restore_connector=save_restore_connector,
            )
예제 #8
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path
        if os.path.isfile(restore_path):
            logging.info(
                f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism'
            )
            state_dict = torch.load(restore_path, map_location='cpu')
            if 'checkpoint_version' in state_dict:
                if state_dict['checkpoint_version'] is not None:
                    set_checkpoint_version(state_dict['checkpoint_version'])
            else:
                logging.warning(
                    'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                )
                set_checkpoint_version(0)
            # to load from Megatron pretrained checkpoint
            if 'model' in state_dict:
                self.language_model.load_state_dict(
                    state_dict['model'][self._language_model_key])
            else:
                self.load_state_dict(state_dict)
            logging.info(f"weights restored from {restore_path}")
        elif os.path.isdir(restore_path):
            # TODO: need to refactor this so we're not repeating code

            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                logging.info(
                    f'Restoring model parallel checkpoint from: {mp_restore_path}'
                )
                state_dict = torch.load(mp_restore_path, map_location='cpu')
                if 'checkpoint_version' in state_dict:
                    if state_dict['checkpoint_version'] is not None:
                        set_checkpoint_version(
                            state_dict['checkpoint_version'])
                else:
                    logging.warning(
                        'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                    )
                    set_checkpoint_version(0)
                # to load from Megatron pretrained checkpoint
                if 'model' in state_dict:
                    self.language_model.load_state_dict(
                        state_dict['model'][self._language_model_key])
                else:
                    self.load_state_dict(state_dict)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')