示例#1
0
    def forward_pass(
            self,
            subsampling: List[int] = None) -> Tuple[Tensor, Tensor, Tensor]:
        """Do a forward pass. Return input, output, and parameters.

        If sub-sampling is None, the forward pass is calculated on the whole batch.

        Args:
            subsampling: Indices of selected samples. Default: ``None`` (all samples).

        Returns:
            input, output, and loss of the forward pass
        """
        input = self.input.clone()
        target = self.target.clone()

        if subsampling is not None:
            batch_axis = 0
            input = subsample(self.input,
                              dim=batch_axis,
                              subsampling=subsampling)
            target = subsample(self.target,
                               dim=batch_axis,
                               subsampling=subsampling)

        output = self.model(input)
        loss = self.loss_function(output, target)

        return input, output, loss
示例#2
0
    def _forward_pass(
        module: LSTM, mat: Tensor, subsampling: List[int] = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """This performs an additional forward pass and returns the hidden variables.

        This is important because the PyTorch implementation does not grant access to
        some of the hidden variables. Those are computed and returned.

        See also forward pass in class docstring.

        Args:
            module: module
            mat: matrix, used to extract device and shapes.
            subsampling: Indices of active samples. Defaults to ``None`` (all samples).

        Returns:
            ifgo, c, c_tanh (all in format ``[N, T, ...]``)
        """
        _, N, T, _ = mat.shape
        H: int = module.hidden_size
        H0: int = 0 * H
        H1: int = 1 * H
        H2: int = 2 * H
        H3: int = 3 * H
        H4: int = 4 * H
        # forward pass and save i, f, g, o, c, c_tanh-> ifgo, c, c_tanh
        ifgo: Tensor = zeros(N, T, 4 * H, device=mat.device, dtype=mat.dtype)
        c: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype)
        c_tanh: Tensor = zeros(N, T, H, device=mat.device, dtype=mat.dtype)

        input0 = subsample(module.input0, dim=0, subsampling=subsampling)
        output = subsample(module.output, dim=0, subsampling=subsampling)

        for t in range(T):
            ifgo[:, t] = (
                einsum("hi,ni->nh", module.weight_ih_l0, input0[:, t])
                + module.bias_ih_l0
                + module.bias_hh_l0
            )
            if t != 0:
                ifgo[:, t] += einsum("hg,ng->nh", module.weight_hh_l0, output[:, t - 1])
            ifgo[:, t, H0:H1] = sigmoid(ifgo[:, t, H0:H1])
            ifgo[:, t, H1:H2] = sigmoid(ifgo[:, t, H1:H2])
            ifgo[:, t, H2:H3] = tanh(ifgo[:, t, H2:H3])
            ifgo[:, t, H3:H4] = sigmoid(ifgo[:, t, H3:H4])
            c[:, t] = ifgo[:, t, H0:H1] * ifgo[:, t, H2:H3]
            if t != 0:
                c[:, t] += ifgo[:, t, H1:H2] * c[:, t - 1]
            c_tanh[:, t] = tanh(c[:, t])

        return ifgo, c, c_tanh
示例#3
0
文件: rnn.py 项目: f-dangel/backpack
    def _weight_ih_l0_jac_t_mat_prod(
        self,
        module: RNN,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Apply transposed Jacobian of the output w.r.t. weight_ih_l0.

        Args:
            module: extended module
            g_inp: input gradient
            g_out: output gradient
            mat: matrix to multiply
            sum_batch: Whether to sum along batch axis. Defaults to True.
            subsampling: Indices of active samples. Defaults to ``None`` (all samples).

        Returns:
            product
        """
        self._check_parameters(module)
        return einsum(
            f"vnth,ntj->v{'' if sum_batch else 'n'}hj",
            self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat,
                                   subsampling),
            subsample(module.input0, dim=0, subsampling=subsampling),
        )
示例#4
0
文件: rnn.py 项目: f-dangel/backpack
    def _weight_hh_l0_jac_t_mat_prod(
        self,
        module: RNN,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Apply transposed Jacobian of the output w.r.t. weight_hh_l0.

        Args:
            module: extended module
            g_inp: input gradient
            g_out: output gradient
            mat: matrix to multiply
            sum_batch: Whether to sum along batch axis. Defaults to True.
            subsampling: Indices of active samples. Defaults to ``None`` (all samples).

        Returns:
            product
        """
        self._check_parameters(module)
        _, N, _, H = mat.shape
        output = subsample(module.output, dim=0, subsampling=subsampling)
        single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype)
        output_shifted = cat([single_step, output[:, :-1]], dim=1)
        return einsum(
            f"vnth,ntk->v{'' if sum_batch else 'n'}hk",
            self._a_jac_t_mat_prod(module, module.weight_hh_l0, mat,
                                   subsampling),
            output_shifted,
        )
