예제 #1
0
    def on_validation_epoch_end(self):
        app_state = AppState()
        if hasattr(self, "_train_ds"):
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.train_ds.global_batch_size,
                micro_batch_size=self.cfg.data.train_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )
        # When running `trainer.validate()`, the training dataset is not available.
        else:
            logging.warning(
                'No training data found, reconfiguring microbatches based on validation batch sizes.'
            )
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.validation_ds.
                global_batch_size,
                micro_batch_size=self.cfg.data.validation_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

        return super().on_validation_epoch_end()
예제 #2
0
    def build_pretraining_data_loader(self, dataset, consumed_samples):
        """Buld dataloader given an input dataset."""

        if dataset is None:
            return None

        logging.info(f'Building dataloader with consumed samples: {consumed_samples}')
        # Megatron sampler
        if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None:
            if self.cfg.data.dataloader_type == 'single':
                batch_sampler = MegatronPretrainingSampler(
                    total_samples=len(dataset),
                    consumed_samples=consumed_samples,
                    micro_batch_size=self.cfg.micro_batch_size,
                    data_parallel_rank=parallel_state.get_data_parallel_rank(),
                    data_parallel_size=parallel_state.get_data_parallel_world_size(),
                )
            elif self.cfg.data.dataloader_type == 'cyclic':
                batch_sampler = MegatronPretrainingRandomSampler(
                    total_samples=len(dataset),
                    consumed_samples=consumed_samples,
                    micro_batch_size=self.cfg.micro_batch_size,
                    data_parallel_rank=parallel_state.get_data_parallel_rank(),
                    data_parallel_size=parallel_state.get_data_parallel_world_size(),
                )
            else:
                raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
        else:
            raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"')

        # Torch dataloader.
        return torch.utils.data.DataLoader(
            dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True,
        )
    def build_pretraining_data_loader(self, dataset, consumed_samples):
        """Buld dataloader given an input dataset."""

        if dataset is None:
            return None

        # Megatron sampler
        if self._cfg.data.dataloader_type == 'single':
            batch_sampler = MegatronPretrainingSampler(
                total_samples=len(dataset),
                consumed_samples=consumed_samples,
                micro_batch_size=self._cfg.micro_batch_size,
                data_parallel_rank=parallel_state.get_data_parallel_rank(),
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )
        elif self._cfg.data.dataloader_type == 'cyclic':
            batch_sampler = MegatronPretrainingRandomSampler(
                total_samples=len(dataset),
                consumed_samples=consumed_samples,
                micro_batch_size=self._cfg.micro_batch_size,
                data_parallel_rank=parallel_state.get_data_parallel_rank(),
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )
        else:
            raise Exception('{} dataloader type is not supported.'.format(self._cfg.dataloader_type))

        # Torch dataloader.
        return torch.utils.data.DataLoader(
            dataset, batch_sampler=batch_sampler, num_workers=self._cfg.data.num_workers, pin_memory=True,
        )
예제 #4
0
        def param_hook(*unused):
            """Gradient accumulation and all-reduce."""
            if param.grad.data is None:
                return
            if main_param.grad is None:
                main_param.grad = param.grad.float()
            else:
                main_param.grad.add_(param.grad.data)
            # Deallocate grad memory.
            param.grad = None

            # Asynchronous gradients allreduce across data_parallel ranks
            if self._require_backward_grad_sync:
                if self._grad_allreduce_chunk_size_mb > 0:
                    self._main_grad_buffers[i].update_chunk_info(
                        grad_chunk_info)
                    while True:
                        allreduce_tensor = self._main_grad_buffers[
                            i].get_allreduce_tensor()
                        if allreduce_tensor is None:
                            break
                        allreduce_tensor.div_(get_data_parallel_world_size())
                        torch.distributed.all_reduce(
                            allreduce_tensor,
                            group=get_data_parallel_group(),
                            async_op=True)
                else:
                    main_param.grad.div_(get_data_parallel_world_size())
                    torch.distributed.all_reduce(  # type: ignore
                        main_param.grad,
                        group=get_data_parallel_group(),
                        async_op=True)
