Ejemplo n.º 1
0
def _test_logical_or(test_case, shape, device):
    np_input = np.random.randint(3, size=shape)
    np_other = np.random.randint(3, size=shape)
    input = flow.tensor(np_input, dtype=flow.float32, device=flow.device(device))
    other = flow.tensor(np_other, dtype=flow.float32, device=flow.device(device))
    of_out = flow.logical_or(input, other)
    np_out = np.logical_or(np_input, np_other)
    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
Ejemplo n.º 2
0
def clip_grad_norm_(
    parameters: _tensor_or_tensors,
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
) -> Tensor:
    r"""Clips gradient norm of an iterable of parameters.
    The norm is computed over all gradients together, as if they were
    concatenated into a single vector.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:``parameters`` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Parameters after cliping gradient norm
        Total norm of the parameters (viewed as a single vector).
    

    For example:

    .. code-block:: python

        >>> import oneflow as flow
        >>> import numpy as np
        >>> x1 = flow.tensor(np.array([[2, 3, 4], [1.5, 2.6, 3.7]]).astype(np.float32), requires_grad=True)
        >>> m1 = flow.nn.ReLU()
        >>> out1 = m1(x1)
        >>> out1 = out1.sum()
        >>> out1.backward()
        >>> norm1 = flow.nn.utils.clip_grad_norm_(x1, 0.6, 1.0)
        >>> norm1
        tensor(6., dtype=oneflow.float32)
        >>> x1.grad
        tensor([[0.1000, 0.1000, 0.1000],
                [0.1000, 0.1000, 0.1000]], dtype=oneflow.float32)
        >>> x2 = flow.tensor(np.array([[-2, -3, -4], [2.5, 0, 3.2]]).astype(np.float32), requires_grad=True)
        >>> out2 = flow.atan(x2)
        >>> out2 = out2.sum()
        >>> out2.backward()
        >>> norm2 = flow.nn.utils.clip_grad_norm_(x2, 0.5)
        >>> norm2
        tensor(1.0394, dtype=oneflow.float32)
        >>> x2.grad
        tensor([[0.0962, 0.0481, 0.0283],
                [0.0663, 0.4810, 0.0428]], dtype=oneflow.float32)

    """

    if isinstance(parameters, (Tensor, flow._oneflow_internal.Tensor)):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return flow.tensor(0.0)

    if parameters[0].is_global:
        assert all([p.is_global for p in parameters
                    ]), "All parameters must be consistent tensor."
        sbp_broadcast = [flow.sbp.broadcast for _ in parameters[0].sbp]
        param0_placement = parameters[0].placement
        if norm_type == float("inf"):
            norms = [
                p.grad.detach().to_global(
                    sbp=sbp_broadcast).abs().max().to_global(
                        placement=param0_placement) for p in parameters
            ]
            total_norm = norms[0] if len(norms) == 1 else flow.max(
                flow.stack(norms))
        elif norm_type == float("-inf"):
            norms = [
                p.grad.detach().to_global(
                    sbp=sbp_broadcast).abs().min().to_global(
                        placement=param0_placement) for p in parameters
            ]
            total_norm = norms[0] if len(norms) == 1 else flow.min(
                flow.stack(norms))
        else:
            total_norm = flow.linalg.vector_norm(
                flow.stack([
                    flow.linalg.vector_norm(
                        p.grad.detach().to_global(sbp=sbp_broadcast),
                        norm_type).to_global(placement=param0_placement)
                    for p in parameters
                ]),
                norm_type,
            )
        if error_if_nonfinite and flow.logical_or(total_norm.isnan(),
                                                  total_norm.isinf()):
            raise RuntimeError(
                f"The total norm of order {norm_type} for gradients from "
                "`parameters` is non-finite, so it cannot be clipped. To disable "
                "this error and scale the gradients by the non-finite norm anyway, "
                "set `error_if_nonfinite=False`")
        clip_coef = max_norm / (total_norm + 1e-6)
        clip_coef_clamped = clip_coef.clamp(max=1.0)
        for p in parameters:
            p.grad.detach().mul_(
                clip_coef_clamped.to_global(placement=p.placement))
    else:
        device = parameters[0].grad.device
        if norm_type == float("inf"):
            norms = [
                p.grad.detach().abs().max().to(device) for p in parameters
            ]
            total_norm = norms[0] if len(norms) == 1 else flow.max(
                flow.stack(norms))
        elif norm_type == float("-inf"):
            norms = [
                p.grad.detach().abs().min().to(device) for p in parameters
            ]
            total_norm = norms[0] if len(norms) == 1 else flow.min(
                flow.stack(norms))
        else:
            total_norm = flow.linalg.vector_norm(
                flow.stack([
                    flow.linalg.vector_norm(p.grad.detach(),
                                            norm_type).to(device)
                    for p in parameters
                ]),
                norm_type,
            )
        if error_if_nonfinite and flow.logical_or(total_norm.isnan(),
                                                  total_norm.isinf()):
            raise RuntimeError(
                f"The total norm of order {norm_type} for gradients from "
                "`parameters` is non-finite, so it cannot be clipped. To disable "
                "this error and scale the gradients by the non-finite norm anyway, "
                "set `error_if_nonfinite=False`")
        clip_coef = max_norm / (total_norm + 1e-6)
        clip_coef_clamped = clip_coef.clamp(max=1.0)
        for p in parameters:
            p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
    return total_norm
