def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): batch_has_lang_information = len(batch[0]) == 7 micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. processed_batch = self._process_global_batch(batch) # Call parent validation step to get the loss. # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class. loss = super().validation_step(processed_batch, batch_idx) predicted_token_ids, _ = self.decode( tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30 ) # Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers. preds_text = self.ids_to_text(predicted_token_ids) labels_text = self.ids_to_text(processed_batch['labels']) input_text = self.ids_to_text(processed_batch['text_enc']) if not batch_has_lang_information: categories = [None] * len(preds_text) else: categories = processed_batch['lang'] metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] assert len(categories) == len(preds_text) == len(labels_text) for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)): # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats. pred, label = self.cast_for_metric( pred, label, self.val_metric_name if mode == 'validation' else self.test_metric_name ) if batch_has_lang_information: _ = metric(pred, label, category) else: _ = metric(pred, label) return { 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories, 'inputs': input_text, }
def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() batch = [next(iterator) for _ in range(get_num_microbatches())] self.fetched += 1 if not self.prefetch_batches and self._has_len: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True self.done = self.fetched >= len(dataloader) self.on_fetch_end(batch, start_output)
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: for data_parallel_size in range(1, self.world_size + 1): expected_global_batch_size = self.GLOBAL_BATCH_SIZE expected_micro_batch_size = self.MICRO_BATCH_SIZE if rampup_batch_size: expected_global_batch_size = rampup_batch_size[0] num_consumed_samples = 0 step_of_global_batch_size = rampup_batch_size[1] threshold = rampup_batch_size[2] if data_parallel_size > 1 and data_parallel_size % 2 != 0: continue if self.world_size % data_parallel_size != 0: continue with self.subTest(data_parallel_size=data_parallel_size): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=self.world_size // data_parallel_size, pipeline_model_parallel_size_=1, ) self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) _reconfigure_microbatch_calculator( self.rank, rampup_batch_size, self.GLOBAL_BATCH_SIZE, self.MICRO_BATCH_SIZE, data_parallel_size, ) self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) self.assertEqual( get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(current_global_batch_size, expected_global_batch_size) # Make sure `global_batch_size` equals to the final global batch size after # certain number of updates. if rampup_batch_size: update_num_microbatches(current_global_batch_size) for i in range(100): current_global_batch_size = get_current_global_batch_size( ) update_num_microbatches(current_global_batch_size) current_global_batch_size = get_current_global_batch_size() self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE) parallel_state.destroy_model_parallel()
def get_forward_backward_func( virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, ): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if virtual_pipeline_model_parallel_size is not None: if get_num_microbatches() % pipeline_model_parallel_size != 0: msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" raise RuntimeError(msg) forward_backward_func = _forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) return super().training_step(batch, batch_idx)
def build_data_loader( self, dataset, micro_batch_size, global_batch_size, shuffle, num_workers, pin_memory, drop_last, check_validation_interval, ): """Buld dataloader given an input dataset.""" if dataset is None: return None rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle ) # This check makes sure the val_check_interval is less than the number of global batches. # Normally, PTL would do this check and properly account for gradient accumulation. # But now, it is implicit in the apex fwd/bwd functions and so we need to check for this somewhere. # The consequence of not doing this is that training loop will never run validation. # NOTE: Prog bar is also broken as a result of this. global_batch_size_per_gpu = micro_batch_size * get_num_microbatches() if ( self.trainer.val_check_interval > (sampler.num_samples // global_batch_size_per_gpu) and check_validation_interval ): raise ValueError( f"trainer.val_check_interval {self.trainer.val_check_interval} is > number of global batches {sampler.num_samples // global_batch_size}" ) # Data loader. Note that batch size is the per GPU batch size. return torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, batch_size=micro_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, )
def get_forward_backward_func( virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, ): if parallel_state.get_pipeline_model_parallel_world_size() > 1: if virtual_pipeline_model_parallel_size is not None: if get_num_microbatches() % pipeline_model_parallel_size != 0: msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" raise RuntimeError(msg) warnings.warn( "Pipeline Model Parallel with interleaving scheduling is experimental. " f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}", ExperimentalWarning) forward_backward_func = _forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: forward_backward_func = forward_backward_no_pipelining return forward_backward_func
def training_step(self, batch, batch_idx): micro_batch_size = batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # At this point batch is a list of dictionaries where eatch dict is a microbatch. # After the process_global_batch call, batch will be a single dictionary containing the global batch. # This is required since the parent class expects a single global batch dictioanry. batch = self._process_global_batch(batch) return super().training_step(batch, batch_idx)
def forward_step( forward_step_func: FwdStepFunc, batch: Batch, model: torch.nn.Module, input_tensor: Optional[torch.Tensor], losses_reduced: List[torch.Tensor], ): """Forward step for passed-in model. If first stage, input tensor is obtained from data_iterator, otherwise passed-in input_tensor is used. Returns output tensor. Args: forward_step_func: Model specific function. This takes a minibatch and model as its arguments and returns the model's output and the loss function. batch: minibatch model: unwrappable model input_tensor: losses_reduced: Returns: output_tensor """ # timers = get_timers() # timers("forward-compute").start() unwrapped_model = unwrap_model(model) # NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`. # See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA # for the details of `set_input_tensor`. unwrapped_model.set_input_tensor(input_tensor) output_tensor, loss_func = forward_step_func(batch, model) # print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}") if parallel_state.is_pipeline_last_stage(): output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) # timers("forward-compute").stop() return output_tensor
def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): batch_has_lang_information = len(batch[0]) == 7 # XNLI Batches have language information that need to be removed before calling the parent validation step. if batch_has_lang_information: processed_batch = [] for micro_batch in batch: micro_batch = { k: v for k, v in micro_batch.items() if k != 'lang' } processed_batch.append(micro_batch) else: processed_batch = batch micro_batch_size = processed_batch[0]['text_enc'].size(0) # This should happen only on the last batch of the dataset. if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size * parallel_state.get_data_parallel_world_size() * get_num_microbatches(), micro_batch_size=micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) # Call parent validation step to get the loss. loss = super().validation_step(processed_batch, batch_idx) # Remainder of the code is to run the decoding loop, and compute accuracies. if batch_has_lang_information: tokens_enc, _, _, labels, enc_mask, _, langs = self.process_global_batch( batch) else: tokens_enc, _, _, labels, enc_mask, _ = self.process_global_batch( batch) predicted_token_ids, _ = self.decode(tokens_enc=tokens_enc, enc_mask=enc_mask, num_tokens_to_generate=30) preds_text, labels_text = self.preds_and_labels_to_text( predicted_token_ids, labels) if not batch_has_lang_information: if (mode == 'validation' and hasattr(self.cfg.data.validation_ds, "names") and isinstance(self.cfg.data.validation_ds.names, ListConfig)): categories = [ self.cfg.data.validation_ds.names[dataloader_idx] ] * len(preds_text) elif (mode == 'test' and hasattr(self.cfg.data.test_ds, "names") and isinstance(self.cfg.data.test_ds.names, ListConfig)): categories = [self.cfg.data.test_ds.names[dataloader_idx] ] * len(preds_text) else: categories = [None] * len(preds_text) else: categories = langs metric = self.val_metric if mode == 'validation' else self.test_metric assert len(categories) == len(preds_text) == len(labels_text) for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)): _ = metric(pred, label, category) return { 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories }
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Optional[Batch], model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, decoder_sequence_length: Optional[int] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwawrgs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model: List[torch.nn.Module] = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model: torch.nn.Module = model[0] # Compute number of warmup microbatches. num_microbatches: int = get_num_microbatches() num_warmup_microbatches: int = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1 ) num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches model_type = get_model_type(model) rank: int = parallel_state.get_pipeline_model_parallel_rank() recv_tensor_shapes: List[List[int]] = get_tensor_shapes( rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) send_tensor_shapes: List[List[int]] = get_tensor_shapes( rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) _logger.info( f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}" ) # Input, output tensors only need to be saved when doing backward passes input_tensors: List[Union[None, torch.Tensor]] = [] output_tensors: List[Union[None, torch.Tensor]] = [] losses_reduced: List[Union[None, torch.Tensor]] = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) output_tensor = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration: bool = i == (num_microbatches_remaining - 1) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) if forward_only: _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) if last_iteration: input_tensor = None _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send bwd and receive fwd") input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) return losses_reduced
def forward_backward_no_pipelining( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, custom_sync_context_handler=None, **kwargs, ): """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A List of torch.Tensors model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: grad_scaler: dtype: disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`. Should be used when your forward and loss computation is in the autocast context to avoid unnecesarily nest autocast context. custom_sync_context_handler: **kwargs: Added to handle `tensor_shape` which has no effect on this function. Returns: a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. """ model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] model_type = get_model_type(model) if custom_sync_context_handler is not None: context_handler = custom_sync_context_handler elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): context_handler = model.no_sync else: context_handler = placeholder_handler losses_reduced = [] input_tensor, output_tensor_grad = None, None num_micro_batches = get_num_microbatches() with context_handler(): for i in range(num_micro_batches - 1): _logger.info(f"Iter {i} of {num_micro_batches - 1}") cur_micro_batch = get_kth_microbatch(batch, i) _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced, dtype=dtype, disable_autocast=disable_autocast, ) if not forward_only: _logger.debug("Call `backward_step`") backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, ) # Run computation for last microbatch out of context handler (want to # synchronize gradients). _logger.info("Cooldown") _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced, dtype=dtype, disable_autocast=disable_autocast, ) if not forward_only: _logger.debug("Call `backward_step`") backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, ) return losses_reduced
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None): app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != self.tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def compute_consumed_samples(self, steps_since_resume=0): app_state = AppState() consumed_samples = ( self.init_consumed_samples + steps_since_resume * app_state.data_parallel_size * self.cfg.micro_batch_size * get_num_microbatches() ) return int(consumed_samples)
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) cur_microbatch = get_kth_microbatch(batch, i) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration = i == (num_microbatches_remaining - 1) cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) if forward_only: _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape=tensor_shape) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) else: _logger.debug("send bwd and receive fwd") input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape=tensor_shape) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = p2p_communication.recv_backward( tensor_shape=tensor_shape) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) return losses_reduced
def forward_step( forward_step_func: FwdStepFunc, batch: Optional[Batch], model: torch.nn.Module, input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], losses_reduced: List[torch.Tensor], dtype: torch.dtype, disable_autocast: bool = False, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """Forward step for passed-in model. If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used. Returns output tensor. Args: forward_step_func: Model specific function. This takes a minibatch and model as its arguments and returns the model's output and the loss function. batch: minibatch model: unwrappable model input_tensor: losses_reduced: dtype: disable_autocast: Returns: output_tensor """ # timers = get_timers() # timers("forward-compute").start() unwrapped_model = unwrap_model(model) model_type = get_model_type(unwrapped_model) # NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`. # See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA # for the details of `set_input_tensor`. unwrap_output_tensor = not isinstance(input_tensor, list) if unwrap_output_tensor: input_tensor = [input_tensor] input_tensor = [ inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor ] unwrapped_model.set_input_tensor(input_tensor) with torch.cuda.amp.autocast( enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16), dtype=dtype, ): output_tensor, loss_func = forward_step_func(batch, model) if parallel_state.is_pipeline_last_stage(): output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() losses_reduced.append(loss_reduced) # timers("forward-compute").stop() # If T5 model (or other model with encoder and decoder) # and in decoder stack, then send encoder_hidden_state # downstream as well. if (parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder): return [output_tensor, input_tensor[-1]] if unwrap_output_tensor: return output_tensor return [output_tensor]
def run_interleaved_with_dynamic_batch_size( pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, ) -> None: args = global_vars.get_args() _reconfigure_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, 1, # args.data_parallel_size, ) virtual_pipeline_model_parallel_size = 2 # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # used ubiquitously but this test uses custom model so it's safe to abuse. parallel_state.initialize_model_parallel( 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) pipeline_model_parallel_size = ( parallel_state.get_pipeline_model_parallel_world_size()) print_separator( f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" ) model = build_model( model_provider_func, wrap_with_ddp=True, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=HIDDEN_SIZE, ) assert isinstance(model, list) assert len(model) == virtual_pipeline_model_parallel_size optimizer = torch.optim.Adam( _get_params_for_weight_decay_optimization(model)) initial_local_minibatch_size = get_num_microbatches() * micro_batch_size dataset = Dataset(NUM_SAMPLES) data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=BatchSamplerCls( NUM_SAMPLES, 0, initial_local_minibatch_size, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size(), ), ) data_iter = iter(data_loader) def get_num_samples(batch): if isinstance(batch, torch.Tensor): return len(batch) assert isinstance(batch, (list, tuple)) return [get_num_samples(b) for b in batch] tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] consumed_samples = 0 for i in range(NUM_ITERATIONS): update_num_microbatches(consumed_samples, consistency_check=False) local_batch_size = get_num_microbatches() * micro_batch_size data_iter._index_sampler.local_minibatch_size = local_batch_size local_mini_batch = next(data_iter) _logger.info(f"iter: {i} / {NUM_ITERATIONS} " f"local batchsize: {get_num_samples(local_mini_batch)} " f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}") _forward_backward_pipelining_with_interleaving( fwd_step_func, local_mini_batch, model, forward_only=forward_only, tensor_shape=tensor_shape, ) consumed_samples += (parallel_state.get_data_parallel_world_size() * get_num_microbatches() * micro_batch_size) if not forward_only: for m in model: for p in m.parameters(): if p.grad is None: raise RuntimeError("grad not found") else: optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(TEST_SUCCESS_MESSAGE)
def _forward_backward_pipelining_with_interleaving( forward_step_func: FwdStepFunc, batch: List[Optional[Batch]], model: List[torch.nn.Module], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwargs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run interleaved 1F1B schedule with communication between pipeline stages as needed. This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. This means that model is split into model chunks. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown Note that if `forward_only` this scheduling consists of only warmup phase. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ if not isinstance(model, list): raise RuntimeError("`model` must be a list of `nn.Module`'s'") num_model_chunks: int = len(model) input_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] output_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] curr_iters: List[int] = [0 for _ in range(num_model_chunks)] losses_reduced: List[Union[None, torch.Tensor]] = [] if not forward_only: output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size( ) pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank( ) # Compute number of warmup and remaining microbatches. num_microbatches: int = get_num_microbatches() * num_model_chunks all_warmup_microbatches: bool = False if forward_only: num_warmup_microbatches: int = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") ################################################################################################################### # Helper function definitions. ################################################################################################################### def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: """Helper function to get the model chunk ID given the iteration number.""" pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if (parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( output_tensors[model_chunk_id])): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, curr_iters[model_chunk_id]), model[model_chunk_id], input_tensor, losses_reduced, dtype, disable_autocast, ) curr_iters[model_chunk_id] += 1 output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id: int) -> torch.Tensor: """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_type = get_model_type(model[model_chunk_id]) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs) return input_tensor_grad ################################################################################################################### # Run warmup forward passes. ################################################################################################################### parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) _logger.info("Warmup phase") for k in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") output_tensor = forward_step_helper(k, curr_iters) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False _logger.debug( f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" ) # Don't send tensor downstream if on last stage. if parallel_state.is_pipeline_last_stage(): _logger.debug("Pipeline last stage, not sending tensor downstream") output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_tensor_grad = None recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: _logger.debug("send fwd and receive fwd") input_tensor = p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) free_output_tensor(output_tensor, deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for k in range(num_microbatches_remaining): # Forward pass. _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k, curr_iters) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank( forward_model_chunk_id) if parallel_state.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank( backward_model_chunk_id) _logger.debug( f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" ) if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) ################################################################################################################### # Run cooldown backward passes (flush out pipeline). ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) for k in range(num_microbatches_remaining, num_microbatches): _logger.debug( f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" ) input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)) return losses_reduced
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None): # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. _reconfigure_microbatch_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, # Make sure that there is no "grad acc" while decoding. data_parallel_size=1, # We check above to make sure that dataparallel size is always 1 at inference. ) # If classes that inherit from this class are using a different tokenizer, tokenizer = self.tokenizer if tokenizer is None else tokenizer app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def forward_backward_no_pipelining( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, **kwargs, ): """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A List of torch.Tensors model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: **kwargs: Added to handle `tensor_shape` which has no effect on this function. Returns: a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. """ model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] context_handler = placeholder_handler if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): context_handler = model.no_sync losses_reduced = [] input_tensor, output_tensor_grad = None, None num_micro_batches = get_num_microbatches() with context_handler(): for i in range(num_micro_batches - 1): _logger.info(f"Iter {i} of {num_micro_batches - 1}") cur_micro_batch = get_kth_microbatch(batch, i) _logger.debug("Call `forward_step`") output_tensor = forward_step(forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced) if not forward_only: _logger.debug("Call `backward_step`") backward_step(input_tensor, output_tensor, output_tensor_grad) # Run computation for last microbatch out of context handler (want to # synchronize gradients). _logger.info("Cooldown") _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced) if not forward_only: _logger.debug("Call `backward_step`") backward_step(input_tensor, output_tensor, output_tensor_grad) return losses_reduced