예제 #5
0
    def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
        batch_has_lang_information = len(batch[0]) == 7

        micro_batch_size = batch[0]['text_enc'].size(0)
        # This should happen only on the last batch of the dataset.
        if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size:
            app_state = AppState()
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=micro_batch_size
                * parallel_state.get_data_parallel_world_size()
                * get_num_microbatches(),
                micro_batch_size=micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )

        # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch.
        # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch.
        # This is required since the parent class expects a single global batch dictioanry.
        processed_batch = self._process_global_batch(batch)

        # Call parent validation step to get the loss.
        # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class.
        loss = super().validation_step(processed_batch, batch_idx)

        predicted_token_ids, _ = self.decode(
            tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30
        )

        # Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
        preds_text = self.ids_to_text(predicted_token_ids)
        labels_text = self.ids_to_text(processed_batch['labels'])
        input_text = self.ids_to_text(processed_batch['text_enc'])

        if not batch_has_lang_information:
            categories = [None] * len(preds_text)
        else:
            categories = processed_batch['lang']

        metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
        assert len(categories) == len(preds_text) == len(labels_text)
        for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)):
            # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats.
            pred, label = self.cast_for_metric(
                pred, label, self.val_metric_name if mode == 'validation' else self.test_metric_name
            )
            if batch_has_lang_information:
                _ = metric(pred, label, category)
            else:
                _ = metric(pred, label)

        return {
            'loss': loss,
            'preds': preds_text,
            'labels': labels_text,
            'categories': categories,
            'inputs': input_text,
        }
예제 #6
0
    def eval_step(self, batch, batch_idx, dataloader_idx, data_cfg):
        # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
        batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
                 for microbatch in batch]
        batch = self.process_global_batch_for_tarred_datasets(batch)
        if data_cfg.dataset_type in ['tarred', 'text']:
            app_state = AppState()
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=batch['text_enc'].size(0) *
                parallel_state.get_data_parallel_world_size(),
                micro_batch_size=batch['text_enc'].size(0),
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )
        # This returns the averaged loss across data-parallel groups.
        reduced_loss = super().validation_step(batch, batch_idx)
        tokens_enc, labels, enc_mask = batch['text_enc'], batch[
            'labels'], batch['enc_mask']
        predicted_tokens_ids, _ = self.decode(
            tokens_enc,
            enc_mask,
            tokens_enc.size(1) + self._cfg.
            max_generation_delta,  # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>.
            tokenizer=self.decoder_tokenizer,
        )
        if self.multilingual:
            source_processor = self.source_processor_list[dataloader_idx]
            target_processor = self.target_processor_list[dataloader_idx]
        else:
            source_processor = self.source_processor
            target_processor = self.target_processor

        # Post-process the translations and inputs to log.
        preds = self.postprocess_outputs(
            outputs=predicted_tokens_ids,
            tokenizer=self.decoder_tokenizer,
            processor=target_processor,
        )
        labels = self.postprocess_outputs(
            outputs=labels,
            tokenizer=self.decoder_tokenizer,
            processor=target_processor,
        )
        encoder_inputs = self.postprocess_outputs(
            outputs=tokens_enc,
            tokenizer=self.encoder_tokenizer,
            processor=source_processor,
        )

        return {
            'inputs': encoder_inputs,
            'translations': preds,
            'ground_truths': labels,
            'loss': reduced_loss,
        }
