def backward_step_helper(microbatch_id: int) -> torch.Tensor: """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_type = get_model_type(model[model_chunk_id]) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) return input_tensor_grad
def backward_step_helper(microbatch_id): """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).""" model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) return input_tensor_grad
def forward_backward_no_pipelining( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, custom_sync_context_handler=None, **kwargs, ): """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A List of torch.Tensors model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: grad_scaler: dtype: disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`. Should be used when your forward and loss computation is in the autocast context to avoid unnecesarily nest autocast context. custom_sync_context_handler: **kwargs: Added to handle `tensor_shape` which has no effect on this function. Returns: a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. """ model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] model_type = get_model_type(model) if custom_sync_context_handler is not None: context_handler = custom_sync_context_handler elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): context_handler = model.no_sync else: context_handler = placeholder_handler losses_reduced = [] input_tensor, output_tensor_grad = None, None num_micro_batches = get_num_microbatches() with context_handler(): for i in range(num_micro_batches - 1): _logger.info(f"Iter {i} of {num_micro_batches - 1}") cur_micro_batch = get_kth_microbatch(batch, i) _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced, dtype=dtype, disable_autocast=disable_autocast, ) if not forward_only: _logger.debug("Call `backward_step`") backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, ) # Run computation for last microbatch out of context handler (want to # synchronize gradients). _logger.info("Cooldown") _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced, dtype=dtype, disable_autocast=disable_autocast, ) if not forward_only: _logger.debug("Call `backward_step`") backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, ) return losses_reduced
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) cur_microbatch = get_kth_microbatch(batch, i) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration = i == (num_microbatches_remaining - 1) cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) if forward_only: _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape=tensor_shape) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) else: _logger.debug("send bwd and receive fwd") input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape=tensor_shape) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = p2p_communication.recv_backward( tensor_shape=tensor_shape) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) return losses_reduced
def forward_backward_no_pipelining( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, **kwargs, ): """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A List of torch.Tensors model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: **kwargs: Added to handle `tensor_shape` which has no effect on this function. Returns: a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. """ model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] context_handler = placeholder_handler if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): context_handler = model.no_sync losses_reduced = [] input_tensor, output_tensor_grad = None, None num_micro_batches = get_num_microbatches() with context_handler(): for i in range(num_micro_batches - 1): _logger.info(f"Iter {i} of {num_micro_batches - 1}") cur_micro_batch = get_kth_microbatch(batch, i) _logger.debug("Call `forward_step`") output_tensor = forward_step(forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced) if not forward_only: _logger.debug("Call `backward_step`") backward_step(input_tensor, output_tensor, output_tensor_grad) # Run computation for last microbatch out of context handler (want to # synchronize gradients). _logger.info("Cooldown") _logger.debug("Call `forward_step`") output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced) if not forward_only: _logger.debug("Call `backward_step`") backward_step(input_tensor, output_tensor, output_tensor_grad) return losses_reduced
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Optional[Batch], model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, decoder_sequence_length: Optional[int] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwawrgs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model: List[torch.nn.Module] = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model: torch.nn.Module = model[0] # Compute number of warmup microbatches. num_microbatches: int = get_num_microbatches() num_warmup_microbatches: int = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1 ) num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches model_type = get_model_type(model) rank: int = parallel_state.get_pipeline_model_parallel_rank() recv_tensor_shapes: List[List[int]] = get_tensor_shapes( rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) send_tensor_shapes: List[List[int]] = get_tensor_shapes( rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) _logger.info( f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}" ) # Input, output tensors only need to be saved when doing backward passes input_tensors: List[Union[None, torch.Tensor]] = [] output_tensors: List[Union[None, torch.Tensor]] = [] losses_reduced: List[Union[None, torch.Tensor]] = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) output_tensor = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration: bool = i == (num_microbatches_remaining - 1) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) if forward_only: _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) if last_iteration: input_tensor = None _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send bwd and receive fwd") input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) return losses_reduced