Ejemplo n.º 3
0
def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    assert (
        embed_dim == embed_dim_to_check
    ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    if isinstance(embed_dim, Tensor):
        # embed_dim can be a tensor when JIT tracing
        head_dim = embed_dim.div(num_heads)
    else:
        head_dim = embed_dim // num_heads
    assert (head_dim * num_heads == embed_dim
            ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert (
            key.shape[:2] == value.shape[:2]
        ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert (
            key.shape == value.shape
        ), f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    #
    if not use_separate_proj_weight:
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight,
                                        in_proj_bias)
    else:
        assert (q_proj_weight is not None
                ), "use_separate_proj_weight is True but q_proj_weight is None"
        assert (k_proj_weight is not None
                ), "use_separate_proj_weight is True but k_proj_weight is None"
        assert (v_proj_weight is not None
                ), "use_separate_proj_weight is True but v_proj_weight is None"
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3, dim=0)
        q, k, v = _in_projection(
            query,
            key,
            value,
            q_proj_weight,
            k_proj_weight,
            v_proj_weight,
            b_q,
            b_k,
            b_v,
        )

    # prep attention mask
    if attn_mask is not None:
        assert (
            attn_mask.dtype.is_floating_point == False
        ), f"Only integer type are supported for attn_mask, not {attn_mask.dtype}"
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                )
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                )
        else:
            raise RuntimeError(
                f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = flow.cat([k, bias_k.repeat((1, bsz, 1))])
        v = flow.cat([v, bias_v.repeat((1, bsz, 1))])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1, 0, 0))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1, 0, 0))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    # replace torch.contiguous with reshape
    q = q.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if static_k is None:
        k = k.reshape(-1, bsz * num_heads, head_dim).transpose(0, 1)
    else:
        assert (
            static_k.size(0) == bsz * num_heads
        ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert (
            static_k.size(2) == head_dim
        ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    if static_v is None:
        v = v.reshape(-1, bsz * num_heads, head_dim).transpose(0, 1)
    else:
        assert (
            static_v.size(0) == bsz * num_heads
        ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert (
            static_v.size(2) == head_dim
        ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # add zero attention along batch dimension (now first)
    if add_zero_attn:
        zero_attn_shape = (bsz * num_heads, 1, head_dim)
        k = flow.cat(
            [k, flow.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
            dim=1)
        v = flow.cat(
            [v, flow.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
            dim=1)
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1, 0, 0))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1, 0, 0))

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (
            bsz,
            src_len,
        ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = (key_padding_mask.reshape(
            bsz, 1, 1, src_len).expand(-1, num_heads, tgt_len,
                                       -1).reshape(bsz * num_heads, tgt_len,
                                                   src_len))
        if attn_mask is not None:
            attn_mask = attn_mask.expand(bsz * num_heads, -1, -1)
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = flow.logical_or(attn_mask, key_padding_mask)

    # convert mask to float
    if attn_mask is not None and attn_mask.dtype.is_floating_point == False:
        new_attn_mask = flow.zeros_like(attn_mask).to(flow.float)
        new_attn_mask = new_attn_mask.masked_fill(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #
    attn_output, attn_output_weights = _scaled_dot_product_attention(
        q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(0, 1).reshape(tgt_len, bsz, embed_dim)
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.reshape(
            bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None