예제 #7
0
    def _build_train_dataset(self, data_cfg):
        """Build the training dataset."""
        if (data_cfg.drop_last is False
                and data_cfg.global_batch_size > data_cfg.micro_batch_size *
                parallel_state.get_data_parallel_world_size()):
            raise ValueError(
                f"Cannot use drop_last=False in your training data with gradient accumulation found grad acc of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} with global_batch_size {data_cfg.global_batch_size}, micro_batch_size {data_cfg.micro_batch_size}, data parallel size {parallel_state.get_data_parallel_world_size()}"
            )
        datasets = []
        # Determine if we are using a single dataset or a list of datasets.
        is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig)
        is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig)

        if (is_src_list_config
                and not is_tgt_list_config) or (is_tgt_list_config
                                                and not is_src_list_config):
            raise ValueError(
                "src_list and tgt_list must both be either a ListConfig or a string. "
            )
        if is_src_list_config:
            if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name):
                raise ValueError(
                    "src_file_name and tgt_file_name must have the same number of elements. "
                )
        else:
            data_cfg.src_file_name = [data_cfg.src_file_name]
            data_cfg.tgt_file_name = [data_cfg.tgt_file_name]

        for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name):
            dataset = SequenceToSequenceDataset(
                src_file_name=src,
                tgt_file_name=tgt,
                src_tokenizer=self.tokenizer,
                tgt_tokenizer=self.tokenizer,
                max_src_seq_length=data_cfg.max_src_seq_length,
                max_tgt_seq_length=data_cfg.max_tgt_seq_length,
            )
            datasets.append(dataset)

        if len(datasets) > 1:
            dataset = ConcatDataset(
                datasets=datasets,
                sampling_technique=data_cfg.get('concat_sampling_technique',
                                                'temperature'),
                sampling_temperature=data_cfg.get(
                    'concat_sampling_temperature', 5),
                sampling_probabilities=data_cfg.get(
                    'concat_sampling_probabilities',
                    [1 / len(datasets)] * len(datasets)),
                global_rank=parallel_state.get_data_parallel_rank(),
                world_size=parallel_state.get_data_parallel_world_size(),
            )
            return dataset
        else:
            return datasets[0]
예제 #8
0
    def _setup_megatron_dataloader_from_config(self, cfg, dataset,
                                               consumed_samples):
        logging.info(
            f'Building dataloader with consumed samples: {consumed_samples}')
        rank = parallel_state.get_data_parallel_rank()
        world_size = parallel_state.get_data_parallel_world_size()
        if isinstance(dataset, BlendableDataset):
            collate_fn = dataset.datasets[0].collate_fn
        else:
            collate_fn = dataset.collate_fn

        if cfg.get("sampler", "distributed") == 'distributed':
            sampler = torch.utils.data.distributed.DistributedSampler(
                dataset,
                num_replicas=world_size,
                rank=rank,
                shuffle=True,
                seed=
                consumed_samples,  # Ensures that each time the model is restored, a new seed is used to see examples in a different order.
            )
            return torch.utils.data.DataLoader(
                dataset,
                collate_fn=collate_fn,
                sampler=sampler,
                batch_size=cfg.micro_batch_size,
                num_workers=cfg.num_workers,
                pin_memory=cfg.pin_memory,
                drop_last=cfg.drop_last,
            )
        elif cfg.get("sampler", "distributed") == 'megatron':
            batch_sampler = MegatronPretrainingBatchSampler(
                total_samples=len(dataset),
                consumed_samples=consumed_samples,
                micro_batch_size=cfg.micro_batch_size,
                global_batch_size=cfg.global_batch_size,
                data_parallel_rank=parallel_state.get_data_parallel_rank(),
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
                drop_last=True,
            )
            return torch.utils.data.DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=collate_fn,
                num_workers=cfg.num_workers,
                pin_memory=cfg.pin_memory,
            )
        else:
            raise ValueError(
                f"Invalid sampler {cfg.sampler}. Options: ['distributed', 'megatron']"
            )
예제 #9
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
              for microbatch in batch]
     batch = self.process_global_batch_for_tarred_datasets(batch)
     app_state = AppState()
     _reconfigure_microbatch_calculator(
         rank=app_state.global_rank,
         rampup_batch_size=None,
         global_batch_size=batch['text_enc'].size(0) *
         parallel_state.get_data_parallel_world_size(),
         micro_batch_size=batch['text_enc'].size(0),
         data_parallel_size=parallel_state.get_data_parallel_world_size(),
     )
     return super().training_step(batch, batch_idx)
