Ejemplo n.º 1
0
def native_layer_norm(
    input: Tensor,
    normalized_shape: List[int],
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
    computation_dtype = utils.get_computation_dtype(input.dtype)

    axis = input.dim() - len(normalized_shape)
    if prod(list(input.shape[:axis])) == 0:
        mean = input.new_zeros((0, ), dtype=computation_dtype)
        rstd = input.new_zeros((0, ), dtype=computation_dtype)
        out = input
    else:
        reduction_dims = list(range(axis, input.dim()))
        out, mean, rstd = normalize(input, reduction_dims, eps)

        if weight is not None:
            out = out * weight
        if bias is not None:
            out = out + bias

        out = out.to(dtype=input.dtype)

    if input.device.type == 'cpu':
        mean = mean.to(dtype=input.dtype)
        rstd = rstd.to(dtype=input.dtype)
    return (out, mean, rstd)
Ejemplo n.º 2
0
def native_batch_norm(
    input: Tensor,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    running_mean: Optional[Tensor],
    running_var: Optional[Tensor],
    training: bool,
    momentum: float,
    eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
    reduction_dims = [0] + list(range(2, input.dim()))
    computation_dtype = utils.get_computation_dtype(input.dtype)
    if training:
        output, mean, rstd = normalize(input, reduction_dims, eps)

        save_mean = _squeeze_multiple(mean, reduction_dims)
        save_rstd = _squeeze_multiple(rstd, reduction_dims)
        if running_mean is not None:
            running_mean.copy_(momentum * save_mean +
                               (1 - momentum) * running_mean)
        if running_var is not None:
            n = input.numel() / input.shape[1]
            # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
            # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
            # numerics probably don't matter.
            unbiased_var = torch.var(input, reduction_dims,
                                     unbiased=False) * (n / (n - 1))
            running_var.copy_(momentum * unbiased_var +
                              (1 - momentum) * running_var)
    else:
        assert running_mean is not None and running_var is not None
        running_mean = running_mean.to(dtype=computation_dtype)
        running_var = running_var.to(dtype=computation_dtype)
        mean = running_mean
        invstd = 1 / (torch.sqrt(running_var + eps))
        # Very annoying inconsistency where CPU and CUDA give different shapes
        if input.device.type != "cpu":
            save_mean = running_mean
            save_rstd = invstd
        else:
            save_mean = input.new_zeros((0, ))
            save_rstd = input.new_zeros((0, ))
        mean = _unsqueeze_to_dim(mean, input.dim() - 1)
        invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
        output = ((input - mean) * invstd)

    if weight is None:
        weight = input.new_ones(())

    if bias is None:
        bias = input.new_zeros(())

    weight = _unsqueeze_to_dim(weight, input.dim() - 1)
    bias = _unsqueeze_to_dim(bias, input.dim() - 1)
    output = output * weight + bias
    if input.device.type == 'cpu':
        save_mean = save_mean.to(dtype=input.dtype)
        save_rstd = save_rstd.to(dtype=input.dtype)
    return output.to(dtype=input.dtype), save_mean, save_rstd
Ejemplo n.º 3
0
def normalize(input, norm_dims, eps):
    computation_dtype = utils.get_computation_dtype(input.dtype)
    input_acc = input.to(dtype=computation_dtype)
    biased_var = torch.var(input_acc, dim=norm_dims, unbiased=False, keepdim=True)
    mean = torch.mean(input_acc, dim=norm_dims, keepdim=True)
    rstd = torch.rsqrt(biased_var + eps)

    out = (input - mean) * rstd
    return out, mean, rstd
Ejemplo n.º 4
0
def native_layer_norm_backward(
    grad_out: Tensor,
    input: Tensor,
    normalized_shape: List[int],
    mean: Tensor,
    rstd: Tensor,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
    input_shape = input.shape
    input_ndim = input.dim()
    computation_dtype = utils.get_computation_dtype(input.dtype)
    grad_out_cast, input_cast, weight_cast, bias_cast = [
        x.to(computation_dtype) if x is not None else x
        for x in (grad_out, input, weight, bias)
    ]
    assert grad_out_cast is not None

    axis = input_ndim - len(normalized_shape)
    inner_dims = input_shape[axis:]
    outer_dims = input_shape[:axis]
    inner_dim_indices: List[int] = []
    outer_dim_indices: List[int] = []
    for i in range(input_ndim):
        if i >= axis:
            inner_dim_indices.append(i)
        else:
            outer_dim_indices.append(i)

    N = prod(inner_dims)  # type: ignore[arg-type]
    M = prod(outer_dims)  # type: ignore[arg-type]
    if M <= 0 or N <= 0:
        return (
            input.new_zeros(input_shape),
            input.new_zeros(input_shape[axis:]),
            input.new_zeros(input_shape[axis:]),
        )

    x_hat = (input_cast - mean) * rstd
    if weight_cast is not None:
        grad_x_hat = grad_out_cast * weight_cast
    else:
        grad_x_hat = grad_out_cast
    a = grad_x_hat * N
    b = torch.sum(grad_x_hat, inner_dim_indices, True)
    c1 = torch.mul(grad_x_hat, x_hat)
    c2 = torch.sum(c1, inner_dim_indices, True)
    c3 = torch.mul(x_hat, c2)

    inner = a - b - c3
    d_input: Optional[Tensor] = None
    d_weight: Optional[Tensor] = None
    d_bias: Optional[Tensor] = None
    if output_mask[0]:
        d_input = (rstd / N) * inner

    if output_mask[1] and weight_cast is not None:
        if len(outer_dim_indices) > 0:
            d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices,
                                 False)
        else:
            d_weight = grad_out_cast * x_hat

    if output_mask[2] and bias_cast is not None:
        if len(outer_dim_indices) > 0:
            d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
        else:
            d_bias = grad_out_cast

    return _maybe_cast(d_input, input.dtype), _maybe_cast(
        d_weight, input.dtype), _maybe_cast(d_bias, input.dtype)