def _compute_norm_grad_sample(layer, A, B, batch_dim=0): layer_type = get_layer_type(layer) if layer_type == "LayerNorm": _create_or_extend_grad_sample( layer.weight, sum_over_all_but_batch_and_last_n( F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B, layer.weight.dim(), ), batch_dim, ) _create_or_extend_grad_sample( layer.bias, sum_over_all_but_batch_and_last_n(B, layer.bias.dim()), batch_dim, ) elif layer_type == "GroupNorm": gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B _create_or_extend_grad_sample(layer.weight, torch.einsum("ni...->ni", gs), batch_dim) if layer.bias is not None: _create_or_extend_grad_sample(layer.bias, torch.einsum("ni...->ni", B), batch_dim) elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}: gs = F.instance_norm(A, eps=layer.eps) * B _create_or_extend_grad_sample(layer.weight, torch.einsum("ni...->ni", gs), batch_dim) if layer.bias is not None: _create_or_extend_grad_sample(layer.bias, torch.einsum("ni...->ni", B), batch_dim)
def _compute_norm_grad_sample( # for some reason pyre doesn't understand that # nn.LayerNorm and nn.modules.normalization.LayerNorm is the same thing # pyre-ignore[11] layer: Union[ nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, ], A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0, ) -> None: """ Computes per sample gradients for normalization layers Args: layer: Layer A: Activations B: Backpropagations batch_dim: Batch dimension position """ layer_type = get_layer_type(layer) if layer_type == "LayerNorm": _create_or_extend_grad_sample( layer.weight, sum_over_all_but_batch_and_last_n( F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B, layer.weight.dim(), ), batch_dim, ) _create_or_extend_grad_sample( layer.bias, sum_over_all_but_batch_and_last_n(B, layer.bias.dim()), batch_dim, ) elif layer_type == "GroupNorm": gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B _create_or_extend_grad_sample( layer.weight, torch.einsum("ni...->ni", gs), batch_dim ) if layer.bias is not None: _create_or_extend_grad_sample( layer.bias, torch.einsum("ni...->ni", B), batch_dim ) elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}: gs = F.instance_norm(A, eps=layer.eps) * B _create_or_extend_grad_sample( layer.weight, torch.einsum("ni...->ni", gs), batch_dim ) if layer.bias is not None: _create_or_extend_grad_sample( layer.bias, torch.einsum("ni...->ni", B), batch_dim )
def _compute_norm_grad_sample(layer, A, B): layer_type = get_layer_type(layer) if layer_type == "LayerNorm": layer.weight.grad_sample = sum_over_all_but_batch_and_last_n( F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B, layer.weight.dim(), ) layer.bias.grad_sample = sum_over_all_but_batch_and_last_n( B, layer.bias.dim()) elif layer_type == "GroupNorm": gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B layer.weight.grad_sample = torch.einsum("ni...->ni", gs) if layer.bias is not None: layer.bias.grad_sample = torch.einsum("ni...->ni", B) elif layer_type in {"InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"}: gs = F.instance_norm(A, eps=layer.eps) * B layer.weight.grad_sample = torch.einsum("ni...->ni", gs) if layer.bias is not None: layer.bias.grad_sample = torch.einsum("ni...->ni", B)
def compute_grad_sample( model: nn.Module, loss_type: str = "mean", batch_dim: int = 0 ) -> None: """ Compute per-example gradients and save them under 'param.grad_sample'. Must be called after loss.backprop() Args: model: loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch """ if loss_type not in ("sum", "mean"): raise ValueError(f"loss_type = {loss_type}. Only 'sum' and 'mean' supported") for layer in model.modules(): layer_type = get_layer_type(layer) if not requires_grad(layer) or layer_type not in _supported_layers: continue if not hasattr(layer, "activations"): raise ValueError( f"No activations detected for {type(layer)}," " run forward after add_hooks(model)" ) if not hasattr(layer, "backprops_list"): raise ValueError( "No backprops detected, run backward after add_hooks(model)" ) if len(layer.backprops_list) != 1: raise ValueError( "Multiple backprops detected, make sure to call clear_backprops(model)" ) A = layer.activations n = A.shape[batch_dim] if loss_type == "mean": B = layer.backprops_list[0] * n else: # loss_type == 'sum': B = layer.backprops_list[0] if batch_dim != 0: A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim]) B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim]) if layer_type == "Linear": gs = torch.einsum("n...i,n...j->n...ij", B, A) layer.weight.grad_sample = torch.einsum("n...ij->nij", gs) if layer.bias is not None: layer.bias.grad_sample = torch.einsum("n...k->nk", B) if layer_type == "LayerNorm": layer.weight.grad_sample = sum_over_all_but_batch_and_last_n( F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B, layer.weight.dim(), ) layer.bias.grad_sample = sum_over_all_but_batch_and_last_n( B, layer.bias.dim() ) if layer_type == "GroupNorm": gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B layer.weight.grad_sample = torch.einsum("ni...->ni", gs) if layer.bias is not None: layer.bias.grad_sample = torch.einsum("ni...->ni", B) elif layer_type in ("InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d"): gs = F.instance_norm(A, eps=layer.eps) * B layer.weight.grad_sample = torch.einsum("ni...->ni", gs) if layer.bias is not None: layer.bias.grad_sample = torch.einsum("ni...->ni", B) elif layer_type in ("Conv2d", "Conv1d"): # get A and B in shape depending on the Conv layer if layer_type == "Conv2d": A = torch.nn.functional.unfold( A, layer.kernel_size, padding=layer.padding, stride=layer.stride ) B = B.reshape(n, -1, A.shape[-1]) elif layer_type == "Conv1d": # unfold doesn't work for 3D tensors; so force it to be 4D A = A.unsqueeze(-2) # add the H dimension # set arguments to tuples with appropriate second element A = torch.nn.functional.unfold( A, (1, layer.kernel_size[0]), padding=(0, layer.padding[0]), stride=(1, layer.stride[0]), ) B = B.reshape(n, -1, A.shape[-1]) try: # n=batch_sz; o=num_out_channels; p=num_in_channels*kernel_sz grad_sample = ( torch.einsum("noq,npq->nop", B, A) if layer.groups == 1 else torch.einsum("njk,njk->nj", B, A) ) shape = [n] + list(layer.weight.shape) layer.weight.grad_sample = grad_sample.reshape(shape) except Exception as e: raise type(e)( f"{e} There is probably a problem with {layer_type}.groups" + "It should be either 1 or in_channel" ) if layer.bias is not None: layer.bias.grad_sample = torch.sum(B, dim=2) if layer_type == "SequenceBias": layer.bias.grad_sample = B[:, -1]