示例#5
0
        def param_function(
            ext: BatchGrad,
            module: Module,
            g_inp: Tuple[Tensor],
            g_out: Tuple[Tensor],
            bpQuantities: None,
        ) -> Tensor:
            """Calculates batch_grad with the help of derivatives object.

            Args:
                ext: extension that is used
                module: module that performed forward pass
                g_inp: input gradient tensors
                g_out: output gradient tensors
                bpQuantities: additional quantities for second order

            Returns:
                Scaled individual gradients
            """
            subsampling = ext.get_subsampling()
            batch_axis = 0

            return self._derivatives.param_mjp(
                param_str,
                module,
                g_inp,
                g_out,
                subsample(g_out[0], dim=batch_axis, subsampling=subsampling),
                sum_batch=False,
                subsampling=subsampling,
            )
示例#6
0
    def _weight_jac_t_mat_prod(
        self,
        module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d],
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        V = mat.shape[0]
        G = module.groups
        C_in = module.input0.shape[1]
        N = module.output.shape[0] if subsampling is None else len(subsampling)
        C_out = module.output.shape[1]

        mat_reshape = mat.reshape(V, N, G, C_out // G,
                                  *module.output.shape[2:])

        u = unfold_by_conv_transpose(
            subsample(module.input0, subsampling=subsampling),
            module).reshape(N, G, C_in // G, *module.weight.shape[2:],
                            *module.output.shape[2:])

        dims_kern = "xyz"[:self.conv_dims]
        dims_data = "abc"[:self.conv_dims]
        result_str = ("vgio" if sum_batch else "vngio") + dims_kern
        equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}"

        final_shape = ((V, *module.weight.shape) if sum_batch else
                       (V, N, *module.weight.shape))

        return einsum(equation, u, mat_reshape).reshape(final_shape)