예제 #10
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size *
             parallel_state.get_data_parallel_world_size() *
             get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
예제 #11
0
    def build_virtual_prompt_dataset(self, dataset_paths, batch_size,
                                     for_train, drop_last, shuffle,
                                     num_workers, pin_memory):
        dataset = GPTPromptLearningDataset(
            datasets=dataset_paths,
            tokenizer=self.tokenizer,
            virtual_prompt_source=self.virtual_prompt_source,
            task_templates=self.task_templates,
            pseudo_tokens=self.pseudo_tokens,
            pad_token_id=self.pad_token_id,
            max_seq_length=self.cfg.data.get(
                'max_seq_length',
                self.frozen_model.cfg.max_position_embeddings),
            min_seq_length=self.cfg.data.get('min_seq_length', 1),
            add_bos=self.cfg.data.get('add_bos', False),
            add_eos=self.cfg.data.get('add_eos', True),
            for_train=for_train,
        )

        rank = parallel_state.get_data_parallel_rank()
        world_size = parallel_state.get_data_parallel_world_size()
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)

        dataloader = torch.utils.data.DataLoader(
            dataset,
            collate_fn=dataset.collate_fn,
            sampler=sampler,
            batch_size=batch_size,
            drop_last=drop_last,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )

        return dataset, dataloader
예제 #12
0
 def build_train_valid_test_datasets(self):
     self._train_ds = MTEncDecModel._setup_dataset_from_config(
         cfg=self._cfg.train_ds,
         encoder_tokenizer=self.encoder_tokenizer,
         decoder_tokenizer=self.decoder_tokenizer,
         global_rank=parallel_state.get_data_parallel_rank(),
         world_size=parallel_state.get_data_parallel_world_size(),
         multilingual=self.multilingual,
         multilingual_ids=self.multilingual_ids,
     )
     self._validation_ds = MTEncDecModel._setup_eval_dataset_from_config(
         cfg=self._cfg.validation_ds,
         multilingual=self.multilingual,
         multilingual_ids=self.multilingual_ids,
         encoder_tokenizer=self.encoder_tokenizer,
         decoder_tokenizer=self.decoder_tokenizer,
     )
     # Test data config is optional.
     if hasattr(self._cfg, 'test_ds'):
         self._test_ds = MTEncDecModel._setup_eval_dataset_from_config(
             cfg=self._cfg.validation_ds,
             multilingual=self.multilingual,
             multilingual_ids=self.multilingual_ids,
             encoder_tokenizer=self.encoder_tokenizer,
             decoder_tokenizer=self.decoder_tokenizer,
         )
예제 #13
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_gpus
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    app_state.model_parallel_size)
                app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                logging.info(f'mp_rank: {app_state.model_parallel_rank}')
                logging.info(f'dp_rank: {app_state.data_parallel_rank}')
예제 #14
0
    def allreduce_gradients(self):
        """Reduce gradients across data parallel ranks.
           Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188
        """
        # Bucketize and all-reduce
        buckets = {}
        for param in self.parameters():
            if param.requires_grad and param.grad is not None:
                tp = param.data.type()
                if tp not in buckets:
                    buckets[tp] = []
                buckets[tp].append(param)
                # param.main_grad = param.grad

        # For each bucket, all-reduce and copy all-reduced grads.
        for tp in buckets:
            bucket = buckets[tp]
            grads = [param.grad.data for param in bucket]
            coalesced = torch._utils._flatten_dense_tensors(grads)
            coalesced /= parallel_state.get_data_parallel_world_size()
            torch.distributed.all_reduce(
                coalesced, group=parallel_state.get_data_parallel_group())
            for buf, synced in zip(
                    grads,
                    torch._utils._unflatten_dense_tensors(coalesced, grads)):
                buf.copy_(synced)
