示例#1
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible")
        inputs = ctx.saved_tensors
        if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
            inputs[0].data = gather_split_1d_tensor(inputs[0].data)
            inputs[0].data = inputs[0].data.view(ctx.input_0_shape)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None,) + grads
示例#2
0
 def gather_recv_prev_wait():
     tensor_recv_prev_req.wait()
     # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
     # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
     torch.cuda.synchronize()
     return (gather_split_1d_tensor(tensor_recv_prev).view(
         tensor_shape).requires_grad_())
示例#3
0
def _communicate(
    tensor_send_next: Optional[torch.Tensor],
    tensor_send_prev: Optional[torch.Tensor],
    recv_prev: bool,
    recv_next: bool,
    tensor_shape: Optional[Shape] = None,
    override_scatter_gather_tensors_in_pipeline: bool = False,
    dtype_: Optional[torch.dtype] = None,
    *,
    scatter_gather_tensors_in_pipeline: bool = True,
    params_dtype: Optional[torch.dtype] = None,
    fp32_residual_connection: bool = False,
    async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor,
                                                          FutureTensor, None]]:
    """Base function for communication of tensors between stages.

    dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
    torch.float32 is used.

    See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
    for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.

    Args:
        tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
        recv_prev: boolean for whether tensor should be received from previous rank.
        recv_next: boolean for whether tensor should be received from next rank.
        tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline:
            optional, this is used when tensor_shape is provided to override scatter gather tensors
        dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape

    Keyword args:
        scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
        params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
            your model deliberately, pass this argument.
        fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.

    Returns:
        tuple containing

        - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
        - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
    """
    # Create placeholder tensors for receive in forward and backward directions if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
        raise RuntimeError(
            "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
        )
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (
            reduce(operator.mul, tensor_shape, 1) //
            parallel_state.get_tensor_model_parallel_world_size(), )
    else:
        tensor_chunk_shape = tensor_shape

    # The dtype logic below is copied from NVIDIA/Megatron-LM repo:
    # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
    # NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
    # use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
    # NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
    # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
    # It might be possible if we restrict model architecture.
    dtype = params_dtype or torch.float
    if fp32_residual_connection:
        dtype = torch.float
    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

    if recv_prev:
        tensor_recv_prev = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )
    if recv_next:
        tensor_recv_next = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = split_tensor_into_1d_equal_chunks(
                tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = split_tensor_into_1d_equal_chunks(
                tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(
        tensor_send_prev,
        tensor_send_next,
        tensor_recv_prev,
        tensor_recv_next,
        async_comm=async_comm)

    if async_comm:
        tensor_recv_prev_waitfunc = None
        tensor_recv_next_waitfunc = None
        # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
        # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
        if tensor_recv_prev_req is not None:

            def tensor_recv_prev_wait():
                tensor_recv_prev_req.wait()
                torch.cuda.synchronize()

            tensor_recv_prev_waitfunc = tensor_recv_prev_wait
        if tensor_recv_next_req is not None:

            def tensor_recv_next_wait():
                tensor_recv_next_req.wait()
                torch.cuda.synchronize()

            tensor_recv_next_waitfunc = tensor_recv_next_wait
    else:
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if not async_comm:
            if recv_prev:
                tensor_recv_prev = (gather_split_1d_tensor(
                    tensor_recv_prev).view(tensor_shape).requires_grad_())

            if recv_next:
                tensor_recv_next = (gather_split_1d_tensor(
                    tensor_recv_next).view(tensor_shape).requires_grad_())
        else:

            def gather_recv_prev_wait():
                tensor_recv_prev_req.wait()
                # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
                # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
                torch.cuda.synchronize()
                return (gather_split_1d_tensor(tensor_recv_prev).view(
                    tensor_shape).requires_grad_())

            def gather_recv_next_wait():
                tensor_recv_next_req.wait()
                torch.cuda.synchronize()
                return (gather_split_1d_tensor(tensor_recv_next).view(
                    tensor_shape).requires_grad_())

            tensor_recv_prev_waitfunc = gather_recv_prev_wait
            tensor_recv_next_waitfunc = gather_recv_next_wait
    if async_comm:
        future_tensor_recv_prev = None
        future_tensor_recv_next = None
        if tensor_recv_prev is not None:
            future_tensor_recv_prev = FutureTensor(tensor_recv_prev,
                                                   tensor_recv_prev_waitfunc)
        if tensor_recv_next is not None:
            future_tensor_recv_next = FutureTensor(tensor_recv_next,
                                                   tensor_recv_next_waitfunc)
        return future_tensor_recv_prev, future_tensor_recv_next

    return tensor_recv_prev, tensor_recv_next
示例#4
0
 def gather_recv_next_wait():
     tensor_recv_next_req.wait()
     torch.cuda.synchronize()
     return (gather_split_1d_tensor(tensor_recv_next).view(
         tensor_shape).requires_grad_())
示例#5
0
def _communicate(
    tensor_send_next: Optional[torch.Tensor],
    tensor_send_prev: Optional[torch.Tensor],
    recv_prev: bool,
    recv_next: bool,
    tensor_shape: Optional[Shape] = None,
    override_scatter_gather_tensors_in_pipeline: bool = False,
    dtype_: torch.dtype = torch.float,
    *,
    scatter_gather_tensors_in_pipeline: bool = True,
    params_dtype: Optional[torch.dtype] = None,
    fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
    """Base function for communication of tensors between stages.

    Args:
        tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
        recv_prev: boolean for whether tensor should be received from previous rank.
        recv_next: boolean for whether tensor should be received from next rank.
        tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline:
            optional, this is used when tensor_shape is provided to override scatter gather tensors
        dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape

    Keyword args:
        scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
        params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
            your model deliberately, pass this argument.
        fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.

    Returns:
        tuple containing

        - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
        - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
    """
    # Create placeholder tensors for receive in forward and backward directions if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    if tensor_shape is None:
        # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
        raise RuntimeError(
            "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
    else:
        tensor_chunk_shape = tensor_shape

    # NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
    # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
    # It might be possible if we restrict model architecture.
    # dtype = params_dtype or torch.float
    # if fp32_residual_connection:
    #     dtype = torch.float
    # if dtype_ is not None:
    #     dtype = dtype_
    #     requires_grad = False
    if dtype_ != torch.float32 or params_dtype is not None:
        if torch.distributed.get_rank() == 0:
            warnings.warn("Tensor P2P communications are executed in FP32")
    dtype = torch.float32
    requires_grad = True

    if recv_prev:
        tensor_recv_prev = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )
    if recv_next:
        tensor_recv_next = torch.empty(
            tensor_chunk_shape,
            requires_grad=requires_grad,
            device=torch.cuda.current_device(),
            dtype=dtype,
        )

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
    # To protect against race condition when using batch_isend_irecv().
    torch.cuda.synchronize()

    # If using scatter-gather optimization, gather smaller chunks.
    if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = (
                gather_split_1d_tensor(tensor_recv_prev)
                .view(tensor_shape)
                .requires_grad_()
            )

        if recv_next:
            tensor_recv_next = (
                gather_split_1d_tensor(tensor_recv_next)
                .view(tensor_shape)
                .requires_grad_()
            )

    return tensor_recv_prev, tensor_recv_next