def allreduce_word_and_position_embeddings(self): # Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407 # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( parallel_state.is_rank_in_embedding_group() ): if self.enc_dec_model.share_word_embeddings: word_embeddings_weight = self.enc_dec_model.word_embeddings_weight() if self.megatron_amp_o2: # O2 recipe stores a "main" copy of weights and grads grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) # All reduce position embeddings for T5. if ( parallel_state.is_rank_in_position_embedding_group() and parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.get_pipeline_model_parallel_split_rank() is not None ): position_embeddings_weight = self.enc_dec_model.position_embeddings_weight() if self.megatron_amp_o2: grad = position_embeddings_weight.main_grad else: grad = position_embeddings_weight.grad torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def setup(self, stage=None): # NOTE: super().__init__ will try and setup train/val/test datasets, but we sidestep this using a if self._train_ds is not None condition # We then set things up for real only once setup() of this class is called. resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path if resume_checkpoint_path: try: init_consumed_samples = int( float( re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0])) except (ValueError, TypeError): logging.warning( "Cannot parse the checkpoint file to get the consumed samples. This is expected if you are not using memmap datasets." ) init_consumed_samples = 0 else: init_consumed_samples = 0 self.init_consumed_samples = init_consumed_samples if stage == 'predict': return # If the user wants to manually override train and validation dataloaders before calling `.fit()` if self._train_dl is not None and self._validation_dl is not None: return self.build_train_valid_test_datasets() self.setup_training_data(self._cfg.train_ds) self.setup_validation_data(self._cfg.validation_ds) if hasattr(self._cfg, 'test_ds'): self.setup_test_data(self._cfg.test_ds) # when using pipeline model parallel the final stage need to initialize word embeddings if parallel_state.get_pipeline_model_parallel_world_size() > 1: self.enc_dec_model.sync_initial_word_embeddings() self.enc_dec_model.sync_initial_position_embeddings()
def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if parallel_state.get_pipeline_model_parallel_world_size() == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if parallel_state.is_pipeline_last_stage(): assert not parallel_state.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init_method) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
def setup(self, stage=None): resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path if resume_checkpoint_path: try: init_consumed_samples = int( float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0]) ) except (ValueError, TypeError): logging.warning("Cannot parse the checkpoint file to get the consumed samples. assume it is zero.") init_consumed_samples = 0 else: init_consumed_samples = 0 self.init_consumed_samples = init_consumed_samples """A PTL method to setup the training, validation and test datasets.""" if stage == 'predict': return if self._train_dl is not None and self._validation_dl is not None: return self.build_train_valid_test_datasets() self.setup_training_data(self._cfg.data) self.setup_validation_data(self._cfg.data) self.setup_test_data(self._cfg.data) # when using pipeline model parallel the final stage need to initialize word embeddings if parallel_state.get_pipeline_model_parallel_world_size() > 1: self.enc_dec_model.sync_initial_word_embeddings() self.enc_dec_model.sync_initial_position_embeddings()
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: """Helper function to get the model chunk ID given the iteration number.""" pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id
def test_initialize_model_parallel_with_virtual_and_split(self) -> None: if self.world_size < 4: self.skipTest("requires >= 4 GPUs") self.assertFalse(parallel_state.model_parallel_is_initialized()) tensor_model_parallel_world_size = 1 + int(self.world_size > 4) pipeline_model_parallel_world_size = (self.world_size // tensor_model_parallel_world_size) virtual_pipeline_model_parallel_world_size = 2 pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_world_size, pipeline_model_parallel_split_rank_= pipeline_model_parallel_split_rank, ) self.assertEqual( calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size), parallel_state.get_tensor_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_world_size, parallel_state.get_pipeline_model_parallel_world_size(), ) self.assertEqual( virtual_pipeline_model_parallel_world_size, parallel_state.get_virtual_pipeline_model_parallel_world_size(), ) expected_pipeline_rank = (self.rank - (self.rank % tensor_model_parallel_world_size )) % pipeline_model_parallel_world_size self.assertEqual( expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(), ) # virtual pipeline model parallel rank is lazily set, i.e., right after the call of # `initialize_model_parallel`, it's set to 0. self.assertEqual( 0, parallel_state.get_virtual_pipeline_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_split_rank, parallel_state.get_pipeline_model_parallel_split_rank(), ) fake_split_rank = 77 parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) self.assertEqual( fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()) parallel_state.destroy_model_parallel()
def get_forward_backward_func( virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, ): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if virtual_pipeline_model_parallel_size is not None: if get_num_microbatches() % pipeline_model_parallel_size != 0: msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" raise RuntimeError(msg) forward_backward_func = _forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def get_forward_backward_func( virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, ): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if virtual_pipeline_model_parallel_size is not None: if get_num_microbatches() % pipeline_model_parallel_size != 0: msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" raise RuntimeError(msg) warnings.warn( "Pipeline Model Parallel with interleaving scheduling is experimental. " f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}", ExperimentalWarning) forward_backward_func = _forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def setup(self, stage=None): """ PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path if resume_checkpoint_path: try: init_consumed_samples = int( float( re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", resume_checkpoint_path)[0])) except (ValueError, TypeError): logging.warning( "Cannot parse the checkpoint file to get the consumed samples. assume it is zero." ) init_consumed_samples = 0 else: init_consumed_samples = 0 self.init_consumed_samples = init_consumed_samples if stage == 'predict': return else: # TODO: consider adding a ModelPT guard to check if model is being restored. # allowing restored models to optionally setup datasets self.build_train_valid_test_datasets() self.setup_training_data(self.cfg.data) self.setup_validation_data(self.cfg.data) self.setup_test_data(self.cfg.data) # when using pipeline model parallel the final stage need to initialize word embeddings if parallel_state.get_pipeline_model_parallel_world_size() > 1: self.model.sync_initial_word_embeddings()
def _forward_backward_pipelining_with_interleaving( forward_step_func: FwdStepFunc, batch: List[Optional[Batch]], model: List[torch.nn.Module], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwargs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run interleaved 1F1B schedule with communication between pipeline stages as needed. This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. This means that model is split into model chunks. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown Note that if `forward_only` this scheduling consists of only warmup phase. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ if not isinstance(model, list): raise RuntimeError("`model` must be a list of `nn.Module`'s'") num_model_chunks: int = len(model) input_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] output_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] curr_iters: List[int] = [0 for _ in range(num_model_chunks)] losses_reduced: List[Union[None, torch.Tensor]] = [] if not forward_only: output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size( ) pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank( ) # Compute number of warmup and remaining microbatches. num_microbatches: int = get_num_microbatches() * num_model_chunks all_warmup_microbatches: bool = False if forward_only: num_warmup_microbatches: int = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") ################################################################################################################### # Helper function definitions. ################################################################################################################### def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: """Helper function to get the model chunk ID given the iteration number.""" pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if (parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( output_tensors[model_chunk_id])): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, curr_iters[model_chunk_id]), model[model_chunk_id], input_tensor, losses_reduced, dtype, disable_autocast, ) curr_iters[model_chunk_id] += 1 output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id: int) -> torch.Tensor: """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_type = get_model_type(model[model_chunk_id]) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs) return input_tensor_grad ################################################################################################################### # Run warmup forward passes. ################################################################################################################### parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) _logger.info("Warmup phase") for k in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") output_tensor = forward_step_helper(k, curr_iters) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False _logger.debug( f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" ) # Don't send tensor downstream if on last stage. if parallel_state.is_pipeline_last_stage(): _logger.debug("Pipeline last stage, not sending tensor downstream") output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_tensor_grad = None recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: _logger.debug("send fwd and receive fwd") input_tensor = p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) free_output_tensor(output_tensor, deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for k in range(num_microbatches_remaining): # Forward pass. _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k, curr_iters) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank( forward_model_chunk_id) if parallel_state.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank( backward_model_chunk_id) _logger.debug( f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" ) if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) ################################################################################################################### # Run cooldown backward passes (flush out pipeline). ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) for k in range(num_microbatches_remaining, num_microbatches): _logger.debug( f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" ) input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)) return losses_reduced
def run_interleaved_with_dynamic_batch_size( pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, ) -> None: args = global_vars.get_args() _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, 1, # args.data_parallel_size, ) virtual_pipeline_model_parallel_size = 2 # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size()) print_separator( f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=HIDDEN_SIZE, ) assert isinstance(model, list) assert len(model) == virtual_pipeline_model_parallel_size optimizer = torch.optim.Adam( _get_params_for_weight_decay_optimization(model)) initial_local_minibatch_size = get_num_microbatches() * micro_batch_size dataset = Dataset(NUM_SAMPLES) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=BatchSamplerCls( NUM_SAMPLES, 0, initial_local_minibatch_size, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size(), ), ) data_iter = iter(data_loader) def get_num_samples(batch): if isinstance(batch, torch.Tensor): return len(batch) assert isinstance(batch, (list, tuple)) return [get_num_samples(b) for b in batch] tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] consumed_samples = 0 for i in range(NUM_ITERATIONS): update_num_microbatches(consumed_samples, consistency_check=False) local_batch_size = get_num_microbatches() * micro_batch_size data_iter._index_sampler.local_minibatch_size = local_batch_size local_mini_batch = next(data_iter) _logger.info(f"iter: {i} / {NUM_ITERATIONS} " f"local batchsize: {get_num_samples(local_mini_batch)} " f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}") _forward_backward_pipelining_with_interleaving( fwd_step_func, local_mini_batch, model, forward_only=forward_only, tensor_shape=tensor_shape, ) consumed_samples += (parallel_state.get_data_parallel_world_size() * get_num_microbatches() * micro_batch_size) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") else: optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) cur_microbatch = get_kth_microbatch(batch, i) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration = i == (num_microbatches_remaining - 1) cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) if forward_only: _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape=tensor_shape) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) else: _logger.debug("send bwd and receive fwd") input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape=tensor_shape) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = p2p_communication.recv_backward( tensor_shape=tensor_shape) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) return losses_reduced
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, *args, **kwargs ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if ( parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None ): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if parallel_state.get_data_parallel_rank() == 0: msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Optional[Batch], model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, decoder_sequence_length: Optional[int] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwawrgs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model: List[torch.nn.Module] = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model: torch.nn.Module = model[0] # Compute number of warmup microbatches. num_microbatches: int = get_num_microbatches() num_warmup_microbatches: int = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1 ) num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches model_type = get_model_type(model) rank: int = parallel_state.get_pipeline_model_parallel_rank() recv_tensor_shapes: List[List[int]] = get_tensor_shapes( rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) send_tensor_shapes: List[List[int]] = get_tensor_shapes( rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) _logger.info( f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}" ) # Input, output tensors only need to be saved when doing backward passes input_tensors: List[Union[None, torch.Tensor]] = [] output_tensors: List[Union[None, torch.Tensor]] = [] losses_reduced: List[Union[None, torch.Tensor]] = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) output_tensor = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration: bool = i == (num_microbatches_remaining - 1) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) if forward_only: _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) if last_iteration: input_tensor = None _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send bwd and receive fwd") input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) return losses_reduced
def __init__( self, init_method, output_layer_init_method, encoder_attn_mask_type, vocab_size, max_position_embeddings, hidden_size, ffn_hidden_size, num_layers, num_tokentypes, num_attention_heads, apply_query_key_layer_scaling=True, kv_channels=None, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True, use_cpu_initialization=False, hidden_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, persist_layer_norm=False, openai_gelu=False, onnx_safe=False, use_soft_prompts=False, num_prompt_tokens=10, prompt_tags=None, ): super(TransformerLanguageModel, self).__init__() self.pre_process = pre_process self.post_process = post_process self.hidden_size = hidden_size self.num_layers = num_layers self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.num_tokentypes = num_tokentypes self.init_method = init_method self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.hidden_dropout = hidden_dropout self.output_layer_init_method = output_layer_init_method self.use_soft_prompts = use_soft_prompts self.prompt_tags = prompt_tags self.num_prompt_tokens = num_prompt_tokens if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads # Embeddings. if self.pre_process: self.embedding = Embedding( hidden_size=self.hidden_size, vocab_size=self.vocab_size, max_sequence_length=self.max_position_embeddings, init_method=self.init_method, num_tokentypes=self.num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=self.hidden_dropout, ) self._embedding_key = 'embedding' # Soft Prompts if self.use_soft_prompts: self.prompt_table = PromptTable( prompt_tags=self.prompt_tags, num_prompt_tokens=self.num_prompt_tokens, hidden_size=self.hidden_size, ) self._prompt_table_key = 'prompt_table' # Transformer. self.encoder = ParallelTransformer( init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, num_layers=self.num_layers, hidden_size=self.hidden_size, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, bias_gelu_fusion=bias_gelu_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) self._encoder_key = 'encoder' # Decoder if self.add_decoder: assert ( parallel_state.get_pipeline_model_parallel_world_size() == 1 ), 'pipeline parallelism is not supported in the presence of decoder' self.decoder = ParallelTransformer( layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, num_layers=self.num_layers, hidden_size=self.hidden_size, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, pre_process=self.pre_process, post_process=self.post_process, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers= activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, bias_gelu_fusion=bias_gelu_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) self._decoder_key = 'decoder' if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. model_type: *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if (parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs if model_type == ModelType.encoder_or_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) elif model_type == ModelType.encoder_and_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() # `add_encoder` & `add_decoder` logic. add_encoder, add_decoder = True, True if parallel_state.get_pipeline_model_parallel_world_size() > 1: split_rank = parallel_state.get_pipeline_model_parallel_split_rank( ) if split_rank is None: raise RuntimeError( "Split rank needs to be specified for model with both encoder and decoder." ) rank = parallel_state.get_pipeline_model_parallel_rank() world_size = parallel_state.get_pipeline_model_parallel_world_size( ) pre_process = rank == 0 or rank == split_rank post_process = rank == (split_rank - 1) or rank == (world_size - 1) add_encoder = parallel_state.is_pipeline_stage_before_split() add_decoder = parallel_state.is_pipeline_stage_after_split() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, "add_encoder": add_encoder, "add_decoder": add_decoder, }) model = model_provider_func(*cur_args, **cur_kwargs) model.model_type = model_type if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if (parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0): msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), _calc_number_of_params(model), ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def __init__( self, init_method, output_layer_init_method, num_layers, hidden_size, ffn_hidden_size, num_attention_heads, apply_query_key_layer_scaling=True, kv_channels=None, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, hidden_dropout=0.1, use_cpu_initialization=False, bias_gelu_fusion=True, openai_gelu=False, onnx_safe=False, ): super(ParallelTransformer, self).__init__() if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads self.fp32_residual_connection = fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None # Store activation checkpointing flag. self.activations_checkpoint_method = activations_checkpoint_method self.activations_checkpoint_num_layers = activations_checkpoint_num_layers # Number of layers. assert ( num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0 ), 'num_layers must be divisible by pipeline_model_parallel_size' self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size( ) # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, precision=precision, fp32_residual_connection=fp32_residual_connection, layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, bias_gelu_fusion=bias_gelu_fusion, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) # TODO: get virtual_pipeline_model_parallel_size from apex.mpu # if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None: # assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( # 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size' # ) # # Number of layers in each model chunk is the number of layers in the stage, # # divided by the number of model chunks in a stage. # self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # # layers to stages like (each list is a model chunk): # # Stage 0: [0] [2] [4] [6] # # Stage 1: [1] [3] [5] [7] # # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of # # layers to stages like (each list is a model chunk): # # Stage 0: [0, 1] [4, 5] # # Stage 1: [2, 3] [6, 7] # offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( # args.num_layers // args.virtual_pipeline_model_parallel_size # ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) # else: # # Each stage gets a contiguous set of layers. # offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers offset = parallel_state.get_pipeline_model_parallel_rank( ) * self.num_layers self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def backward_step( input_tensor: Optional[torch.Tensor], output_tensor: torch.Tensor, output_tensor_grad: Optional[torch.Tensor], model_type: ModelType, *, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, deallocate_pipeline_outputs: bool = False, ) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage's output tensor. Returns gradient of loss with respect to input tensor (None if first stage). Args: input_tensor: output_tensor: output_tensor_grad: Keyword Arguments: grad_scaler: deallocate_pipeline_outputs: Experimental. Returns: input_tensor_grad """ # timers = get_timers() # timers("backward-compute").start() # Retain the grad on the input_tensor. unwrap_input_tensor_grad = not isinstance(input_tensor, list) if unwrap_input_tensor_grad: input_tensor = [input_tensor] input_tensor = [ inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor ] for x in input_tensor: if x is not None: x.retain_grad() if not isinstance(output_tensor, list): output_tensor = [output_tensor] output_tensor = [ out.get() if isinstance(out, FutureTensor) else out for out in output_tensor ] if not isinstance(output_tensor_grad, list): output_tensor_grad = [output_tensor_grad] output_tensor_grad = [ ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad ] # Backward pass. if grad_scaler is not None and output_tensor_grad[0] is None: output_tensor[0] = grad_scaler.scale(output_tensor[0]) if deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0]) else: torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) # Collect the grad of the input_tensor. input_tensor_grad = [None] if input_tensor is not None: input_tensor_grad = [] for x in input_tensor: input_tensor_grad.append(None if x is None else x.grad) # Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder). if (parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder): if output_tensor_grad[1] is not None: # todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`? input_tensor_grad[-1].add_(output_tensor_grad[1]) # timers("backward-compute").stop() return input_tensor_grad[ 0] if unwrap_input_tensor_grad else input_tensor_grad
args.micro_batch_size, args.data_parallel_size, # args.data_parallel_size, ) world_size = torch.distributed.get_world_size() print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE") parallel_state.initialize_model_parallel( tensor_model_parallel_size_=args.tensor_model_parallel_size, pipeline_model_parallel_size_=args.pipeline_model_parallel_size, default_backend="nccl", p2p_backend="ucc" if HAS_TORCH_UCC else "nccl", ) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size() ) model_parallel_cuda_manual_seed(0) model = build_model( gpt_model_provider, wrap_with_ddp=True, virtual_pipeline_model_parallel_size=None, cpu_offload=args.cpu_offload, ) assert isinstance(model, list), model _param_groups = _get_params_for_weight_decay_optimization(model) optim = torch.optim.Adam(_param_groups) runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm) parallel_state.destroy_model_parallel() torch.distributed.barrier()
def forward_backward_func_template( name: str, forward_backward_func, pipeline_model_parallel_size: int, forward_only: bool, ) -> None: print_separator( f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}" ) virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None if name == "no_pipelining": # note (mkozuki): `forward_backward_no_pipelining` is **NOTE** compatible with # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as # tensor_model_parallel_size and set pipeline_model_parallel_size to 1. parallel_state.initialize_model_parallel(1, 1, None) else: # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) if virtual_pipeline_model_parallel_size is not None: # Check the experimental warning message get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, ) assert isinstance(model, list) assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size) _param_groups = _get_params_for_weight_decay_optimization(model) torch.optim.Adam(_param_groups, lr=1e-4) tensor_shape = [ batch_size // parallel_state.get_data_parallel_world_size(), hidden_size ] batch = (torch.randn(tensor_shape).cuda(), ) tensor_shape[0] = micro_batch_size update_num_microbatches(0) forward_backward_func(fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)