예제 #15
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size
             * parallel_state.get_data_parallel_world_size()
             * get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(),
         )
     # At this point batch is a list of dictionaries where eatch dict is a microbatch.
     # After the process_global_batch call, batch will be a single dictionary containing the global batch.
     # This is required since the parent class expects a single global batch dictioanry.
     batch = self._process_global_batch(batch)
     return super().training_step(batch, batch_idx)
예제 #16
0
 def build_tarred_train_dataset(self):
     return MTEncDecModel._setup_dataset_from_config(
         cfg=self._cfg.train_ds,
         encoder_tokenizer=self.encoder_tokenizer,
         decoder_tokenizer=self.decoder_tokenizer,
         global_rank=parallel_state.get_data_parallel_rank(),
         world_size=parallel_state.get_data_parallel_world_size(),
         multilingual=self.multilingual,
         multilingual_ids=self.multilingual_ids,
     )
예제 #17
0
 def on_validation_epoch_start(self):
     app_state = AppState()
     _reconfigure_microbatch_calculator(
         rank=app_state.global_rank,
         rampup_batch_size=None,
         global_batch_size=self.cfg.data.validation_ds.global_batch_size,
         micro_batch_size=self.cfg.data.validation_ds.micro_batch_size,
         data_parallel_size=parallel_state.get_data_parallel_world_size(),
     )
     return super().on_validation_epoch_start()
예제 #18
0
 def on_validation_epoch_end(self):
     app_state = AppState()
     if hasattr(self, "_train_ds"):
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=self._cfg.train_ds.global_batch_size,
             micro_batch_size=self._cfg.train_ds.micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
예제 #19
0
    def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
        for data_parallel_size in range(1, self.world_size + 1):

            expected_global_batch_size = self.GLOBAL_BATCH_SIZE
            expected_micro_batch_size = self.MICRO_BATCH_SIZE
            if rampup_batch_size:
                expected_global_batch_size = rampup_batch_size[0]
                num_consumed_samples = 0
                step_of_global_batch_size = rampup_batch_size[1]
                threshold = rampup_batch_size[2]

            if data_parallel_size > 1 and data_parallel_size % 2 != 0:
                continue
            if self.world_size % data_parallel_size != 0:
                continue
            with self.subTest(data_parallel_size=data_parallel_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=self.world_size //
                    data_parallel_size,
                    pipeline_model_parallel_size_=1,
                )
                self.assertEqual(data_parallel_size,
                                 parallel_state.get_data_parallel_world_size())

                _reconfigure_microbatch_calculator(
                    self.rank,
                    rampup_batch_size,
                    self.GLOBAL_BATCH_SIZE,
                    self.MICRO_BATCH_SIZE,
                    data_parallel_size,
                )

                self.assertEqual(get_micro_batch_size(),
                                 expected_micro_batch_size)
                self.assertEqual(
                    get_num_microbatches(), expected_global_batch_size /
                    expected_micro_batch_size / data_parallel_size)
                current_global_batch_size = get_current_global_batch_size()
                self.assertEqual(current_global_batch_size,
                                 expected_global_batch_size)

                # Make sure `global_batch_size` equals to the final global batch size after
                # certain number of updates.
                if rampup_batch_size:
                    update_num_microbatches(current_global_batch_size)
                    for i in range(100):
                        current_global_batch_size = get_current_global_batch_size(
                        )
                        update_num_microbatches(current_global_batch_size)
                    current_global_batch_size = get_current_global_batch_size()
                    self.assertEqual(get_current_global_batch_size(),
                                     self.GLOBAL_BATCH_SIZE)
                parallel_state.destroy_model_parallel()
