Ejemplo n.º 1
0
    def distributed_sampler_kwargs(self):
        app_state = AppState()
        if app_state.model_parallel_size is not None:
            # When using model parallel, data parallel groups are non-trivial and they
            # correspond to the logical GPUs. This means that the GPUs that form a
            # single logical GPU all need to get the same batch of data.
            distributed_sampler_kwargs = dict(
                num_replicas=app_state.data_parallel_size, rank=app_state.data_parallel_rank
            )
            return distributed_sampler_kwargs

        else:
            return super(NLPDDPPlugin, self).distributed_sampler_kwargs
Ejemplo n.º 2
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if isinstance(self.lightning_module.bert_model, MegatronBertEncoder):
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # TODO: figure out how to override clip gradients again
                # Update PTL trainer to use our _clip_gradients
                # self._trainer.accelerator_backend._clip_gradients = self._clip_gradients

                if get_checkpoint_version():
                    # Restored from .nemo, checkpoint_version will already be set
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    logging.info(
                        f"Restoring from pretrained model parallel checkpoint: {self.lightning_module.bert_model._restore_path}"
                    )
                    self.lightning_module.bert_model.restore_weights(
                        self.lightning_module.bert_model._restore_path)

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
Ejemplo n.º 3
0
    def init_ddp_connection(self,
                            global_rank: int = None,
                            world_size: int = None) -> None:
        # call PTL init ddp
        super().init_ddp_connection()

        # init model parallel if needed
        app_state = AppState()

        if app_state.model_parallel_size is not None:

            if self.lightning_module.has_megatron_encoder and not self.lightning_module.is_model_parallel_initialized:
                self.init_model_parallel(app_state.global_rank,
                                         app_state.world_size)
