示例#1
0
def _compute_norm_grad_sample(layer, A, B, batch_dim=0):
    layer_type = get_layer_type(layer)
    if layer_type == "LayerNorm":
        _create_or_extend_grad_sample(
            layer.weight,
            sum_over_all_but_batch_and_last_n(
                F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B,
                layer.weight.dim(),
            ),
            batch_dim,
        )
        _create_or_extend_grad_sample(
            layer.bias,
            sum_over_all_but_batch_and_last_n(B, layer.bias.dim()),
            batch_dim,
        )
    elif layer_type == "GroupNorm":
        gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B
        _create_or_extend_grad_sample(layer.weight,
                                      torch.einsum("ni...->ni", gs), batch_dim)
        if layer.bias is not None:
            _create_or_extend_grad_sample(layer.bias,
                                          torch.einsum("ni...->ni", B),
                                          batch_dim)
    elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}:
        gs = F.instance_norm(A, eps=layer.eps) * B
        _create_or_extend_grad_sample(layer.weight,
                                      torch.einsum("ni...->ni", gs), batch_dim)
        if layer.bias is not None:
            _create_or_extend_grad_sample(layer.bias,
                                          torch.einsum("ni...->ni", B),
                                          batch_dim)
示例#2
0
def _compute_norm_grad_sample(
    # for some reason pyre doesn't understand that
    # nn.LayerNorm and nn.modules.normalization.LayerNorm is the same thing
    # pyre-ignore[11]
    layer: Union[
        nn.LayerNorm,
        nn.GroupNorm,
        nn.InstanceNorm1d,
        nn.InstanceNorm2d,
        nn.InstanceNorm3d,
    ],
    A: torch.Tensor,
    B: torch.Tensor,
    batch_dim: int = 0,
) -> None:
    """
    Computes per sample gradients for normalization layers

    Args:
        layer: Layer
        A: Activations
        B: Backpropagations
        batch_dim: Batch dimension position
    """
    layer_type = get_layer_type(layer)
    if layer_type == "LayerNorm":
        _create_or_extend_grad_sample(
            layer.weight,
            sum_over_all_but_batch_and_last_n(
                F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B,
                layer.weight.dim(),
            ),
            batch_dim,
        )
        _create_or_extend_grad_sample(
            layer.bias,
            sum_over_all_but_batch_and_last_n(B, layer.bias.dim()),
            batch_dim,
        )
    elif layer_type == "GroupNorm":
        gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B
        _create_or_extend_grad_sample(
            layer.weight, torch.einsum("ni...->ni", gs), batch_dim
        )
        if layer.bias is not None:
            _create_or_extend_grad_sample(
                layer.bias, torch.einsum("ni...->ni", B), batch_dim
            )
    elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}:
        gs = F.instance_norm(A, eps=layer.eps) * B
        _create_or_extend_grad_sample(
            layer.weight, torch.einsum("ni...->ni", gs), batch_dim
        )
        if layer.bias is not None:
            _create_or_extend_grad_sample(
                layer.bias, torch.einsum("ni...->ni", B), batch_dim
            )
示例#3
0
 def get_weight(self):
     if self.use_layernorm:
         weight = self.scale * \
             F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
     else:
         mean = torch.mean(
             self.weight, dim=[1, 2, 3], keepdim=True)
         std = torch.std(
             self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
         weight = self.scale * (self.weight - mean) / (std + self.eps)
     if self.gain is not None:
         weight = weight * self.gain
     return weight
def _compute_norm_grad_sample(layer, A, B):
    layer_type = get_layer_type(layer)
    if layer_type == "LayerNorm":
        layer.weight.grad_sample = sum_over_all_but_batch_and_last_n(
            F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B,
            layer.weight.dim(),
        )
        layer.bias.grad_sample = sum_over_all_but_batch_and_last_n(
            B, layer.bias.dim())
    elif layer_type == "GroupNorm":
        gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B
        layer.weight.grad_sample = torch.einsum("ni...->ni", gs)
        if layer.bias is not None:
            layer.bias.grad_sample = torch.einsum("ni...->ni", B)
    elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}:
        gs = F.instance_norm(A, eps=layer.eps) * B
        layer.weight.grad_sample = torch.einsum("ni...->ni", gs)
        if layer.bias is not None:
            layer.bias.grad_sample = torch.einsum("ni...->ni", B)