예제 #20
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         batch = [[
             x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch
         ] for microbatch in batch]
         batch = self.process_global_batch_for_tarred_datasets(batch)
     elif (self._cfg.train_ds.dataset_type in ['bin_memmap', 'text_memmap']
           and self._cfg.train_ds.get("sampler",
                                      "distributed") == 'distributed'):
         batch = self._process_global_batch_without_megatron_batch_sampler(
             batch, tokenizer=self.encoder_tokenizer)
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=batch['text_enc'].size(0) *
             parallel_state.get_data_parallel_world_size(),
             micro_batch_size=batch['text_enc'].size(0),
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
예제 #21
0
        def param_hook(*unused):
            # Accumulates gradients on main gradients
            if param.grad.data is not None:
                if main_param.grad is None:
                    main_param.grad = param.grad.float()
                else:
                    main_param.grad.add_(param.grad.data)
                # Deallocate grad memory.
                param.grad = None

                # Asynchronous gradients allreduce accross data_parallel ranks
                if self._require_backward_grad_sync:
                    main_param.grad.div_(get_data_parallel_world_size())
                    torch.distributed.all_reduce(
                        main_param.grad,
                        group=get_data_parallel_group(),
                        async_op=True)
예제 #22
0
    def build_data_loader(
        self,
        dataset,
        micro_batch_size,
        global_batch_size,
        shuffle,
        num_workers,
        pin_memory,
        drop_last,
        check_validation_interval,
    ):
        """Buld dataloader given an input dataset."""

        if dataset is None:
            return None

        rank = parallel_state.get_data_parallel_rank()
        world_size = parallel_state.get_data_parallel_world_size()
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=shuffle
        )
        # This check makes sure the val_check_interval is less than the number of global batches.
        # Normally, PTL would do this check and properly account for gradient accumulation.
        # But now, it is implicit in the apex fwd/bwd functions and so we need to check for this somewhere.
        # The consequence of not doing this is that training loop will never run validation.
        # NOTE: Prog bar is also broken as a result of this.
        global_batch_size_per_gpu = micro_batch_size * get_num_microbatches()
        if (
            self.trainer.val_check_interval > (sampler.num_samples // global_batch_size_per_gpu)
            and check_validation_interval
        ):
            raise ValueError(
                f"trainer.val_check_interval {self.trainer.val_check_interval} is > number of global batches {sampler.num_samples // global_batch_size}"
            )

        # Data loader. Note that batch size is the per GPU batch size.
        return torch.utils.data.DataLoader(
            dataset,
            collate_fn=dataset.collate_fn,
            sampler=sampler,
            batch_size=micro_batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
        )
예제 #23
0
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_devices
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            # destroy groups in case they have already been created
            # this happens with multiple calls to trainer.test for example
            parallel_state.destroy_model_parallel()
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=app_state.
                    tensor_model_parallel_size,
                    pipeline_model_parallel_size_=app_state.
                    pipeline_model_parallel_size,
                    pipeline_model_parallel_split_rank_=app_state.
                    pipeline_model_parallel_split_rank,
                )

                # assert that fake tp and pp rank match after model parallel init
                assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank(
                )
                assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank(
                )

                app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group(
                )
예제 #24
0
    def _setup_eval_dataloader_from_config(self, cfg: DictConfig, dataset):

        rank = parallel_state.get_data_parallel_rank()
        world_size = parallel_state.get_data_parallel_world_size()
        dataloaders = []
        for _dataset in dataset:
            sampler = torch.utils.data.distributed.DistributedSampler(
                _dataset, num_replicas=world_size, rank=rank, shuffle=False)
            dataloaders.append(
                torch.utils.data.DataLoader(
                    dataset=_dataset,
                    batch_size=1,
                    sampler=sampler,
                    num_workers=cfg.get("num_workers", 0),
                    pin_memory=cfg.get("pin_memory", False),
                    drop_last=cfg.get("drop_last", False),
                    shuffle=False,
                ))

        return dataloaders