示例#7
0
    def _weight_jac_t_mat_prod(
        self,
        module: Linear,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: int = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Batch-apply transposed Jacobian of the output w.r.t. the weight.

        Args:
            module: Linear layer.
            g_inp: Gradients w.r.t. module input. Not required by the implementation.
            g_out: Gradients w.r.t. module output. Not required by the implementation.
            mat: Batch of ``V`` vectors of same shape as the layer output
                (``[N, *, out_features]``) to which the transposed output-input Jacobian
                is applied. Has shape ``[V, N, *, out_features]`` if subsampling is not
                used, otherwise ``N`` must be ``len(subsampling)`` instead.
            sum_batch: Sum the result's batch axis. Default: ``True``.
            subsampling: Indices of samples along the output's batch dimension that
                should be considered. Defaults to ``None`` (use all samples).

        Returns:
            Batched transposed Jacobian vector products. Has shape
            ``[V, N, *module.weight.shape]`` when ``sum_batch`` is ``False``. With
            ``sum_batch=True``, has shape ``[V, *module.weight.shape]``. If sub-
            sampling is used, ``N`` must be ``len(subsampling)`` instead.
        """
        d_weight = subsample(module.input0, subsampling=subsampling)

        equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi"
        return einsum(equation, mat, d_weight)
示例#8
0
    def forward_pass(self,
                     input_requires_grad: bool = False,
                     subsampling: List[int] = None
                     ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]:
        """Do a forward pass. Return input, output, and parameters."""
        input: Tensor = self.input.clone().detach()

        if subsampling is not None:
            batch_axis = 0
            input = subsample(input, dim=batch_axis, subsampling=subsampling)

        if input_requires_grad and input.dtype is not long:
            input.requires_grad = True

        if self.is_loss():
            assert subsampling is None
            output: Tensor = self.module(input, self.target)
        else:
            output: Tensor = self.module(input)

        if isinstance(output, tuple):
            # is true for RNN,GRU,LSTM which return tuple (output, ...)
            output: Tensor = output[0]

        return input, output, dict(self.module.named_parameters())
示例#9
0
def _check_like(mat, module, name, diff=1, *args, **kwargs):
    if name in ["output", "input0"] and "subsampling" in kwargs.keys():
        compare = subsample(getattr(module, name),
                            dim=0,
                            subsampling=kwargs["subsampling"])
    else:
        compare = getattr(module, name)

    return check_shape(mat, compare, diff=diff)
示例#10
0
文件: tanh.py 项目: f-dangel/backpack
 def df(
     self,
     module: Tanh,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:
     output = subsample(module.output, subsampling=subsampling)
     return 1.0 - output**2
示例#11
0
 def df(
     self,
     module: LogSigmoid,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:
     """First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `."""
     input0 = subsample(module.input0, subsampling=subsampling)
     return 1 / (exp(input0) + 1)
示例#12
0
 def df(
     self,
     module: Sigmoid,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:
     """First sigmoid derivative: `σ'(x) = σ(x) (1 - σ(x))`."""
     output = subsample(module.output, subsampling=subsampling)
     return output * (1.0 - output)
示例#13
0
 def df(
     self,
     module: ReLU,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:
     """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`."""
     input0 = subsample(module.input0, subsampling=subsampling)
     return gt(input0, 0).to(input0.dtype)
示例#14
0
 def df(
     self,
     module: LeakyReLU,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:
     """``LeakyReLU'(x) = negative_slope if x < 0 else 1``."""
     input0 = subsample(module.input0, subsampling=subsampling)
     df_leakyrelu = gt(input0, 0).to(input0.dtype)
     df_leakyrelu[df_leakyrelu == 0] = module.negative_slope
     return df_leakyrelu
示例#15
0
    def jac_t_vec_prod(self,
                       vec: Tensor,
                       subsampling=None) -> Tensor:  # noqa: D102
        input, output, _ = self.problem.forward_pass(input_requires_grad=True)

        if subsampling is None:
            return transposed_jacobian_vector_product(output, input, vec)[0]
        else:
            # for each sample, multiply by full input Jacobian, slice out result:
            # ( (∂ output[n] / ∂ input)ᵀ v[n] )[n]
            batch_axis = 0
            output = subsample(output, dim=batch_axis, subsampling=subsampling)
            output = output.split(1, dim=batch_axis)
            vec = vec.split(1, dim=batch_axis)

            vjps: List[Tensor] = []
            for sample_idx, out, v in zip(subsampling, output, vec):
                vjp = transposed_jacobian_vector_product(out, input, v)[0]
                vjp = subsample(vjp, dim=batch_axis, subsampling=[sample_idx])
                vjps.append(vjp)

            return cat(vjps, dim=batch_axis)
示例#16
0
    def _get_probs(module: CrossEntropyLoss, subsampling: List[int] = None) -> Tensor:
        """Compute the softmax probabilities from the module input.

        Args:
            module: cross-entropy loss with I/O.
            subsampling: Indices of samples to be considered. Default of ``None`` uses
                the full mini-batch.

        Returns:
            Softmax probabilites
        """
        input0 = subsample(module.input0, subsampling=subsampling)
        return softmax(input0, dim=1)
示例#17
0
    def _weight_jac_t_mat_prod(
        self,
        module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        x_hat, _ = self._get_normalized_input_and_var(module)
        x_hat = subsample(x_hat, subsampling=subsampling)

        equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c"
        return einsum(equation, mat, x_hat)
示例#18
0
 def df(
     self,
     module: Dropout,
     g_inp: Tuple[Tensor],
     g_out: Tuple[Tensor],
     subsampling: List[int] = None,
 ) -> Tensor:  # noqa: D102
     output = subsample(module.output, subsampling=subsampling)
     if module.training:
         scaling = 1 / (1 - module.p)
         mask = 1 - eq(output, 0.0).to(output.dtype)
         return mask * scaling
     else:
         return ones_like(output)
示例#19
0
    def __same_conv_weight_jac_t(
        self,
        module: Union[Conv1d, Conv2d, Conv3d],
        mat: Tensor,
        sum_batch: bool,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Uses convolution of same order."""
        G = module.groups
        V = mat.shape[0]
        C_out = module.output.shape[1]
        N = module.output.shape[0] if subsampling is None else len(subsampling)
        C_in = module.input0.shape[1]
        C_in_axis = 1
        N_axis = 0

        # treat channel groups like vectorization (v) and batch (n) axes
        mat = rearrange(mat, "v n (g c) ... -> (v n g) c ...", g=G, c=C_out // G)
        repeat_pattern = [1, C_in // G] + [1 for _ in range(self.conv_dims)]
        mat = mat.repeat(*repeat_pattern)
        mat = rearrange(mat, "a b ... -> (a b) ...")
        mat = mat.unsqueeze(C_in_axis)

        input = rearrange(
            subsample(module.input0, subsampling=subsampling), "n c ... -> (n c) ..."
        )
        input = input.unsqueeze(N_axis)
        repeat_pattern = [1, V] + [1 for _ in range(self.conv_dims)]
        input = input.repeat(*repeat_pattern)

        grad_weight = self.conv_func(
            input,
            mat,
            bias=None,
            stride=module.dilation,
            padding=module.padding,
            dilation=module.stride,
            groups=C_in * N * V,
        ).squeeze(0)

        for dim in range(self.conv_dims):
            axis = dim + 1
            size = module.weight.shape[2 + dim]
            grad_weight = grad_weight.narrow(axis, 0, size)

        dim = {"g": G, "v": V, "n": N, "i": C_in // G, "o": C_out // G}
        if sum_batch:
            return reduce(grad_weight, "(v n g i o) ... -> v (g o) i ...", "sum", **dim)
        else:
            return rearrange(grad_weight, "(v n g i o) ... -> v n (g o) i ...", **dim)
示例#20
0
    def df(
        self,
        module: ELU,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        subsampling: List[int] = None,
    ):
        """First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`."""
        input0 = subsample(module.input0, subsampling=subsampling)
        non_pos = le(input0, 0)

        result = ones_like(input0)
        result[non_pos] = module.alpha * exp(input0[non_pos])

        return result
示例#21
0
 def get_pooling_idx(
     self,
     module: Union[MaxPool1d, MaxPool2d, MaxPool3d],
     subsampling: List[int] = None,
 ) -> Tensor:
     _, pool_idx = self.maxpool(
         subsample(module.input0, subsampling=subsampling),
         kernel_size=module.kernel_size,
         stride=module.stride,
         padding=module.padding,
         dilation=module.dilation,
         return_indices=True,
         ceil_mode=module.ceil_mode,
     )
     return pool_idx
示例#22
0
文件: selu.py 项目: f-dangel/backpack
    def df(
        self,
        module: SELU,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        subsampling: List[int] = None,
    ) -> Tensor:
        """First SELU derivative: `SELU'(x) = scale if x > 0 else scale*alpha*e^x`."""
        input0 = subsample(module.input0, subsampling=subsampling)
        non_pos = le(input0, 0)

        result = self.scale * ones_like(input0)
        result[non_pos] = self.scale * self.alpha * exp(input0[non_pos])

        return result
示例#23
0
    def _jac_t_mat_prod(
        self,
        module: Slicing,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        subsampling: List[int] = None,
    ) -> Tensor:
        self.no_slice_batch_axis(module)

        input0 = module.input0
        result_shape = (mat.shape[0], *subsample(input0, subsampling=subsampling).shape)
        result = zeros(result_shape, device=input0.device, dtype=input0.dtype)
        result[(slice(None),) + module.slice_info] = mat

        return result
示例#24
0
    def _weight_jac_t_mat_prod(
        self,
        module: Embedding,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        self._check_parameters(module)

        input0 = subsample(module.input0, subsampling=subsampling)
        delta = zeros(module.num_embeddings, *input0.shape, device=mat.device)
        for s in range(module.num_embeddings):
            delta[s] = input0 == s
        equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh"
        return einsum(equation, delta, mat)
示例#25
0
    def input_hessian_via_sqrt_hessian(
            self,
            mc_samples: int = None,
            chunks: int = 1,
            subsampling: List[int] = None) -> Tensor:
        """Computes the Hessian w.r.t. to the input from its matrix square root.

        Args:
            mc_samples: If int, uses an MC approximation with the specified
                number of samples. If None, uses the exact hessian. Defaults to None.
            chunks: Maximum sequential split of the computation. Default: ``1``.
                Only used if mc_samples is specified.
            subsampling: Indices of active samples. ``None`` uses all samples.

        Returns:
            Hessian with respect to the input. Has shape
            ``[N, A, B, ..., N, A, B, ...]`` where ``N`` is the batch size or number
            of active samples when sub-sampling is used, and ``[A, B, ...]`` are the
            input's feature dimensions.
        """
        self.store_forward_io()

        if mc_samples is not None:
            chunk_samples = chunk_sizes(mc_samples, chunks)
            chunk_weights = [samples / mc_samples for samples in chunk_samples]

            individual_hessians: Tensor = sum(
                weight * self._sample_hessians_from_sqrt(
                    self.problem.derivative.sqrt_hessian_sampled(
                        self.problem.module,
                        None,
                        None,
                        mc_samples=samples,
                        subsampling=subsampling,
                    ))
                for weight, samples in zip(chunk_weights, chunk_samples))
        else:
            sqrt_hessian = self.problem.derivative.sqrt_hessian(
                self.problem.module, None, None, subsampling=subsampling)
            individual_hessians = self._sample_hessians_from_sqrt(sqrt_hessian)

        input0 = subsample(self.problem.module.input0, subsampling=subsampling)
        return self._embed_sample_hessians(individual_hessians, input0)
示例#26
0
    def _weight_ih_l0_jac_t_mat_prod(
        self,
        module: LSTM,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        self._check_parameters(module)

        IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
            module, mat, subsampling=subsampling
        )
        return einsum(
            f"vnth,nti->v{'' if sum_batch else 'n'}hi",
            IFGO_prod,
            subsample(module.input0, dim=0, subsampling=subsampling),
        )
示例#27
0
文件: rnn.py 项目: f-dangel/backpack
    def _a_jac_t_mat_prod(
        cls,
        module: RNN,
        weight_hh_l0: Tensor,
        mat: Tensor,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Calculates jacobian vector product wrt a.

        Args:
            module: RNN module
            weight_hh_l0: weight matrix hidden-to-hidden
            mat: matrix to multiply
            subsampling: subsampling

        Returns:
            jacobian vector product wrt a
        """
        V, N, T, H = mat.shape
        output = subsample(module.output, dim=0, subsampling=subsampling)
        a_jac_t_mat_prod: Tensor = zeros(V,
                                         N,
                                         T,
                                         H,
                                         device=mat.device,
                                         dtype=mat.dtype)
        for t in reversed(range(T)):
            if t == (T - 1):
                a_jac_t_mat_prod[:, :, t] = einsum("vnh,nh->vnh", mat[:, :, t],
                                                   1 - output[:, t]**2)
            else:
                a_jac_t_mat_prod[:, :, t] = einsum(
                    "vnh,nh->vnh",
                    mat[:, :, t] + einsum(
                        "vng,gh->vnh",
                        a_jac_t_mat_prod[:, :, t + 1],
                        weight_hh_l0,
                    ),
                    1 - output[:, t]**2,
                )
        return a_jac_t_mat_prod
示例#28
0
    def _weight_hh_l0_jac_t_mat_prod(
        self,
        module: LSTM,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        sum_batch: bool = True,
        subsampling: List[int] = None,
    ) -> Tensor:
        self._check_parameters(module)
        _, N, _, H = mat.shape
        IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
            module, mat, subsampling=subsampling
        )

        subsampled_output = subsample(module.output, dim=0, subsampling=subsampling)
        single_step = zeros(N, 1, H, device=mat.device, dtype=mat.dtype)
        return einsum(
            f"vnth,ntg->v{'' if sum_batch else 'n'}hg",
            IFGO_prod,
            cat([single_step, subsampled_output[:, :-1]], dim=1),
        )
示例#29
0
def test_subsample():
    """Test slicing operations for sub-sampling a tensor's batch axis."""
    manual_seed(0)
    tensor = rand(3, 4, 5, 6)

    # leave tensor untouched when `subsampling = None`
    assert id(subsample(tensor)) == id(tensor)
    assert allclose(subsample(tensor), tensor)

    # slice along correct dimension
    idx = [2, 0]
    assert allclose(subsample(tensor, dim=0, subsampling=idx), tensor[idx])
    assert allclose(subsample(tensor, dim=1, subsampling=idx), tensor[:, idx])
    assert allclose(subsample(tensor, dim=2, subsampling=idx), tensor[:, :,
                                                                      idx])
    assert allclose(subsample(tensor, dim=3, subsampling=idx), tensor[:, :, :,
                                                                      idx])
示例#30
0
    def __higher_conv_weight_jac_t(
        self,
        module: Union[Conv1d, Conv2d, Conv3d],
        mat: Tensor,
        sum_batch: bool,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Requires higher-order convolution.

        The algorithm is proposed in:

            - Rochette, G., Manoel, A., & Tramel, E. W., Efficient per-example
              gradient computations in convolutional neural networks (2019).
        """
        G = module.groups
        V = mat.shape[0]
        C_out = module.output.shape[1]
        N = module.output.shape[0] if subsampling is None else len(subsampling)
        C_in = module.input0.shape[1]

        higher_conv_func = get_conv_function(self.conv_dims + 1)

        spatial_dim = (C_in // G,) + module.input0.shape[2:]
        spatial_dim_axis = (1, V) + tuple([1] * (self.conv_dims + 1))
        spatial_dim_new = (C_in // G,) + module.weight.shape[2:]

        # Reshape to extract groups from the convolutional layer
        # Channels are seen as an extra spatial dimension with kernel size 1
        input_conv = (
            subsample(module.input0, subsampling=subsampling)
            .reshape(1, N * G, *spatial_dim)
            .repeat(*spatial_dim_axis)
        )
        # Compute convolution between input and output; the batchsize is seen
        # as channels, taking advantage of the `groups` argument
        mat_conv = rearrange(mat, "v n c ... -> (v n c) ...").unsqueeze(1).unsqueeze(2)

        stride = (1, *module.stride)
        dilation = (1, *module.dilation)
        padding = (0, *module.padding)

        conv = higher_conv_func(
            input_conv,
            mat_conv,
            groups=V * N * G,
            stride=dilation,
            dilation=stride,
            padding=padding,
        ).squeeze(0)

        # Because of rounding shapes when using non-default stride or dilation,
        # convolution result must be truncated to convolution kernel size
        for axis in range(2, 2 + self.conv_dims):
            conv = conv.narrow(axis, 0, module.weight.shape[axis])

        new_shape = [V, N, C_out, *spatial_dim_new]
        weight_grad = conv.reshape(*new_shape)

        if sum_batch:
            weight_grad = weight_grad.sum(1)

        return weight_grad