Ejemplo n.º 4
0
    def start_testing(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """
        app_state = AppState()

        if app_state.model_parallel_size is not None:

            if self.has_megatron_encoder:
                # check megatron checkpoint version
                checkpoint_version = get_checkpoint_version()
                if checkpoint_version is None:
                    raise ValueError(
                        "Unable to find megatron checkpoint version.")

        return super().start_testing(trainer)
Ejemplo n.º 5
0
    def setup(self, stage: str) -> None:
        """ PTL hook that is called on all DDP processes. """

        if stage == 'fit':

            # adds self.bert_model config to .nemo file
            if hasattr(self, 'bert_model') and self.bert_model is not None:
                self.register_bert_model()

            app_state = AppState()

            if app_state.model_parallel_size is not None:

                self._trainer.checkpoint_connector = NLPCheckpointConnector(self._trainer)
Ejemplo n.º 6
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if self.lightning_module.has_megatron_encoder:
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    self.lightning_module.restore_megatron_encoder_weights()
            else:
                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                else:
                    self.lightning_module.restore_megatron_encoder_weights()

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
Ejemplo n.º 7
0
    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        # call PTL init ddp
        super().init_ddp_connection(global_rank, world_size)

        # init model parallel
        app_state = AppState()

        if app_state.model_parallel_size is not None:

            if isinstance(self.lightning_module.bert_model,
                          MegatronBertEncoder):

                if app_state.model_parallel_group is None:
                    self.init_model_parallel(app_state.global_rank,
                                             app_state.world_size)
Ejemplo n.º 8
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
              for microbatch in batch]
     batch = self.process_global_batch_for_tarred_datasets(batch)
     app_state = AppState()
     _reconfigure_microbatch_calculator(
         rank=app_state.global_rank,
         rampup_batch_size=None,
         global_batch_size=batch['text_enc'].size(0) *
         parallel_state.get_data_parallel_world_size(),
         micro_batch_size=batch['text_enc'].size(0),
         data_parallel_size=parallel_state.get_data_parallel_world_size(),
     )
     return super().training_step(batch, batch_idx)
Ejemplo n.º 9
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size *
             parallel_state.get_data_parallel_world_size() *
             get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
Ejemplo n.º 10
0
    def setup(self, stage: str) -> None:
        """ PTL hook that is called after DDP is initialized.
            Called at the beginning of fit and test. 

        Args:
            stage (str): either 'fit' or 'test'
        """

        # TODO: implement model parallel for test stage
        if stage == 'fit':

            # adds self.bert_model config to .nemo file
            self.register_bert_model()

            app_state = AppState()

            if app_state.model_parallel_size is not None:

                if app_state.model_parallel_group is None:
                    self.init_model_parallel(app_state.global_rank,
                                             app_state.world_size)

                # Update PTL trainer to use our configure_ddp
                self._trainer.accelerator_backend.configure_ddp = self.configure_ddp

                if isinstance(self.bert_model, MegatronBertEncoder):
                    logging.info(
                        f"restoring model parallel checkpoint: {self.bert_model._restore_path}"
                    )
                    # model parallel checkpoints need to be restored after torch.distributed is initialized
                    self.bert_model.restore_weights(
                        self.bert_model._restore_path)

                    logging.info(
                        "replacing sampler with model parallel sampler")
                    mp_sampler = torch.utils.data.distributed.DistributedSampler(
                        self._train_dl.dataset,
                        num_replicas=app_state.data_parallel_size,
                        rank=app_state.data_parallel_rank,
                    )
                    mp_dl = self._trainer.replace_sampler(
                        self._train_dl, mp_sampler)
                    self._train_dl = mp_dl
                else:
                    raise NotImplementedError(
                        f'The BERT encoder: {self.bert_model} does not support model parallelism yet.'
                    )
Ejemplo n.º 11
0
    def save_to(self, model, save_path: str):
        app_state = AppState()
        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:

            dir_name = os.path.dirname(save_path)

            # first we save the weights for each model parallel rank
            if app_state.data_parallel_rank == 0:
                mp_model_weights = os.path.join(
                    dir_name, f'mp_rank_{app_state.model_parallel_rank:02d}_' +
                    self.model_weights_ckpt)
                self._save_state_dict_to_disk(model.state_dict(),
                                              mp_model_weights)

            torch.distributed.barrier()

            # create nemo file from folder with all mp_ranks checkpoints
            if app_state.model_parallel_rank == 0 and app_state.data_parallel_rank == 0:
                with tempfile.TemporaryDirectory() as tmpdir:

                    # move weights to the tmpdir
                    for mp_rank in range(app_state.model_parallel_size):
                        os.makedirs(
                            os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}'))
                        mp_model_weights = os.path.join(
                            dir_name, f'mp_rank_{mp_rank:02d}_' +
                            self.model_weights_ckpt)
                        shutil.move(
                            mp_model_weights,
                            os.path.join(tmpdir, f'mp_rank_{mp_rank:02d}',
                                         self.model_weights_ckpt))

                    # create config and artifacts in tmpdir
                    config_yaml = os.path.join(tmpdir, self.model_config_yaml)
                    model.to_config_file(path2yaml_file=config_yaml)
                    if hasattr(model,
                               'artifacts') and model.artifacts is not None:
                        self._handle_artifacts(model, nemo_file_folder=tmpdir)
                        self._update_artifact_paths(model,
                                                    path2yaml_file=config_yaml)

                    # create tar file
                    self._make_nemo_file_from_folder(save_path, tmpdir)

        else:
            return super().save_to(model, save_path)
Ejemplo n.º 12
0
def inject_model_parallel_rank(filepath):
    """
    Injects tensor/pipeline model parallel ranks into the filepath.
    Does nothing if not using model parallelism.
    """
    app_state = AppState()
    if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
        # filepath needs to be updated to include mp_rank
        dirname = os.path.dirname(filepath)
        basename = os.path.basename(filepath)
        if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1:
            filepath = f'{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}/{basename}'
        else:
            filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}'
        return filepath
    else:
        return filepath
Ejemplo n.º 13
0
    def save_to(self, save_path: str):
        """
        Saves model instance (weights and configuration) into .nemo file
         You can use "restore_from" method to fully restore instance from .nemo file.

        .nemo file is an archive (tar.gz) with the following:
            model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
            model_weights.ckpt - model checkpoint

        Args:
            save_path: Path to .nemo file where model instance should be saved
        """
        save_path = os.path.abspath(save_path)
        app_state = AppState()
        if app_state.model_parallel_size is not None:
            self._default_save_to(save_path)
        else:
            # super.save_to only runs on global rank 0
            return super().save_to(save_path)
Ejemplo n.º 14
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size
             * parallel_state.get_data_parallel_world_size()
             * get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(),
         )
     # At this point batch is a list of dictionaries where eatch dict is a microbatch.
     # After the process_global_batch call, batch will be a single dictionary containing the global batch.
     # This is required since the parent class expects a single global batch dictioanry.
     batch = self._process_global_batch(batch)
     return super().training_step(batch, batch_idx)
Ejemplo n.º 15
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_devices
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            # destroy groups in case they have already been created
            # this happens with multiple calls to trainer.test for example
            parallel_state.destroy_model_parallel()
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=app_state.
                    tensor_model_parallel_size,
                    pipeline_model_parallel_size_=app_state.
                    pipeline_model_parallel_size,
                    pipeline_model_parallel_split_rank_=app_state.
                    pipeline_model_parallel_split_rank,
                )

                # assert that fake tp and pp rank match after model parallel init
                assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank(
                )
                assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank(
                )

                app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group(
                )
Ejemplo n.º 16
0
    def on_validation_epoch_end(self):
        app_state = AppState()
        if hasattr(self, "_train_ds"):
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.train_ds.global_batch_size,
                micro_batch_size=self.cfg.data.train_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )
        # When running `trainer.validate()`, the training dataset is not available.
        else:
            logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.')
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.validation_ds.global_batch_size,
                micro_batch_size=self.cfg.data.validation_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )

        return super().on_validation_epoch_end()
Ejemplo n.º 17
0
    def save_checkpoint(self, filepath, weights_only: bool):
        """Slightly modified version of PyTorch Lightning's save_checkpoint.

        Args:
            filepath ([str]): [description]
            weights_only (bool): [description]

        Returns:
            [type]: [description]
        """
        app_state = AppState()
        if app_state.model_parallel_size is not None:
            # filepath needs to be updated to include mp_rank
            dirname = os.path.dirname(filepath)
            basename = os.path.basename(filepath)
            filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'

            # dump states as a checkpoint dictionary object
            checkpoint = self.dump_checkpoint(weights_only)

            # each model parallel rank needs to save a copy of its model
            if app_state.data_parallel_rank == 0:
                # write the checkpoint dictionary on the file
                if self.trainer.accelerator_backend:
                    checkpoint = self.trainer.accelerator_backend.on_save(
                        checkpoint)
                try:
                    atomic_save(checkpoint, filepath)
                except AttributeError as err:
                    if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                        del checkpoint[
                            LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
                    rank_zero_warn(
                        'Warning, `hyper_parameters` dropped from checkpoint.'
                        f' An attribute is not picklable {err}')
                    atomic_save(checkpoint, filepath)
        return None
Ejemplo n.º 18
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            if torch.distributed.is_initialized() and app_state.data_parallel_group is None:
                parallel_state.initialize_model_parallel(app_state.model_parallel_size)
                app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group()
                app_state.data_parallel_group = parallel_state.get_data_parallel_group()
                app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank()
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
                logging.info(f'mp_rank: {app_state.model_parallel_rank}')
                logging.info(f'dp_rank: {app_state.data_parallel_rank}')
Ejemplo n.º 19
0
    def configure_ddp(self):
        """ Override LightningModule ddp if using model parallel.
            Sets find_unused_parameters to True.
        """

        app_state = AppState()

        if app_state.model_parallel_size is not None:
            logging.info(f"Configuring DDP for model parallelism.")

            # With model parallelism, multiple GPUs form a large "logical GPU"
            # this means that data parallel groups span multiple GPUs
            # and are non-trivial
            device_ids = self.determine_ddp_device_ids()
            self._model = DistributedDataParallel(
                LightningDistributedModule(self.model),
                device_ids=device_ids,
                output_device=device_ids[0],
                process_group=app_state.data_parallel_group,
                **self._ddp_kwargs,
            )

        else:
            super().configure_ddp()
Ejemplo n.º 20
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         batch = [[
             x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch
         ] for microbatch in batch]
         batch = self.process_global_batch_for_tarred_datasets(batch)
     elif (self._cfg.train_ds.dataset_type in ['bin_memmap', 'text_memmap']
           and self._cfg.train_ds.get("sampler",
                                      "distributed") == 'distributed'):
         batch = self._process_global_batch_without_megatron_batch_sampler(
             batch, tokenizer=self.encoder_tokenizer)
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=batch['text_enc'].size(0) *
             parallel_state.get_data_parallel_world_size(),
             micro_batch_size=batch['text_enc'].size(0),
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
Ejemplo n.º 21
0
    def _clip_gradients(self, optimizer, clip_val=None):
        """ Override of PTL Gradient Clipping.
            Enables model parallel gradient clipping from Megatron-LM.

        Args:
            optimizer ([type]): [description]
            clip_val ([type], optional): [description]. Defaults to None.
        """
        app_state = AppState()

        # get clip_val from trainer if None is provided
        if clip_val is None:
            clip_val = float(self._trainer.gradient_clip_val)

        if app_state.model_parallel_size is not None:
            model = self._trainer.get_model()
            parameters = model.parameters()
            if mpu.model_parallel_is_initialized():
                mpu.grads.clip_grad_norm(parameters=parameters, max_norm=clip_val)
            else:
                raise ValueError('Model parallel groups must be intialized to use model parallel gradient clipping.')

        else:
            return Accelerator._clip_gradients(self, optimizer, clip_val)
Ejemplo n.º 22
0
    def __init__(self, cfg: DictConfig, trainer: Trainer):
        app_state = AppState()

        if not app_state._is_megatron_initialized:
            logging.info(
                f"Initializing megatron since it hasn't been initialized by the model. This is normal if you are using a NeMo model with Megatron dataloaders."
            )
            app_state.global_rank = trainer.global_rank
            app_state.world_size = trainer.world_size
            app_state.model_parallel_size = 1
            app_state.model_parallel_rank = trainer.global_rank

            initialize_model_parallel_for_nemo(
                world_size=trainer.world_size,
                global_rank=trainer.global_rank,
                local_rank=trainer.local_rank,
                tensor_model_parallel_size=cfg.get(
                    'tensor_model_parallel_size', 1),
                seed=self.cfg.get('seed', 1234),
            )
Ejemplo n.º 23
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Override for LightningModule DDP initialization.
            Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            mpu.initialize_model_parallel(app_state.model_parallel_size)
            app_state.model_parallel_group = mpu.get_model_parallel_group()
            app_state.data_parallel_group = mpu.get_data_parallel_group()
            app_state.model_parallel_rank = torch.distributed.get_rank(group=app_state.model_parallel_group)
            app_state.data_parallel_rank = torch.distributed.get_rank(group=app_state.data_parallel_group)
            logging.info(f'mp_rank: {app_state.model_parallel_rank}')
            logging.info(f'dp_rank: {app_state.data_parallel_rank}')
Ejemplo n.º 24
0
    def __init__(self, cfg: DictConfig, trainer: Trainer):
        app_state = AppState()

        if not app_state._is_megatron_initialized:
            logging.info(
                f"Initializing megatron since it hasn't been initialized by the model. This is normal if you are using a NeMo model with Megatron dataloaders."
            )
            app_state.global_rank = trainer.global_rank
            app_state.world_size = trainer.world_size
            app_state.tensor_model_parallel_size = 1
            app_state.tensor_model_parallel_rank = trainer.global_rank

            initialize_model_parallel_for_nemo(
                world_size=trainer.world_size,
                global_rank=trainer.global_rank,
                local_rank=trainer.local_rank,
                tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
                seed=self.cfg.get('seed', 1234),
            )

        try:
            from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper

            if is_global_rank_zero():
                compile_helper()

            if torch.distributed.is_available() and torch.distributed.is_initialized():
                torch.distributed.barrier()

            from nemo.collections.nlp.data.language_modeling.megatron import helpers

            logging.info('Megatron dataset helper compiled successfully.')
        except ImportError:
            raise ImportError(
                f'Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file.'
            )
Ejemplo n.º 25
0
    def inference_epoch_end(self, outputs, mode, data_cfg):
        # Parent class will handle logging of the loss.
        if not outputs:
            return
        if isinstance(outputs[0], dict):
            outputs = [outputs]

        averaged_loss = []
        averaged_metric = []
        metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name
        # Log metrics for each provided validation/test dataset.
        for dataloader_idx, output in enumerate(outputs):
            loss = super().validation_epoch_end([x['loss'] for x in output])
            # Determine the key used to log the loss based on the user provided name of the dataset or the dataloader index.
            loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode)
            # Determine the key used to log the eval metric based on the user provided name of the dataset or the dataloader index.
            metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode)
            self.log(loss_log_key, loss)
            metric_object = (
                self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
            )
            metric = metric_object.compute()
            # Handle logging of GLUE/XNLI separately here. XNLI has a separate metric per language.
            if isinstance(metric, dict):
                # GLUE case:
                if len(metric) == 1 and 'acc' in metric:
                    metric = metric['acc']
                    self.log(metric_log_key, metric)
                    logging.info(f"{mode} {metric_name}: {metric}")
                # XNLI case where the metric dictionary contains the language and the computed metric as values.
                else:
                    for k, v in metric.items():
                        if k != 'acc' and 'total' not in k:
                            self.log(metric_log_key + f'_{k}', v)
                            logging.info(f"{mode} {metric_name} lang {k} : {v}")
                    metric = metric['acc']
            else:
                self.log(metric_log_key, metric)
                logging.info(f"{mode} {metric_name}: {metric}")
            metric_object.reset()

            averaged_loss.append(loss)
            averaged_metric.append(metric)

            # Write predictions, labels, and inputs to a file for each validation/test dataset.
            if data_cfg.get("write_predictions_to_file", False):

                # Check if the user provided a prefix path to the file(s) they want to write.
                if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None:
                    raise ValueError(
                        f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file."
                    )

                # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks.
                gathered_outputs = [None for _ in range(self.world_size)]
                torch.distributed.all_gather_object(
                    gathered_outputs,
                    [
                        {
                            'preds': x['preds'],
                            'labels': x['labels'],
                            'categories': x['categories'],
                            'inputs': x['inputs'],
                        }
                        for x in output
                    ],
                )

                # Figure out what the suffix of the file should be.
                filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode)

                # Keep a set of ground truths and inputs to write deduplicated predictions. Distributed Sampler may duplicate examples.
                gt_inp_set = set()
                deduplicated_outputs = {
                    'preds': [],
                    'labels': [],
                    'categories': [],
                    'inputs': [],
                }

                # PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0.
                if self.global_rank == 0:
                    for rank in range(0, self.world_size):
                        for batch in gathered_outputs[rank]:
                            for pred, label, input, category in zip(
                                batch['preds'], batch['labels'], batch['inputs'], batch['categories']
                            ):
                                if input + label not in gt_inp_set:
                                    gt_inp_set.add(input + label)
                                    deduplicated_outputs['preds'].append(pred)
                                    deduplicated_outputs['labels'].append(label)
                                    deduplicated_outputs['categories'].append(category)
                                    deduplicated_outputs['inputs'].append(input)
                self.write_predictions_to_file(
                    deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}"
                )
                torch.distributed.barrier()

        # Logging of the averaged metrics:
        averaged_loss = sum(averaged_loss) / len(averaged_loss)
        averaged_metric = sum(averaged_metric) / len(averaged_metric)

        # Handle case where metrics can be nan or inf. This can break checkpoint save/load.
        if torch.isinf(averaged_metric) or torch.isnan(averaged_metric):
            app_state = AppState()
            monitor_mode = app_state.checkpoint_callback_params.mode
            assert monitor_mode in ['min', 'max']
            averaged_metric = 0.0 if monitor_mode == 'max' else 1e5

        if mode == 'validation':
            self.log("validation_loss", averaged_loss)
            self.log(f"validation_{self.val_metric_name}", averaged_metric)
        elif mode == 'test':
            self.log("test_loss", averaged_loss)
            self.log(f"test_{self.test_metric_name}", averaged_metric)

        return averaged_loss, averaged_metric
    def complete(self, request: Dict):
        """
            Autoregressively invokes language model in the inference mode
        Args:
            request: Dictionary with the following fields
                * prompt: a string which text the model should complete.
                * tokens_to_generate: how many tokens to generate while doing prompt completion.
        Returns:
            response: A python dictionary with the following fields
                * prompt: original text of the prompt
                * tokenized_prompt: list of (str) tokens from prompt
                * completion: a python dictionary with the following subfields:
                    * tokens: a list of triples (token, token_id, log_prob) comprising completion
                    * text: completion text (as a single string)

        """
        app_state = AppState()

        # The complete method only works with global batch = micro batch size = data parallel size = 1.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=1,
            micro_batch_size=1,
            data_parallel_size=1,
        )
        app_state = AppState()

        response = {}
        self.freeze()
        # naive greedy slow loop
        # TODO: add option for BeamSearchDecoder

        response['prompt'] = request['prompt'][0]
        response['completion'] = {}
        tokens_enc = request['masked_sample']

        response['masked_input'] = ' '.join(self.tokenizer.ids_to_tokens(tokens_enc[0].cpu().numpy().tolist()))
        enc_mask = tokens_enc != self.tokenizer.pad_id

        predicted_tokens_ids, log_probs = self.decode(tokens_enc, enc_mask, int(request['tokens_to_generate']))
        predicted_tokens_ids = predicted_tokens_ids.cpu().numpy()[0].tolist()
        log_probs = log_probs.cpu().numpy()[0].tolist()
        if self.tokenizer.eos_id in predicted_tokens_ids:
            idx = predicted_tokens_ids.index(self.tokenizer.eos_id)
            predicted_tokens_ids = predicted_tokens_ids[:idx]
        else:
            predicted_tokens_ids = [id for id in predicted_tokens_ids if id != self.tokenizer.pad_id]
        if self.tokenizer.eos_id in predicted_tokens_ids:
            idx = predicted_tokens_ids.index(self.tokenizer.eos_id)
            predicted_tokens_ids = predicted_tokens_ids[:idx]
        # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match.
        if hasattr(self.tokenizer, 'special_token_to_id'):
            predicted_tokens_ids = [
                id for id in predicted_tokens_ids if id not in self.tokenizer.special_token_to_id.values()
            ]

        predicted_tokens_dec = self.tokenizer.ids_to_tokens(predicted_tokens_ids)
        response['completion']['text'] = self.tokenizer.tokens_to_text(predicted_tokens_dec)
        response['completion']['tokens'] = list(zip(predicted_tokens_ids, predicted_tokens_dec, log_probs))
        self.unfreeze()
        return response
Ejemplo n.º 27
0
    def setup(self, stage: str) -> None:
        """ PTL hook that is called after DDP is initialized.
            Called at the beginning of fit and test.

        Args:
            stage (str): either 'fit' or 'test'
        """
        # TODO: implement model parallel for test stage
        if stage == 'fit':
            # set find_unused_parameters to True by default for NLP models
            if isinstance(self.trainer.accelerator.training_type_plugin,
                          DDPPlugin):
                self.trainer.accelerator.training_type_plugin._ddp_kwargs[
                    'find_unused_parameters'] = True

            # adds self.bert_model config to .nemo file
            if hasattr(self, 'bert_model') and self.bert_model is not None:
                self.register_bert_model()

            app_state = AppState()

            if app_state.model_parallel_size is not None:

                if app_state.model_parallel_group is None:
                    self.init_model_parallel(app_state.global_rank,
                                             app_state.world_size)

                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self._trainer.get_model().parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # Update PTL trainer to use our configure_ddp
                self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp
                # Update PTL trainer to use our _clip_gradients
                self._trainer.accelerator_backend._clip_gradients = self._clip_gradients
                self._trainer.checkpoint_connector = NLPCheckpointConnector(
                    self._trainer)

                # Configure checkpointing for model parallel
                if app_state.create_checkpoint_callback:
                    # global rank 0 is configured by exp_manager
                    if not is_global_rank_zero(
                    ) and app_state.data_parallel_rank == 0:
                        configure_checkpointing(
                            self._trainer,
                            app_state.log_dir,
                            app_state.checkpoint_name,
                            app_state.checkpoint_callback_params,
                        )

                if isinstance(self.bert_model, MegatronBertEncoder):
                    self.bert_model.complete_lazy_init()

                    # model parallel checkpoints need to be restored after torch.distributed is initialized
                    if self._trainer.resume_from_checkpoint is not None:
                        # update path based on model parallel rank
                        filepath = self._trainer.resume_from_checkpoint
                        dirname = os.path.dirname(os.path.dirname(filepath))
                        basename = os.path.basename(filepath)
                        filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                        self._trainer.resume_from_checkpoint = filepath
                        logging.info(
                            f'Resuming training from checkpoint {self._trainer.resume_from_checkpoint}'
                        )
                        # need to set checkpoint version for megatron-lm
                        checkpoint_version = torch.load(
                            self._trainer.resume_from_checkpoint).get(
                                'checkpoint_version', None)
                        if checkpoint_version is not None:
                            set_checkpoint_version(checkpoint_version)
                        else:
                            logging.warning(
                                'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                            )
                            set_checkpoint_version(0)
                    else:
                        logging.info(
                            f"Restoring from pretrained model parallel checkpoint: {self.bert_model._restore_path}"
                        )
                        self.bert_model.restore_weights(
                            self.bert_model._restore_path)

                    logging.info(
                        "Replacing sampler with model parallel sampler")
                    mp_sampler = torch.utils.data.distributed.DistributedSampler(
                        self._train_dl.dataset,
                        num_replicas=app_state.data_parallel_size,
                        rank=app_state.data_parallel_rank,
                    )
                    mp_dl = self._trainer.replace_sampler(
                        self._train_dl, mp_sampler)
                    self._train_dl = mp_dl
                else:
                    raise NotImplementedError(
                        f'The BERT encoder: {self.bert_model} does not support model parallelism yet.'
                    )
Ejemplo n.º 28
0
def get_lm_model(
    config_dict: Optional[dict] = None,
    config_file: Optional[str] = None,
    vocab_file: Optional[str] = None,
    trainer: Optional[Trainer] = None,
    cfg: DictConfig = None,
) -> BertModule:
    """
    Helper function to instantiate a language model encoder, either from scratch or a pretrained model.
    If only pretrained_model_name are passed, a pretrained model is returned.
    If a configuration is passed, whether as a file or dictionary, the model is initialized with random weights.

    Args:
        config_dict: path to the model configuration dictionary
        config_file: path to the model configuration file
        vocab_file: path to vocab_file to be used with Megatron-LM
        trainer: an instance of a PyTorch Lightning trainer
        cfg: a model configuration
    Returns:
        Pretrained BertModule
    """

    # check valid model type
    if cfg.language_model.get('pretrained_model_name'):
        if (not cfg.language_model.pretrained_model_name
                or cfg.language_model.pretrained_model_name
                not in get_pretrained_lm_models_list(include_external=False)):
            logging.warning(
                f'{cfg.language_model.pretrained_model_name} is not in get_pretrained_lm_models_list(include_external=False), '
                f'will be using AutoModel from HuggingFace.')

    # warning when user passes both configuration dict and file
    if config_dict and config_file:
        logging.warning(
            f"Both config_dict and config_file were found, defaulting to use config_file: {config_file} will be used."
        )

    pretrain_model_name = ''
    if cfg.get('language_model') and cfg.language_model.get(
            'pretrained_model_name', ''):
        pretrain_model_name = cfg.language_model.get('pretrained_model_name',
                                                     '')
    all_pretrained_megatron_bert_models = get_megatron_pretrained_bert_models()
    if (cfg.tokenizer is not None
            and cfg.tokenizer.get("tokenizer_name", "") is not None
            and "megatron" in cfg.tokenizer.get("tokenizer_name", "")
        ) or pretrain_model_name in all_pretrained_megatron_bert_models:
        import torch

        from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel

        class Identity(torch.nn.Module):
            def __init__(self):
                super(Identity, self).__init__()

            def forward(self, x, *args):
                return x

        if cfg.language_model.get("lm_checkpoint"):
            model = MegatronBertModel.restore_from(
                restore_path=cfg.language_model.lm_checkpoint, trainer=trainer)
        else:
            model = MegatronBertModel.from_pretrained(
                cfg.language_model.get('pretrained_model_name'),
                trainer=trainer)
        # remove the headers that are only revelant for pretraining
        model.model.lm_head = Identity()
        model.model.binary_head = Identity()
        model.model.language_model.pooler = Identity()

    else:
        model = get_huggingface_lm_model(
            config_dict=config_dict,
            config_file=config_file,
            pretrained_model_name=cfg.language_model.pretrained_model_name,
        )

        if cfg.language_model.get("lm_checkpoint"):
            app_state = AppState()
            if not app_state.is_model_being_restored and not os.path.exists(
                    cfg.language_model.lm_checkpoint):
                raise ValueError(
                    f'{cfg.language_model.lm_checkpoint} not found')
            model.restore_weights(
                restore_path=cfg.language_model.lm_checkpoint)

    return model
Ejemplo n.º 29
0
 def compute_consumed_samples(self, global_step):
     app_state = AppState()
     consumed_samples = (global_step * app_state.data_parallel_size *
                         self.cfg.micro_batch_size *
                         self.trainer.accumulate_grad_batches)
     return int(consumed_samples)
Ejemplo n.º 30
0
 def is_model_parallel_initialized(self):
     app_state = AppState()
     if app_state.model_parallel_group is not None:
         return True
     else:
         return False