def build_pretraining_data_loader(self, dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: if self.cfg.data.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) elif self.cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) else: raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') else: raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') # Torch dataloader. return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True, )
def build_pretraining_data_loader(self, dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None # Megatron sampler if self._cfg.data.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) elif self._cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) else: raise Exception('{} dataloader type is not supported.'.format(self._cfg.dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=self._cfg.data.num_workers, pin_memory=True, )
def _setup_megatron_dataloader_from_config(self, cfg, dataset, consumed_samples): logging.info( f'Building dataloader with consumed samples: {consumed_samples}') rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() if isinstance(dataset, BlendableDataset): collate_fn = dataset.datasets[0].collate_fn else: collate_fn = dataset.collate_fn if cfg.get("sampler", "distributed") == 'distributed': sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True, seed= consumed_samples, # Ensures that each time the model is restored, a new seed is used to see examples in a different order. ) return torch.utils.data.DataLoader( dataset, collate_fn=collate_fn, sampler=sampler, batch_size=cfg.micro_batch_size, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=cfg.drop_last, ) elif cfg.get("sampler", "distributed") == 'megatron': batch_sampler = MegatronPretrainingBatchSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=cfg.micro_batch_size, global_batch_size=cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size( ), drop_last=True, ) return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, ) else: raise ValueError( f"Invalid sampler {cfg.sampler}. Options: ['distributed', 'megatron']" )
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def build_virtual_prompt_dataset(self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory): dataset = GPTPromptLearningDataset( datasets=dataset_paths, tokenizer=self.tokenizer, virtual_prompt_source=self.virtual_prompt_source, task_templates=self.task_templates, pseudo_tokens=self.pseudo_tokens, pad_token_id=self.pad_token_id, max_seq_length=self.cfg.data.get( 'max_seq_length', self.frozen_model.cfg.max_position_embeddings), min_seq_length=self.cfg.data.get('min_seq_length', 1), add_bos=self.cfg.data.get('add_bos', False), add_eos=self.cfg.data.get('add_eos', True), for_train=for_train, ) rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, batch_size=batch_size, drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, ) return dataset, dataloader
def build_train_valid_test_datasets(self): self._train_ds = MTEncDecModel._setup_dataset_from_config( cfg=self._cfg.train_ds, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, ) self._validation_ds = MTEncDecModel._setup_eval_dataset_from_config( cfg=self._cfg.validation_ds, multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, ) # Test data config is optional. if hasattr(self._cfg, 'test_ds'): self._test_ds = MTEncDecModel._setup_eval_dataset_from_config( cfg=self._cfg.validation_ds, multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, )
def __init__(self, cfg, trainer, tokenizer, name, size): super().__init__() self.name = name self.tokenizer = tokenizer self._cfg = cfg self.size = size seed_val = parallel_state.get_data_parallel_rank() * 131 + 97 torch.manual_seed(seed_val)
def build_tarred_train_dataset(self): return MTEncDecModel._setup_dataset_from_config( cfg=self._cfg.train_ds, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, )
def _build_train_dataset(self, data_cfg): """Build the training dataset.""" if (data_cfg.drop_last is False and data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size()): raise ValueError( f"Cannot use drop_last=False in your training data with gradient accumulation found grad acc of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} with global_batch_size {data_cfg.global_batch_size}, micro_batch_size {data_cfg.micro_batch_size}, data parallel size {parallel_state.get_data_parallel_world_size()}" ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError( "src_list and tgt_list must both be either a ListConfig or a string. " ) if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError( "src_file_name and tgt_file_name must have the same number of elements. " ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, src_tokenizer=self.tokenizer, tgt_tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=data_cfg.get('concat_sampling_technique', 'temperature'), sampling_temperature=data_cfg.get( 'concat_sampling_temperature', 5), sampling_probabilities=data_cfg.get( 'concat_sampling_probabilities', [1 / len(datasets)] * len(datasets)), global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), ) return dataset else: return datasets[0]
def test_broadcast_data(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing broadcast_data with model parallel size {} ...'.format( tensor_model_parallel_size)) parallel_state.initialize_model_parallel(tensor_model_parallel_size) torch.manual_seed(1234 + parallel_state.get_data_parallel_rank()) tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size( ) key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12], } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def report_memory(name): """Simple GPU memory report.""" mega_bytes = 1024.0 * 1024.0 string = name + " memory (MB)" string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes) string += " | max allocated: {}".format(torch.cuda.max_memory_allocated() / mega_bytes) string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes) string += " | max reserved: {}".format(torch.cuda.max_memory_reserved() / mega_bytes) if parallel_state.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def build_data_loader( self, dataset, micro_batch_size, global_batch_size, shuffle, num_workers, pin_memory, drop_last, check_validation_interval, ): """Buld dataloader given an input dataset.""" if dataset is None: return None rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle ) # This check makes sure the val_check_interval is less than the number of global batches. # Normally, PTL would do this check and properly account for gradient accumulation. # But now, it is implicit in the apex fwd/bwd functions and so we need to check for this somewhere. # The consequence of not doing this is that training loop will never run validation. # NOTE: Prog bar is also broken as a result of this. global_batch_size_per_gpu = micro_batch_size * get_num_microbatches() if ( self.trainer.val_check_interval > (sampler.num_samples // global_batch_size_per_gpu) and check_validation_interval ): raise ValueError( f"trainer.val_check_interval {self.trainer.val_check_interval} is > number of global batches {sampler.num_samples // global_batch_size}" ) # Data loader. Note that batch size is the per GPU batch size. return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, batch_size=micro_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, )
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )
def _setup_eval_dataloader_from_config(self, cfg: DictConfig, dataset): rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() dataloaders = [] for _dataset in dataset: sampler = torch.utils.data.distributed.DistributedSampler( _dataset, num_replicas=world_size, rank=rank, shuffle=False) dataloaders.append( torch.utils.data.DataLoader( dataset=_dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 0), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), shuffle=False, )) return dataloaders
def build_pretraining_data_loader(self, dataset, batch_size, shuffle, num_workers, pin_memory): """Buld dataloader given an input dataset.""" if dataset is None: return None rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) # Data loader. Note that batch size is the per GPU batch size. return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, )
def test_initialize_model_parallel(tensor_model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( tensor_model_parallel_size)) tensor_model_parallel_size_ = min( tensor_model_parallel_size, torch.distributed.get_world_size(), ) assert not parallel_state.model_parallel_is_initialized() parallel_state.initialize_model_parallel(tensor_model_parallel_size_) assert parallel_state.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = tensor_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_ assert world_size == parallel_state.get_tensor_model_parallel_world_size() assert rank == parallel_state.get_tensor_model_parallel_rank() check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size( ) // tensor_model_parallel_size_ rank = torch.distributed.get_rank() // tensor_model_parallel_size assert world_size == parallel_state.get_data_parallel_world_size() assert rank == parallel_state.get_data_parallel_rank() check(parallel_state.get_data_parallel_group(), world_size, rank) # Reset groups parallel_state.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def eval_epoch_end(self, outputs, mode): if isinstance(outputs[0], dict): outputs = [outputs] loss_list = [] bleu_score_list = [] for dataloader_idx, output in enumerate(outputs): averaged_loss = average_losses_across_data_parallel_group( [x['loss'] for x in output]) inputs = list(itertools.chain(*[x['inputs'] for x in output])) translations = list( itertools.chain(*[x['translations'] for x in output])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in output])) assert len(translations) == len(inputs) assert len(translations) == len(ground_truths) # Gather translations and ground truths from all workers tr_gt_inp = [ None for _ in range(parallel_state.get_data_parallel_world_size()) ] # we also need to drop pairs where ground truth is an empty string torch.distributed.all_gather_object( tr_gt_inp, [(t, g, i) for (t, g, i) in zip(translations, ground_truths, inputs)], group=parallel_state.get_data_parallel_group(), ) if parallel_state.get_data_parallel_rank() == 0: _translations = [] _ground_truths = [] _inputs = [] # Deduplicate sentences that may have been distributed across multiple data parallel ranks. gt_inp_set = set() for rank in range( 0, parallel_state.get_data_parallel_world_size()): for t, g, i in tr_gt_inp[rank]: if g + i not in gt_inp_set: gt_inp_set.add(g + i) _translations.append(t) _ground_truths.append(g) _inputs.append(i) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size( ) dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score / parallel_state.get_data_parallel_world_size()}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:" ) logging.info( '============================================================' ) for example_idx in range(0, 3): random_index = random.randint(0, len(_translations) - 1) logging.info(" " + '\u0332'.join(f"Example {example_idx}:")) logging.info(f" Input: {_inputs[random_index]}") logging.info( f" Prediction: {_translations[random_index]}") logging.info( f" Ground Truth: {_ground_truths[random_index]}") logging.info( '============================================================' ) else: bleu_score = 0.0 loss_list.append(averaged_loss[0].cpu().numpy()) bleu_score_list.append(bleu_score) if dataloader_idx == 0: self.log(f'{mode}_sacreBLEU', bleu_score, sync_dist=True) self.log(f'{mode}_loss', averaged_loss[0], prog_bar=True) if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss[0], mode) else: if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss[0], mode) else: self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}', bleu_score, sync_dist=True) self.log(f'{mode}_loss_dl_index_{dataloader_idx}', averaged_loss[0], prog_bar=False) if len(loss_list) > 1: self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) self.log(f"{mode}_sacreBLEU_avg", np.mean(bleu_score_list), sync_dist=True)
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, *args, **kwargs ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if ( parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None ): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if parallel_state.get_data_parallel_rank() == 0: msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def _build_eval_dataset(self, data_cfg, mode='train'): """Build the evaluation dataset.""" if data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size( ): raise ValueError( f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.' ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) is_names_list_config = False if hasattr(data_cfg, "names"): if isinstance(data_cfg.names, ListConfig): is_names_list_config = True if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError( "src_list and tgt_list must both be either a ListConfig or a string. " ) if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError( "src_file_name and tgt_file_name must have the same number of elements. " ) if is_names_list_config and len(data_cfg.names) != len( data_cfg.src_file_name): raise ValueError( "If you are providing names for each src/tgt file, they must have the same number of elements." ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) if mode == 'train' and len(datasets) > 1: if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=data_cfg.get( 'concat_sampling_technique', 'temperature'), sampling_temperature=data_cfg.get( 'concat_sampling_temperature', 5), sampling_probabilities=data_cfg.get( 'concat_sampling_probabilities', [1 / len(datasets)] * len(datasets)), global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), ) return dataset return datasets
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. model_type: *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if (parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs if model_type == ModelType.encoder_or_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) elif model_type == ModelType.encoder_and_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() # `add_encoder` & `add_decoder` logic. add_encoder, add_decoder = True, True if parallel_state.get_pipeline_model_parallel_world_size() > 1: split_rank = parallel_state.get_pipeline_model_parallel_split_rank( ) if split_rank is None: raise RuntimeError( "Split rank needs to be specified for model with both encoder and decoder." ) rank = parallel_state.get_pipeline_model_parallel_rank() world_size = parallel_state.get_pipeline_model_parallel_world_size( ) pre_process = rank == 0 or rank == split_rank post_process = rank == (split_rank - 1) or rank == (world_size - 1) add_encoder = parallel_state.is_pipeline_stage_before_split() add_decoder = parallel_state.is_pipeline_stage_after_split() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, "add_encoder": add_encoder, "add_decoder": add_decoder, }) model = model_provider_func(*cur_args, **cur_kwargs) model.model_type = model_type if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if (parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0): msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), _calc_number_of_params(model), ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def run_interleaved_with_dynamic_batch_size( pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, ) -> None: args = global_vars.get_args() _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, 1, # args.data_parallel_size, ) virtual_pipeline_model_parallel_size = 2 # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size()) print_separator( f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=HIDDEN_SIZE, ) assert isinstance(model, list) assert len(model) == virtual_pipeline_model_parallel_size optimizer = torch.optim.Adam( _get_params_for_weight_decay_optimization(model)) initial_local_minibatch_size = get_num_microbatches() * micro_batch_size dataset = Dataset(NUM_SAMPLES) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=BatchSamplerCls( NUM_SAMPLES, 0, initial_local_minibatch_size, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size(), ), ) data_iter = iter(data_loader) def get_num_samples(batch): if isinstance(batch, torch.Tensor): return len(batch) assert isinstance(batch, (list, tuple)) return [get_num_samples(b) for b in batch] tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] consumed_samples = 0 for i in range(NUM_ITERATIONS): update_num_microbatches(consumed_samples, consistency_check=False) local_batch_size = get_num_microbatches() * micro_batch_size data_iter._index_sampler.local_minibatch_size = local_batch_size local_mini_batch = next(data_iter) _logger.info(f"iter: {i} / {NUM_ITERATIONS} " f"local batchsize: {get_num_samples(local_mini_batch)} " f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}") _forward_backward_pipelining_with_interleaving( fwd_step_func, local_mini_batch, model, forward_only=forward_only, tensor_shape=tensor_shape, ) consumed_samples += (parallel_state.get_data_parallel_world_size() * get_num_microbatches() * micro_batch_size) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") else: optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)