예제 #25
0
    def _build_eval_dataset(self, data_cfg):
        """Build the evaluation dataset."""
        if data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size():
            raise ValueError(
                f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.'
            )
        datasets = []
        # Determine if we are using a single dataset or a list of datasets.
        is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig)
        is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig)
        is_names_list_config = False
        if hasattr(data_cfg, "names"):
            if isinstance(data_cfg.names, ListConfig):
                is_names_list_config = True

        if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config):
            raise ValueError("src_list and tgt_list must both be either a ListConfig or a string. ")
        if is_src_list_config:
            if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name):
                raise ValueError("src_file_name and tgt_file_name must have the same number of elements. ")
            if is_names_list_config and len(data_cfg.names) != len(data_cfg.src_file_name):
                raise ValueError(
                    "If you are providing names for each src/tgt file, they must have the same number of elements."
                )
        else:
            data_cfg.src_file_name = [data_cfg.src_file_name]
            data_cfg.tgt_file_name = [data_cfg.tgt_file_name]

        for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name):
            dataset = SequenceToSequenceDataset(
                src_file_name=src,
                tgt_file_name=tgt,
                src_tokenizer=self.tokenizer,
                tgt_tokenizer=self.tokenizer,
                max_src_seq_length=data_cfg.max_src_seq_length,
                max_tgt_seq_length=data_cfg.max_tgt_seq_length,
            )
            datasets.append(dataset)

        return datasets
예제 #26
0
def test_initialize_model_parallel(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(
        tensor_model_parallel_size,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
    assert parallel_state.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == parallel_state.get_tensor_model_parallel_world_size()
    assert rank == parallel_state.get_tensor_model_parallel_rank()
    check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)

    # Data parallel.
    world_size = torch.distributed.get_world_size(
    ) // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
    assert world_size == parallel_state.get_data_parallel_world_size()
    assert rank == parallel_state.get_data_parallel_rank()
    check(parallel_state.get_data_parallel_group(), world_size, rank)

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
예제 #27
0
    def build_pretraining_data_loader(self, dataset, batch_size, shuffle,
                                      num_workers, pin_memory):
        """Buld dataloader given an input dataset."""

        if dataset is None:
            return None

        rank = parallel_state.get_data_parallel_rank()
        world_size = parallel_state.get_data_parallel_world_size()
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)

        # Data loader. Note that batch size is the per GPU batch size.
        return torch.utils.data.DataLoader(
            dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=False,
        )
