Exemplo n.º 1
0
def _get_params_for_weight_decay_optimization(
        model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
    """Divide params into with-weight-decay and without-weight-decay groups.
    Layernorms and biases will have no weight decay but the rest will.
    """
    modules = listify_model(model)
    from apex.normalization.fused_layer_norm import FusedLayerNorm  # NOQA
    weight_decay_params = {'params': []}
    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
    for module in modules:
        for module_ in module.modules():
            if isinstance(module_, FusedLayerNorm):
                no_weight_decay_params['params'].extend(
                    [p for p in list(module_._parameters.values())
                     if p is not None])
            else:
                weight_decay_params['params'].extend(
                    [p for n, p in list(module_._parameters.items())
                     if p is not None and n != 'bias'])
                no_weight_decay_params['params'].extend(
                    [p for n, p in list(module_._parameters.items())
                     if p is not None and n == 'bias'])

    return weight_decay_params, no_weight_decay_params
Exemplo n.º 2
0
def _get_params_for_weight_decay_optimization(
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    no_weight_decay_modules=(FusedLayerNorm, ),
) -> Dict[str, torch.nn.Parameter]:
    """Divide params into with-weight-decay and without-weight-decay groups.

    Layernorms and biases will have no weight decay but the rest will.
    """
    modules = listify_model(model)
    weight_decay_params = {"params": []}
    no_weight_decay_params = {"params": [], "weight_decay": 0.0}
    for module in modules:
        for module_ in module.modules():
            if isinstance(module_, no_weight_decay_modules):
                no_weight_decay_params["params"].extend([
                    p for p in list(module_._parameters.values())
                    if p is not None
                ])
            else:
                weight_decay_params["params"].extend([
                    p for n, p in list(module_._parameters.items())
                    if p is not None and n != "bias"
                ])
                no_weight_decay_params["params"].extend([
                    p for n, p in list(module_._parameters.items())
                    if p is not None and n == "bias"
                ])

    return weight_decay_params, no_weight_decay_params
Exemplo n.º 3
0
def forward_backward_no_pipelining(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    custom_sync_context_handler=None,
    **kwargs,
):
    """Run forward and backward passes with no pipeline parallelism (no inter-stage communication).

    This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A List of torch.Tensors
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        grad_scaler:
        dtype:
        disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
            Should be used when your forward and loss computation is in the autocast context to
            avoid unnecesarily nest autocast context.
        custom_sync_context_handler:
        **kwargs: Added to handle `tensor_shape` which has no effect on this function.

    Returns:
        a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]
    model_type = get_model_type(model)

    if custom_sync_context_handler is not None:
        context_handler = custom_sync_context_handler
    elif isinstance(model,
                    torch.nn.parallel.distributed.DistributedDataParallel):
        context_handler = model.no_sync
    else:
        context_handler = placeholder_handler

    losses_reduced = []
    input_tensor, output_tensor_grad = None, None
    num_micro_batches = get_num_microbatches()
    with context_handler():
        for i in range(num_micro_batches - 1):
            _logger.info(f"Iter {i} of {num_micro_batches - 1}")
            cur_micro_batch = get_kth_microbatch(batch, i)
            _logger.debug("Call `forward_step`")
            output_tensor = forward_step(
                forward_step_func,
                cur_micro_batch,
                model,
                input_tensor,
                losses_reduced,
                dtype=dtype,
                disable_autocast=disable_autocast,
            )
            if not forward_only:
                _logger.debug("Call `backward_step`")
                backward_step(
                    input_tensor,
                    output_tensor,
                    output_tensor_grad,
                    model_type=model_type,
                    grad_scaler=grad_scaler,
                )

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    _logger.info("Cooldown")
    _logger.debug("Call `forward_step`")
    output_tensor = forward_step(
        forward_step_func,
        get_kth_microbatch(batch, num_micro_batches - 1),
        model,
        input_tensor,
        losses_reduced,
        dtype=dtype,
        disable_autocast=disable_autocast,
    )
    if not forward_only:
        _logger.debug("Call `backward_step`")
        backward_step(
            input_tensor,
            output_tensor,
            output_tensor_grad,
            model_type=model_type,
            grad_scaler=grad_scaler,
        )

    return losses_reduced
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size() -
        parallel_state.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
    losses_reduced = []

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)
        cur_microbatch = get_kth_microbatch(batch, i)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        _logger.debug("send fwd")
        p2p_communication.send_forward(output_tensor,
                                       tensor_shape=tensor_shape)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration = i == (num_microbatches_remaining - 1)

        cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            _logger.debug("send fwd")
            p2p_communication.send_forward(output_tensor,
                                           tensor_shape=tensor_shape)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = p2p_communication.recv_forward(
                    tensor_shape=tensor_shape)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape=tensor_shape)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                p2p_communication.send_backward(input_tensor_grad,
                                                tensor_shape=tensor_shape)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = p2p_communication.send_backward_recv_forward(
                    input_tensor_grad, tensor_shape=tensor_shape)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = p2p_communication.recv_backward(
                tensor_shape=tensor_shape)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            _logger.debug("send bwd")
            p2p_communication.send_backward(input_tensor_grad,
                                            tensor_shape=tensor_shape)

    return losses_reduced
Exemplo n.º 5
0
def forward_backward_no_pipelining(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    **kwargs,
):
    """Run forward and backward passes with no pipeline parallelism (no inter-stage communication).

    This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A List of torch.Tensors
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        **kwargs: Added to handle `tensor_shape` which has no effect on this function.

    Returns:
        a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]

    context_handler = placeholder_handler
    if isinstance(model,
                  torch.nn.parallel.distributed.DistributedDataParallel):
        context_handler = model.no_sync

    losses_reduced = []
    input_tensor, output_tensor_grad = None, None
    num_micro_batches = get_num_microbatches()
    with context_handler():
        for i in range(num_micro_batches - 1):
            _logger.info(f"Iter {i} of {num_micro_batches - 1}")
            cur_micro_batch = get_kth_microbatch(batch, i)
            _logger.debug("Call `forward_step`")
            output_tensor = forward_step(forward_step_func, cur_micro_batch,
                                         model, input_tensor, losses_reduced)
            if not forward_only:
                _logger.debug("Call `backward_step`")
                backward_step(input_tensor, output_tensor, output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    _logger.info("Cooldown")
    _logger.debug("Call `forward_step`")
    output_tensor = forward_step(
        forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1),
        model, input_tensor, losses_reduced)
    if not forward_only:
        _logger.debug("Call `backward_step`")
        backward_step(input_tensor, output_tensor, output_tensor_grad)

    return losses_reduced
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Optional[Batch],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    decoder_sequence_length: Optional[int] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwawrgs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model: List[torch.nn.Module] = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model: torch.nn.Module = model[0]

    # Compute number of warmup microbatches.
    num_microbatches: int = get_num_microbatches()
    num_warmup_microbatches: int = (
        parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
    )
    num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    model_type = get_model_type(model)
    rank: int = parallel_state.get_pipeline_model_parallel_rank()
    recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )
    send_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )

    _logger.info(
        f"num_microbatches: {num_microbatches}, "
        f"num_warmup_microbatches: {num_warmup_microbatches}, "
        f"num_microbatches_remaining: {num_microbatches_remaining}"
    )

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors: List[Union[None, torch.Tensor]] = []
    output_tensors: List[Union[None, torch.Tensor]] = []
    losses_reduced: List[Union[None, torch.Tensor]] = []
    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
        output_tensor = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        _logger.debug("send fwd")
        send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration: bool = i == (num_microbatches_remaining - 1)

        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        if forward_only:
            _logger.debug("send fwd")
            send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            _logger.debug("send bwd")
            send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)

    return losses_reduced