def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") inputs = ctx.saved_tensors if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = inputs[0].data.view(ctx.input_0_shape) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # Compute the forward pass. detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None,) + grads
def gather_recv_prev_wait(): tensor_recv_prev_req.wait() # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14 # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_prev).view( tensor_shape).requires_grad_())
def _communicate( tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Optional[Shape] = None, override_scatter_gather_tensors_in_pipeline: bool = False, dtype_: Optional[torch.dtype] = None, *, scatter_gather_tensors_in_pipeline: bool = True, params_dtype: Optional[torch.dtype] = None, fp32_residual_connection: bool = False, async_comm: bool = False, ) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]: """Base function for communication of tensors between stages. dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified, torch.float32 is used. See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159 for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``. Args: tensor_send_next: tensor to send to next rank (no tensor sent if set to None). tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). recv_prev: boolean for whether tensor should be received from previous rank. recv_next: boolean for whether tensor should be received from next rank. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to override scatter gather tensors dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape Keyword args: scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on your model deliberately, pass this argument. fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. Returns: tuple containing - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. """ # Create placeholder tensors for receive in forward and backward directions if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` raise RuntimeError( "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`" ) if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: tensor_chunk_shape = ( reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(), ) else: tensor_chunk_shape = tensor_shape # The dtype logic below is copied from NVIDIA/Megatron-LM repo: # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81 # NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to # use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`. # NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32, # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general. # It might be possible if we restrict model architecture. dtype = params_dtype or torch.float if fp32_residual_connection: dtype = torch.float requires_grad = True if dtype_ is not None: dtype = dtype_ requires_grad = False if recv_prev: tensor_recv_prev = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) if recv_next: tensor_recv_next = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) # Split tensor into smaller chunks if using scatter-gather optimization. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = split_tensor_into_1d_equal_chunks( tensor_send_next) if tensor_send_prev is not None: tensor_send_prev = split_tensor_into_1d_equal_chunks( tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops( tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm) if async_comm: tensor_recv_prev_waitfunc = None tensor_recv_next_waitfunc = None # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642) # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait if tensor_recv_prev_req is not None: def tensor_recv_prev_wait(): tensor_recv_prev_req.wait() torch.cuda.synchronize() tensor_recv_prev_waitfunc = tensor_recv_prev_wait if tensor_recv_next_req is not None: def tensor_recv_next_wait(): tensor_recv_next_req.wait() torch.cuda.synchronize() tensor_recv_next_waitfunc = tensor_recv_next_wait else: # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if not async_comm: if recv_prev: tensor_recv_prev = (gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_()) if recv_next: tensor_recv_next = (gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_()) else: def gather_recv_prev_wait(): tensor_recv_prev_req.wait() # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14 # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_prev).view( tensor_shape).requires_grad_()) def gather_recv_next_wait(): tensor_recv_next_req.wait() torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_next).view( tensor_shape).requires_grad_()) tensor_recv_prev_waitfunc = gather_recv_prev_wait tensor_recv_next_waitfunc = gather_recv_next_wait if async_comm: future_tensor_recv_prev = None future_tensor_recv_next = None if tensor_recv_prev is not None: future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc) if tensor_recv_next is not None: future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc) return future_tensor_recv_prev, future_tensor_recv_next return tensor_recv_prev, tensor_recv_next
def gather_recv_next_wait(): tensor_recv_next_req.wait() torch.cuda.synchronize() return (gather_split_1d_tensor(tensor_recv_next).view( tensor_shape).requires_grad_())
def _communicate( tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Optional[Shape] = None, override_scatter_gather_tensors_in_pipeline: bool = False, dtype_: torch.dtype = torch.float, *, scatter_gather_tensors_in_pipeline: bool = True, params_dtype: Optional[torch.dtype] = None, fp32_residual_connection: bool = False, ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]: """Base function for communication of tensors between stages. Args: tensor_send_next: tensor to send to next rank (no tensor sent if set to None). tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). recv_prev: boolean for whether tensor should be received from previous rank. recv_next: boolean for whether tensor should be received from next rank. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to override scatter gather tensors dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape Keyword args: scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on your model deliberately, pass this argument. fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. Returns: tuple containing - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. """ # Create placeholder tensors for receive in forward and backward directions if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` raise RuntimeError( "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`") if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),) else: tensor_chunk_shape = tensor_shape # NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32, # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general. # It might be possible if we restrict model architecture. # dtype = params_dtype or torch.float # if fp32_residual_connection: # dtype = torch.float # if dtype_ is not None: # dtype = dtype_ # requires_grad = False if dtype_ != torch.float32 or params_dtype is not None: if torch.distributed.get_rank() == 0: warnings.warn("Tensor P2P communications are executed in FP32") dtype = torch.float32 requires_grad = True if recv_prev: tensor_recv_prev = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) if recv_next: tensor_recv_next = torch.empty( tensor_chunk_shape, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype, ) # Split tensor into smaller chunks if using scatter-gather optimization. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next) if tensor_send_prev is not None: tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next) # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if recv_prev: tensor_recv_prev = ( gather_split_1d_tensor(tensor_recv_prev) .view(tensor_shape) .requires_grad_() ) if recv_next: tensor_recv_next = ( gather_split_1d_tensor(tensor_recv_next) .view(tensor_shape) .requires_grad_() ) return tensor_recv_prev, tensor_recv_next