예제 #28
0
 def _build_dataset(self, data_cfg, check_implict_grad_acc=False):
     if (check_implict_grad_acc
             and data_cfg.global_batch_size > data_cfg.micro_batch_size *
             parallel_state.get_data_parallel_world_size()):
         raise ValueError(
             f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.'
         )
     if data_cfg.task_name == 'xnli':
         dataset = TextToTextXNLIDataset(
             data_cfg.file_path,
             task_name=data_cfg.task_name,
             tokenizer=self.tokenizer,
             max_seq_length=data_cfg.max_seq_length,
             lang_list=self.cfg.eval_languages,
         )
     else:
         dataset = TextToTextGLUEDataset(
             data_cfg.file_path,
             task_name=data_cfg.task_name,
             tokenizer=self.tokenizer,
             max_seq_length=data_cfg.max_seq_length,
         )
     return dataset
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None):
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != self.tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
예제 #30
0
    def _forward_backward_test_impl(
        self,
        forward_only: bool,
        fwd_bwd_func: FwdStepFunc,
        pipeline_model_parallel_world_size: Optional[int],
        virtual_pipeline_model_parallel_size: Optional[int],
        async_comm: bool = False,
        *,
        default_backend: Optional[str] = None,
        p2p_backend: Optional[str] = None,
    ) -> None:
        if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
            self.assertIsNotNone(virtual_pipeline_model_parallel_size)
            self.assertGreater(virtual_pipeline_model_parallel_size, 1)
        dtype_options = self.dtypes or [torch.float32, torch.double
                                        ] + _get_autocast_dtypes()

        for dtype, deallocate_pipeline_outputs in itertools.product(
                dtype_options,
                self.deallocate_options,
        ):
            grad_scaler = (torch.cuda.amp.GradScaler(
                init_scale=4.0) if dtype == torch.half else None)

            (tensor_model_parallel_world_size, data_parallel_size,
             pipeline_model_parallel_world_size
             ) = _get_default_world_sizes_model_parallel_world_size(
                 pipeline_model_parallel_world_size)

            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_parallel_world_size,
                pipeline_model_parallel_size_=
                pipeline_model_parallel_world_size,
                virtual_pipeline_model_parallel_size_=
                virtual_pipeline_model_parallel_size,
                default_backend=default_backend,
                p2p_backend=p2p_backend,
            )
            pp_utils._reconfigure_microbatch_calculator(
                rank=parallel_state.get_tensor_model_parallel_rank(),
                rampup_batch_size=None,
                global_batch_size=self.GLOBAL_BATCH_SIZE,
                micro_batch_size=self.MICRO_BATCH_SIZE,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

            global_batch_shape = (
                self.GLOBAL_BATCH_SIZE //
                parallel_state.get_data_parallel_world_size(),
                self.HIDDEN_SIZE,
                self.HIDDEN_SIZE,
            )

            batch = None
            if parallel_state.is_pipeline_first_stage():
                batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )

            model = build_model(
                testing_utils.model_provider_func,
                # Use DDP only when it's better to have
                wrap_with_ddp=data_parallel_size > 1,
                virtual_pipeline_model_parallel_size=
                virtual_pipeline_model_parallel_size,
                hidden_size=self.HIDDEN_SIZE,
            )

            offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
            for idx, model_module in enumerate(model):
                model_module = model_module.to(dtype)
                model_module.apply(get_init_weights_func(idx * offset))

            _param_groups = _get_params_for_weight_decay_optimization(model)
            optimizer = torch.optim.Adam(_param_groups, lr=1e-3)

            pp_utils.update_num_microbatches(0)

            loss = fwd_bwd_func(
                testing_utils.fwd_step_func,
                batch,
                model,
                forward_only=forward_only,
                # `tensor_shape` is the shape of micro batch.
                tensor_shape=(
                    self.MICRO_BATCH_SIZE,
                    self.HIDDEN_SIZE,
                    self.HIDDEN_SIZE,
                ),
                dtype=dtype,
                async_comm=async_comm,
                grad_scaler=grad_scaler,
                deallocate_pipeline_output=deallocate_pipeline_outputs,
            )

            if dtype == torch.double:
                hidden_size = self.HIDDEN_SIZE
                microbatch_size = self.MICRO_BATCH_SIZE
                total_layers = pipeline_model_parallel_world_size
                if virtual_pipeline_model_parallel_size is not None:
                    total_layers *= virtual_pipeline_model_parallel_size
                target_loss, target_model = get_target_loss_and_model(
                    global_batch_shape, hidden_size, total_layers)

                for loss_item in loss:
                    x = loss_item['avg']
                    torch.testing.assert_close(x.item() / microbatch_size,
                                               target_loss.item())

                if not forward_only:
                    for vm_id, model_module in enumerate(model):
                        params = list(model_module.parameters())
                        rank = params[0].get_device()
                        offset = pipeline_model_parallel_world_size
                        param_id = rank // data_parallel_size + vm_id * offset
                        target_params = target_model[param_id]

                        torch.testing.assert_close(params[0].cpu(),
                                                   target_params[0])
                        torch.testing.assert_close(params[1].cpu(),
                                                   target_params[1])
                        torch.testing.assert_close(
                            params[0].grad.cpu() / microbatch_size,
                            target_params[0].grad)
                        torch.testing.assert_close(
                            params[1].grad.cpu() / microbatch_size,
                            target_params[1].grad)

            if not forward_only:
                for m in model:
                    for p in m.parameters():
                        self.assertIsNotNone(p.grad)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            parallel_state.destroy_model_parallel()