Пример #1
0
    def test_send_backward_recv_backward(self):
        self._init_model_parallel()
        tensor = self.create_tensor(self.rank)

        next_tensor = None
        if parallel_state.is_pipeline_first_stage():
            next_tensor = p2p_communication.recv_backward(
                tensor_shape=self.shape, dtype=self.dtype)
        elif parallel_state.is_pipeline_last_stage():
            p2p_communication.send_backward(input_tensor_grad=tensor,
                                            tensor_shape=self.shape,
                                            dtype=self.dtype)
        else:
            next_tensor = p2p_communication.send_backward_recv_backward(
                input_tensor_grad=tensor,
                recv_next=True,
                tensor_shape=self.shape,
                dtype=self.dtype,
            )

        if parallel_state.is_pipeline_last_stage():
            self.assertIsNone(next_tensor)
        else:
            expected_next_tensor = self.create_tensor(self.rank + 1)
            self.assertEqual(next_tensor, expected_next_tensor)
def recv_backward(
    tensor_shapes: List[Union[None, List[int]]],
    *,
    dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
    return output_tensor_grads
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size() -
        parallel_state.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
    losses_reduced = []

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)
        cur_microbatch = get_kth_microbatch(batch, i)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        _logger.debug("send fwd")
        p2p_communication.send_forward(output_tensor,
                                       tensor_shape=tensor_shape)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration = i == (num_microbatches_remaining - 1)

        cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            _logger.debug("send fwd")
            p2p_communication.send_forward(output_tensor,
                                           tensor_shape=tensor_shape)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = p2p_communication.recv_forward(
                    tensor_shape=tensor_shape)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape=tensor_shape)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                p2p_communication.send_backward(input_tensor_grad,
                                                tensor_shape=tensor_shape)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = p2p_communication.send_backward_recv_forward(
                    input_tensor_grad, tensor_shape=tensor_shape)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = p2p_communication.recv_backward(
                tensor_shape=tensor_shape)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            _logger.debug("send bwd")
            p2p_communication.send_backward(input_tensor_grad,
                                            tensor_shape=tensor_shape)

    return losses_reduced
Пример #4
0
def _forward_backward_pipelining_with_interleaving(
    forward_step_func: FwdStepFunc,
    batch: List[Optional[Batch]],
    model: List[torch.nn.Module],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run interleaved 1F1B schedule with communication between pipeline stages as needed.

    This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
    This means that model is split into model chunks.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown
    Note that if `forward_only` this scheduling consists of only warmup phase.

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    if not isinstance(model, list):
        raise RuntimeError("`model` must be a list of `nn.Module`'s'")

    num_model_chunks: int = len(model)
    input_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    output_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
    losses_reduced: List[Union[None, torch.Tensor]] = []
    if not forward_only:
        output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
            [] for _ in range(num_model_chunks)
        ]

    pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size(
    )
    pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank(
    )

    # Compute number of warmup and remaining microbatches.
    num_microbatches: int = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches: bool = False
    if forward_only:
        num_warmup_microbatches: int = num_microbatches
    else:
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = (pipeline_parallel_size -
                                       pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks -
                                        1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    ###################################################################################################################
    # Helper function definitions.
    ###################################################################################################################
    def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
        """Helper function to get the model chunk ID given the iteration number."""
        pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
        )
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size *
                                                  num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
        if not forward:
            model_chunk_id = num_model_chunks - model_chunk_id - 1
        return model_chunk_id

    def forward_step_helper(microbatch_id: int,
                            curr_iters: List[int]) -> torch.Tensor:
        """Helper method to run forward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        # forward step
        if (parallel_state.is_pipeline_first_stage()
                and len(input_tensors[model_chunk_id]) == len(
                    output_tensors[model_chunk_id])):
            input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
        output_tensor = forward_step(
            forward_step_func,
            get_kth_microbatch(batch, curr_iters[model_chunk_id]),
            model[model_chunk_id],
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        curr_iters[model_chunk_id] += 1
        output_tensors[model_chunk_id].append(output_tensor)

        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

        return output_tensor

    def backward_step_helper(microbatch_id: int) -> torch.Tensor:
        """Helper method to run backward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
        model_type = get_model_type(model[model_chunk_id])
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if parallel_state.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = backward_step(
            input_tensor,
            output_tensor,
            output_tensor_grad,
            model_type=model_type,
            grad_scaler=grad_scaler,
            deallocate_pipeline_outputs=deallocate_pipeline_outputs)

        return input_tensor_grad

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
    _logger.info("Warmup phase")
    for k in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
        output_tensor = forward_step_helper(k, curr_iters)

        # Determine if tensor should be received from previous stage.
        next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
        _logger.debug(
            f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
        )

        # Don't send tensor downstream if on last stage.
        if parallel_state.is_pipeline_last_stage():
            _logger.debug("Pipeline last stage, not sending tensor downstream")
            output_tensor = None

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
        if k == (num_warmup_microbatches -
                 1) and not forward_only and not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            _logger.debug("send fwd&bwd and receive fwd&bwd")
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p_communication.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                dtype=dtype,
            )
            output_tensor_grads[num_model_chunks -
                                1].append(output_tensor_grad)
        else:
            _logger.debug("send fwd and receive fwd")
            input_tensor = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                dtype=dtype)
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for k in range(num_microbatches_remaining):
        # Forward pass.
        _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k, curr_iters)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            forward_model_chunk_id)
        if parallel_state.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            backward_model_chunk_id)
        _logger.debug(
            f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
        )
        if parallel_state.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)

        recv_next = True
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)

        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        _logger.debug("send fwd&bwd and receive fwd&bwd")
        (
            input_tensor,
            output_tensor_grad,
        ) = p2p_communication.send_forward_backward_recv_forward_backward(
            output_tensor,
            input_tensor_grad,
            recv_prev=recv_prev,
            recv_next=recv_next,
            tensor_shape=tensor_shape,
            dtype=dtype,
        )
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)

    ###################################################################################################################
    # Run cooldown backward passes (flush out pipeline).
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape=tensor_shape,
                                                dtype=dtype))
        for k in range(num_microbatches_remaining, num_microbatches):
            _logger.debug(
                f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
            )
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k + 1,
                                                              forward=False)
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    dtype=dtype))

    return losses_reduced