def compute_grad_sample(
    model: nn.Module, loss_type: str = "mean", batch_dim: int = 0
) -> None:
    """
    Compute per-example gradients and save them under 'param.grad_sample'.
    Must be called after loss.backprop()
    Args:
        model:
        loss_type: either "mean" or "sum" depending whether backpropped
        loss was averaged or summed over batch
    """

    if loss_type not in ("sum", "mean"):
        raise ValueError(f"loss_type = {loss_type}. Only 'sum' and 'mean' supported")
    for layer in model.modules():
        layer_type = get_layer_type(layer)
        if not requires_grad(layer) or layer_type not in _supported_layers:
            continue
        if not hasattr(layer, "activations"):
            raise ValueError(
                f"No activations detected for {type(layer)},"
                " run forward after add_hooks(model)"
            )
        if not hasattr(layer, "backprops_list"):
            raise ValueError(
                "No backprops detected, run backward after add_hooks(model)"
            )
        if len(layer.backprops_list) != 1:
            raise ValueError(
                "Multiple backprops detected, make sure to call clear_backprops(model)"
            )

        A = layer.activations
        n = A.shape[batch_dim]
        if loss_type == "mean":
            B = layer.backprops_list[0] * n
        else:  # loss_type == 'sum':
            B = layer.backprops_list[0]

        if batch_dim != 0:
            A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim])
            B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim])

        if layer_type == "Linear":
            gs = torch.einsum("n...i,n...j->n...ij", B, A)
            layer.weight.grad_sample = torch.einsum("n...ij->nij", gs)
            if layer.bias is not None:
                layer.bias.grad_sample = torch.einsum("n...k->nk", B)

        if layer_type == "LayerNorm":
            layer.weight.grad_sample = sum_over_all_but_batch_and_last_n(
                F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B,
                layer.weight.dim(),
            )
            layer.bias.grad_sample = sum_over_all_but_batch_and_last_n(
                B, layer.bias.dim()
            )

        if layer_type == "GroupNorm":
            gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B
            layer.weight.grad_sample = torch.einsum("ni...->ni", gs)
            if layer.bias is not None:
                layer.bias.grad_sample = torch.einsum("ni...->ni", B)

        elif layer_type in ("InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"):
            gs = F.instance_norm(A, eps=layer.eps) * B
            layer.weight.grad_sample = torch.einsum("ni...->ni", gs)
            if layer.bias is not None:
                layer.bias.grad_sample = torch.einsum("ni...->ni", B)

        elif layer_type in ("Conv2d", "Conv1d"):
            # get A and B in shape depending on the Conv layer
            if layer_type == "Conv2d":
                A = torch.nn.functional.unfold(
                    A, layer.kernel_size, padding=layer.padding, stride=layer.stride
                )
                B = B.reshape(n, -1, A.shape[-1])
            elif layer_type == "Conv1d":
                # unfold doesn't work for 3D tensors; so force it to be 4D
                A = A.unsqueeze(-2)  # add the H dimension
                # set arguments to tuples with appropriate second element
                A = torch.nn.functional.unfold(
                    A,
                    (1, layer.kernel_size[0]),
                    padding=(0, layer.padding[0]),
                    stride=(1, layer.stride[0]),
                )
                B = B.reshape(n, -1, A.shape[-1])
            try:
                # n=batch_sz; o=num_out_channels; p=num_in_channels*kernel_sz
                grad_sample = (
                    torch.einsum("noq,npq->nop", B, A)
                    if layer.groups == 1
                    else torch.einsum("njk,njk->nj", B, A)
                )
                shape = [n] + list(layer.weight.shape)
                layer.weight.grad_sample = grad_sample.reshape(shape)
            except Exception as e:
                raise type(e)(
                    f"{e} There is probably a problem with {layer_type}.groups"
                    + "It should be either 1 or in_channel"
                )
            if layer.bias is not None:
                layer.bias.grad_sample = torch.sum(B, dim=2)
        if layer_type == "SequenceBias":
            layer.bias.grad_sample = B[:, -1]