def on_validation_epoch_end(self): app_state = AppState() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, micro_batch_size=self.cfg.data.train_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # When running `trainer.validate()`, the training dataset is not available. else: logging.warning( 'No training data found, reconfiguring microbatches based on validation batch sizes.' ) _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds. global_batch_size, micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().on_validation_epoch_end()
def build_pretraining_data_loader(self, dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: if self.cfg.data.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) elif self.cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) else: raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') else: raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') # Torch dataloader. return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True, )
def build_pretraining_data_loader(self, dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None # Megatron sampler if self._cfg.data.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) elif self._cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) else: raise Exception('{} dataloader type is not supported.'.format(self._cfg.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=self._cfg.data.num_workers, pin_memory=True, )
def param_hook(*unused): """Gradient accumulation and all-reduce.""" if param.grad.data is None: return if main_param.grad is None: main_param.grad = param.grad.float() else: main_param.grad.add_(param.grad.data) # Deallocate grad memory. param.grad = None # Asynchronous gradients allreduce across data_parallel ranks if self._require_backward_grad_sync: if self._grad_allreduce_chunk_size_mb > 0: self._main_grad_buffers[i].update_chunk_info( grad_chunk_info) while True: allreduce_tensor = self._main_grad_buffers[ i].get_allreduce_tensor() if allreduce_tensor is None: break allreduce_tensor.div_(get_data_parallel_world_size()) torch.distributed.all_reduce( allreduce_tensor, group=get_data_parallel_group(), async_op=True) else: main_param.grad.div_(get_data_parallel_world_size()) torch.distributed.all_reduce( # type: ignore main_param.grad, group=get_data_parallel_group(), async_op=True)
def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): batch_has_lang_information = len(batch[0]) == 7 micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. processed_batch = self._process_global_batch(batch) # Call parent validation step to get the loss. # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class. loss = super().validation_step(processed_batch, batch_idx) predicted_token_ids, _ = self.decode( tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30 ) # Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers. preds_text = self.ids_to_text(predicted_token_ids) labels_text = self.ids_to_text(processed_batch['labels']) input_text = self.ids_to_text(processed_batch['text_enc']) if not batch_has_lang_information: categories = [None] * len(preds_text) else: categories = processed_batch['lang'] metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] assert len(categories) == len(preds_text) == len(labels_text) for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)): # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats. pred, label = self.cast_for_metric( pred, label, self.val_metric_name if mode == 'validation' else self.test_metric_name ) if batch_has_lang_information: _ = metric(pred, label, category) else: _ = metric(pred, label) return { 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories, 'inputs': input_text, }
def eval_step(self, batch, batch_idx, dataloader_idx, data_cfg): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) if data_cfg.dataset_type in ['tarred', 'text']: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # This returns the averaged loss across data-parallel groups. reduced_loss = super().validation_step(batch, batch_idx) tokens_enc, labels, enc_mask = batch['text_enc'], batch[ 'labels'], batch['enc_mask'] predicted_tokens_ids, _ = self.decode( tokens_enc, enc_mask, tokens_enc.size(1) + self._cfg. max_generation_delta, # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>. tokenizer=self.decoder_tokenizer, ) if self.multilingual: source_processor = self.source_processor_list[dataloader_idx] target_processor = self.target_processor_list[dataloader_idx] else: source_processor = self.source_processor target_processor = self.target_processor # Post-process the translations and inputs to log. preds = self.postprocess_outputs( outputs=predicted_tokens_ids, tokenizer=self.decoder_tokenizer, processor=target_processor, ) labels = self.postprocess_outputs( outputs=labels, tokenizer=self.decoder_tokenizer, processor=target_processor, ) encoder_inputs = self.postprocess_outputs( outputs=tokens_enc, tokenizer=self.encoder_tokenizer, processor=source_processor, ) return { 'inputs': encoder_inputs, 'translations': preds, 'ground_truths': labels, 'loss': reduced_loss, }
def _build_train_dataset(self, data_cfg): """Build the training dataset.""" if (data_cfg.drop_last is False and data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size()): raise ValueError( f"Cannot use drop_last=False in your training data with gradient accumulation found grad acc of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} with global_batch_size {data_cfg.global_batch_size}, micro_batch_size {data_cfg.micro_batch_size}, data parallel size {parallel_state.get_data_parallel_world_size()}" ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError( "src_list and tgt_list must both be either a ListConfig or a string. " ) if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError( "src_file_name and tgt_file_name must have the same number of elements. " ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, src_tokenizer=self.tokenizer, tgt_tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=data_cfg.get('concat_sampling_technique', 'temperature'), sampling_temperature=data_cfg.get( 'concat_sampling_temperature', 5), sampling_probabilities=data_cfg.get( 'concat_sampling_probabilities', [1 / len(datasets)] * len(datasets)), global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), ) return dataset else: return datasets[0]
def _setup_megatron_dataloader_from_config(self, cfg, dataset, consumed_samples): logging.info( f'Building dataloader with consumed samples: {consumed_samples}') rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() if isinstance(dataset, BlendableDataset): collate_fn = dataset.datasets[0].collate_fn else: collate_fn = dataset.collate_fn if cfg.get("sampler", "distributed") == 'distributed': sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True, seed= consumed_samples, # Ensures that each time the model is restored, a new seed is used to see examples in a different order. ) return torch.utils.data.DataLoader( dataset, collate_fn=collate_fn, sampler=sampler, batch_size=cfg.micro_batch_size, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=cfg.drop_last, ) elif cfg.get("sampler", "distributed") == 'megatron': batch_sampler = MegatronPretrainingBatchSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=cfg.micro_batch_size, global_batch_size=cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size( ), drop_last=True, ) return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, ) else: raise ValueError( f"Invalid sampler {cfg.sampler}. Options: ['distributed', 'megatron']" )
def training_step(self, batch, batch_idx): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return super().training_step(batch, batch_idx)
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().training_step(batch, batch_idx)
def build_virtual_prompt_dataset(self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory): dataset = GPTPromptLearningDataset( datasets=dataset_paths, tokenizer=self.tokenizer, virtual_prompt_source=self.virtual_prompt_source, task_templates=self.task_templates, pseudo_tokens=self.pseudo_tokens, pad_token_id=self.pad_token_id, max_seq_length=self.cfg.data.get( 'max_seq_length', self.frozen_model.cfg.max_position_embeddings), min_seq_length=self.cfg.data.get('min_seq_length', 1), add_bos=self.cfg.data.get('add_bos', False), add_eos=self.cfg.data.get('add_eos', True), for_train=for_train, ) rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, batch_size=batch_size, drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, ) return dataset, dataloader
def build_train_valid_test_datasets(self): self._train_ds = MTEncDecModel._setup_dataset_from_config( cfg=self._cfg.train_ds, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, ) self._validation_ds = MTEncDecModel._setup_eval_dataset_from_config( cfg=self._cfg.validation_ds, multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, ) # Test data config is optional. if hasattr(self._cfg, 'test_ds'): self._test_ds = MTEncDecModel._setup_eval_dataset_from_config( cfg=self._cfg.validation_ds, multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, )
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def allreduce_gradients(self): """Reduce gradients across data parallel ranks. Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188 """ # Bucketize and all-reduce buckets = {} for param in self.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) # param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = torch._utils._flatten_dense_tensors(grads) coalesced /= parallel_state.get_data_parallel_world_size() torch.distributed.all_reduce( coalesced, group=parallel_state.get_data_parallel_group()) for buf, synced in zip( grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced)
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. batch = self._process_global_batch(batch) return super().training_step(batch, batch_idx)
def build_tarred_train_dataset(self): return MTEncDecModel._setup_dataset_from_config( cfg=self._cfg.train_ds, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, )
def on_validation_epoch_start(self): app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds.global_batch_size, micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return super().on_validation_epoch_start()
def on_validation_epoch_end(self): app_state = AppState() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self._cfg.train_ds.global_batch_size, micro_batch_size=self._cfg.train_ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), )
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: for data_parallel_size in range(1, self.world_size + 1): expected_global_batch_size = self.GLOBAL_BATCH_SIZE expected_micro_batch_size = self.MICRO_BATCH_SIZE if rampup_batch_size: expected_global_batch_size = rampup_batch_size[0] num_consumed_samples = 0 step_of_global_batch_size = rampup_batch_size[1] threshold = rampup_batch_size[2] if data_parallel_size > 1 and data_parallel_size % 2 != 0: continue if self.world_size % data_parallel_size != 0: continue with self.subTest(data_parallel_size=data_parallel_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=self.world_size // data_parallel_size, pipeline_model_parallel_size_=1, ) self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) _reconfigure_microbatch_calculator( self.rank, rampup_batch_size, self.GLOBAL_BATCH_SIZE, self.MICRO_BATCH_SIZE, data_parallel_size, ) self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) self.assertEqual( get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(current_global_batch_size, expected_global_batch_size) # Make sure `global_batch_size` equals to the final global batch size after # certain number of updates. if rampup_batch_size: update_num_microbatches(current_global_batch_size) for i in range(100): current_global_batch_size = get_current_global_batch_size( ) update_num_microbatches(current_global_batch_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE) parallel_state.destroy_model_parallel()
def training_step(self, batch, batch_idx): # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1. if self._cfg.train_ds.dataset_type in ['tarred', 'text']: batch = [[ x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch ] for microbatch in batch] batch = self.process_global_batch_for_tarred_datasets(batch) elif (self._cfg.train_ds.dataset_type in ['bin_memmap', 'text_memmap'] and self._cfg.train_ds.get("sampler", "distributed") == 'distributed'): batch = self._process_global_batch_without_megatron_batch_sampler( batch, tokenizer=self.encoder_tokenizer) if self._cfg.train_ds.dataset_type in ['tarred', 'text']: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), micro_batch_size=batch['text_enc'].size(0), data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().training_step(batch, batch_idx)
def param_hook(*unused): # Accumulates gradients on main gradients if param.grad.data is not None: if main_param.grad is None: main_param.grad = param.grad.float() else: main_param.grad.add_(param.grad.data) # Deallocate grad memory. param.grad = None # Asynchronous gradients allreduce accross data_parallel ranks if self._require_backward_grad_sync: main_param.grad.div_(get_data_parallel_world_size()) torch.distributed.all_reduce( main_param.grad, group=get_data_parallel_group(), async_op=True)
def build_data_loader( self, dataset, micro_batch_size, global_batch_size, shuffle, num_workers, pin_memory, drop_last, check_validation_interval, ): """Buld dataloader given an input dataset.""" if dataset is None: return None rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle ) # This check makes sure the val_check_interval is less than the number of global batches. # Normally, PTL would do this check and properly account for gradient accumulation. # But now, it is implicit in the apex fwd/bwd functions and so we need to check for this somewhere. # The consequence of not doing this is that training loop will never run validation. # NOTE: Prog bar is also broken as a result of this. global_batch_size_per_gpu = micro_batch_size * get_num_microbatches() if ( self.trainer.val_check_interval > (sampler.num_samples // global_batch_size_per_gpu) and check_validation_interval ): raise ValueError( f"trainer.val_check_interval {self.trainer.val_check_interval} is > number of global batches {sampler.num_samples // global_batch_size}" ) # Data loader. Note that batch size is the per GPU batch size. return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, batch_size=micro_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, )
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )
def _setup_eval_dataloader_from_config(self, cfg: DictConfig, dataset): rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() dataloaders = [] for _dataset in dataset: sampler = torch.utils.data.distributed.DistributedSampler( _dataset, num_replicas=world_size, rank=rank, shuffle=False) dataloaders.append( torch.utils.data.DataLoader( dataset=_dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 0), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), shuffle=False, )) return dataloaders
def _build_eval_dataset(self, data_cfg): """Build the evaluation dataset.""" if data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size(): raise ValueError( f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.' ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) is_names_list_config = False if hasattr(data_cfg, "names"): if isinstance(data_cfg.names, ListConfig): is_names_list_config = True if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError("src_list and tgt_list must both be either a ListConfig or a string. ") if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError("src_file_name and tgt_file_name must have the same number of elements. ") if is_names_list_config and len(data_cfg.names) != len(data_cfg.src_file_name): raise ValueError( "If you are providing names for each src/tgt file, they must have the same number of elements." ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, src_tokenizer=self.tokenizer, tgt_tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) return datasets
def test_initialize_model_parallel(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( tensor_model_parallel_size)) tensor_model_parallel_size_ = min( tensor_model_parallel_size, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size_) assert parallel_state.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = tensor_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_ assert world_size == parallel_state.get_tensor_model_parallel_world_size() assert rank == parallel_state.get_tensor_model_parallel_rank() check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size( ) // tensor_model_parallel_size_ rank = torch.distributed.get_rank() // tensor_model_parallel_size assert world_size == parallel_state.get_data_parallel_world_size() assert rank == parallel_state.get_data_parallel_rank() check(parallel_state.get_data_parallel_group(), world_size, rank) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def build_pretraining_data_loader(self, dataset, batch_size, shuffle, num_workers, pin_memory): """Buld dataloader given an input dataset.""" if dataset is None: return None rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) # Data loader. Note that batch size is the per GPU batch size. return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, )
def _build_dataset(self, data_cfg, check_implict_grad_acc=False): if (check_implict_grad_acc and data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size()): raise ValueError( f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.' ) if data_cfg.task_name == 'xnli': dataset = TextToTextXNLIDataset( data_cfg.file_path, task_name=data_cfg.task_name, tokenizer=self.tokenizer, max_seq_length=data_cfg.max_seq_length, lang_list=self.cfg.eval_languages, ) else: dataset = TextToTextGLUEDataset( data_cfg.file_path, task_name=data_cfg.task_name, tokenizer=self.tokenizer, max_seq_length=data_cfg.max_seq_length, ) return dataset
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None): app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != self.tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def _forward_backward_test_impl( self, forward_only: bool, fwd_bwd_func: FwdStepFunc, pipeline_model_parallel_world_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int], async_comm: bool = False, *, default_backend: Optional[str] = None, p2p_backend: Optional[str] = None, ) -> None: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertGreater(virtual_pipeline_model_parallel_size, 1) dtype_options = self.dtypes or [torch.float32, torch.double ] + _get_autocast_dtypes() for dtype, deallocate_pipeline_outputs in itertools.product( dtype_options, self.deallocate_options, ): grad_scaler = (torch.cuda.amp.GradScaler( init_scale=4.0) if dtype == torch.half else None) (tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size ) = _get_default_world_sizes_model_parallel_world_size( pipeline_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_size, default_backend=default_backend, p2p_backend=p2p_backend, ) pp_utils._reconfigure_microbatch_calculator( rank=parallel_state.get_tensor_model_parallel_rank(), rampup_batch_size=None, global_batch_size=self.GLOBAL_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) global_batch_shape = ( self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), self.HIDDEN_SIZE, self.HIDDEN_SIZE, ) batch = None if parallel_state.is_pipeline_first_stage(): batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) model = build_model( testing_utils.model_provider_func, # Use DDP only when it's better to have wrap_with_ddp=data_parallel_size > 1, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=self.HIDDEN_SIZE, ) offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 for idx, model_module in enumerate(model): model_module = model_module.to(dtype) model_module.apply(get_init_weights_func(idx * offset)) _param_groups = _get_params_for_weight_decay_optimization(model) optimizer = torch.optim.Adam(_param_groups, lr=1e-3) pp_utils.update_num_microbatches(0) loss = fwd_bwd_func( testing_utils.fwd_step_func, batch, model, forward_only=forward_only, # `tensor_shape` is the shape of micro batch. tensor_shape=( self.MICRO_BATCH_SIZE, self.HIDDEN_SIZE, self.HIDDEN_SIZE, ), dtype=dtype, async_comm=async_comm, grad_scaler=grad_scaler, deallocate_pipeline_output=deallocate_pipeline_outputs, ) if dtype == torch.double: hidden_size = self.HIDDEN_SIZE microbatch_size = self.MICRO_BATCH_SIZE total_layers = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None: total_layers *= virtual_pipeline_model_parallel_size target_loss, target_model = get_target_loss_and_model( global_batch_shape, hidden_size, total_layers) for loss_item in loss: x = loss_item['avg'] torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) if not forward_only: for vm_id, model_module in enumerate(model): params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] torch.testing.assert_close(params[0].cpu(), target_params[0]) torch.testing.assert_close(params[1].cpu(), target_params[1]) torch.testing.assert_close( params[0].grad.cpu() / microbatch_size, target_params[0].grad) torch.testing.assert_close( params[1].grad.cpu() / microbatch_size, target_params[1].grad) if not forward_only: for m in model: for p in m.parameters(): self.assertIsNotNone(p.grad) optimizer.step() optimizer.zero_grad(set_to_none=True) parallel_state.destroy_model_parallel()