Пример #1
0
def spectrum_to_basis(spectrum: torch.Tensor,
                      l2_normalize: bool = True) -> torch.Tensor:
    """Convert spectrum matrix to Fourier basis by 2D FFT. Shape of returned basis is (H, W).

    Note:
        - Currently, only supported the case H==W. If H!=W, returned basis might be wrong.
        - In order to apply 2D FFT, dim argument of torch.fft.irfftn should be =(-2,-1).

    Args:
        spectrum (torch.Tensor): 2D spectrum matrix. Its shape should be (H, W//2+1).
                                 Here, (H, W) represent the size of 2D Fourier basis we want to get.
        l2_normalize (bool): If True, basis is l2 normalized.

    Returns:
        torch.Tensor: 2D Fourier basis.

    """
    assert len(spectrum.size()) == 2
    H = spectrum.size(-2)  # currently, only consider the case H==W
    basis = fft.irfftn(spectrum, s=(H, H), dim=(-2, -1))

    if l2_normalize:
        return cast(torch.Tensor, basis / basis.norm(dim=(-2, -1))[None, None])
    else:
        return cast(torch.Tensor, basis)
Пример #2
0
def cifft2(a, signal_sizes=None):
    """Do inverse FFT corresponding to cfft2."""

    b_in = torch.view_as_complex(irfftshift2(a))
    # , 2, signal_sizes=signal_sizes)

    s = [-1, 2 * b_in.size(-1) - 1] if signal_sizes is None else signal_sizes
    return torch_fft.irfftn(b_in, s=s, dim=[-2, -1])
Пример #3
0
def tikhonov_filter(s, *, lmbda=1.0, npd=16, dtype=torch.float32):
    r"""Lowpass filter based on Tikhonov regularization.

    Lowpass filter image(s) and return low and high frequency
    components, consisting of the lowpass filtered image and its
    difference with the input image. The lowpass filter is equivalent to
    Tikhonov regularization with `lmbda` as the regularization parameter
    and a discrete gradient as the operator in the regularization term,
    i.e. the lowpass component is the solution to

    .. math::
      \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s}
      \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;,

    where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the
    regularization parameter, and :math:`G_i` is an operator that
    computes the discrete gradient along image axis :math:`i`. Once the
    lowpass component :math:`\mathbf{x}` has been computed, the highpass
    component is just :math:`\mathbf{s} - \mathbf{x}`.

    Parameters
    ----------
    s : array_like
      Input image or array of images.
    lmbda : float
      Regularization parameter controlling lowpass filtering.
    npd : int, optional (default=16)
      Number of samples to pad at image boundaries.

    Returns
    -------
    slp : array_like
      Lowpass image or array of images.
    shp : array_like
      Highpass image or array of images.
    """

    grv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([2, 1])).to(s.device)
    gcv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([1, 2])).to(s.device)
    fftopt = {"s": (s.shape[0] + 2 * npd, s.shape[1] + 2 * npd), "dim": (0, 1)}
    Gr = tfft.rfftn(grv, **fftopt)
    Gc = tfft.rfftn(gcv, **fftopt)
    A = 1.0 + lmbda * (torch.conj(Gr) * Gr + torch.conj(Gc) * Gc).real
    if s.ndim > 2:
        A = A[(slice(None), ) * 2 + (np.newaxis, ) * (s.ndim - 2)]
    fill = ((npd, npd), ) * 2 + ((0, 0), ) * (s.ndim - 2)
    snp = np.pad(s.cpu().numpy(), fill, 'symmetric')
    # sp = tpad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')
    sp = torch.from_numpy(snp).to(s.device)
    # sp = torch.from_numpy(np.pad(s.numpy(), ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric'))
    spshp = sp.shape
    sp = tfft.rfftn(sp, dim=(0, 1))
    sp /= A
    sp = tfft.irfftn(sp, s=spshp[0:2], dim=(0, 1))
    slp = sp[npd:(sp.shape[0] - npd), npd:(sp.shape[1] - npd)]
    shp = s - slp
    return slp, shp
Пример #4
0
    def forward(
        self,
        signal: Tensor,
        kernel: Tensor,
        bias: Tensor = None,
        padding: Union[int, Iterable[int]] = 0,
        stride: Union[int, Iterable[int]] = 1,
        groups: int = 1,
    ) -> Tensor:
        # Cast padding & stride to tuples.
        padding_ = self.to_ntuple(padding, n=signal.ndim - 2)
        stride_ = self.to_ntuple(stride, n=signal.ndim - 2)

        # Pad the input signal & kernel tensors
        signal_padding = [p for p in padding_[::-1] for _ in range(2)]
        signal = f.pad(signal, signal_padding)

        # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
        # have *even* length.  Just pad with one more zero if the final dimension is odd.
        if signal.size(-1) % 2 != 0:
            signal_ = f.pad(signal, [0, 1])
        else:
            signal_ = signal

        kernel_padding = [
            pad for i in reversed(range(2, signal_.ndim))
            for pad in [0, signal_.size(i) - kernel.size(i)]
        ]
        padded_kernel = f.pad(kernel, kernel_padding)

        # Perform fourier convolution -- FFT, matrix multiply, then IFFT
        # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
        signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
        kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))

        kernel_fr.imag *= -1
        output_fr = self.complex_matmul(signal_fr, kernel_fr, groups=groups)
        output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

        # Remove extra padded values
        crop_slices = [slice(0, output.size(0)),
                       slice(0, output.size(1))] + [
                           slice(0, (signal.size(i) - kernel.size(i) + 1),
                                 stride_[i - 2])
                           for i in range(2, signal.ndim)
                       ]
        output = output[crop_slices].contiguous()

        # Optionally, add a bias term before returning.
        if bias is not None:
            bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
            output += bias.view(bias_shape)

        return output
Пример #5
0
    def forward(self, x):
        x = self.feature(x) * self.config.cos_window

        xf = fft.rfftn(x, dim=[-2, -1])

        kxzf = torch.sum(xf * torch.conj(self.model_zf), dim=1, keepdim=True)

        response = fft.irfftn(kxzf * self.model_alphaf, dim=[-2, -1])

        # r_max = torch.max(response)
        # cv2.imshow('response', response[0, 0].data.cpu().numpy())
        # cv2.waitKey(0)

        return response
Пример #6
0
def fft_conv(
    signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
) -> Tensor:
    """Performs N-d convolution of Tensors using a fast fourier transform, which
    is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
    the convolution (in order ot mimic the PyTorch direct convolution).

    Args:
        signal: (Tensor) Input tensor to be convolved with the kernel.
        kernel: (Tensor) Convolution kernel.
        bias: (Optional, Tensor) Bias tensor to add to the output.
        padding: (int) Number of zero samples to pad the input on the last dimension.

    Returns:
        (Tensor) Convolved tensor
    """
    # Pad the input signal & kernel tensors
    signal_padding = (signal.ndim - 2) * [padding, padding]
    signal = f.pad(signal, signal_padding)
    kernel_padding = [
        pad for i in reversed(range(2, signal.ndim))
        for pad in [0, signal.size(i) - kernel.size(i)]
    ]
    padded_kernel = f.pad(kernel, kernel_padding)

    # Perform fourier convolution -- FFT, matrix multiply, then IFFT
    signal_fr = rfftn(signal, dim=tuple(range(2, signal.ndim)))
    kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))

    kernel_fr.imag *= -1
    output_fr = complex_matmul(signal_fr, kernel_fr)
    output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

    # Remove extra padded values
    crop_slices = [slice(0, output.shape[0]), slice(0, output.shape[1])] + [
        slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in range(2, signal.ndim)
    ]
    output = output[crop_slices].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
        output += bias.view(bias_shape)

    return output
Пример #7
0
def fft_new(z, x, label):
    zf = fft.rfftn(z, dim=[-2, -1])
    xf = fft.rfftn(x, dim=[-2, -1])

    # R[batch, 1, 121, 61]
    kzzf = torch.sum(torch.real(zf)**2 + torch.imag(zf)**2,
                     dim=1,
                     keepdim=True)

    # C[batch, 1, 121, 61]
    t = xf * torch.conj(zf)
    kxzf = torch.sum(t, dim=1, keepdim=True)

    # C[batch, 1, 121, 61, 2]
    alphaf = label.to(device=z.device) / (kzzf + lambda0)

    # R[batch, 1, 121, 121]
    return fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1])
Пример #8
0
    def forward(self, template, search, label):
        # Template shape: R[batch, 32, 121, 121]
        # Search shape:   R[batch, 32, 121, 121]
        # Label shape:    R[batch,  1, 121,  61]

        # zf & xf shape: C[batch, 32, 121, 61]
        zf = fft.rfftn(template, dim=[-2, -1])
        xf = fft.rfftn(search, dim=[-2, -1])

        # R[batch, 1, 121, 61]
        kzzf = torch.sum(zf.real**2 + zf.imag**2, dim=1, keepdim=True)

        # C[batch, 1, 121, 61]
        t = xf * torch.conj(zf)
        kxzf = torch.sum(t, dim=1, keepdim=True)

        # C[batch, 1, 121, 61]
        alphaf = label.to(device=template.device) / (kzzf + self.lambda0)

        # R[batch, 1, 121, 121]
        response = fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1])
        return response
Пример #9
0
def _fft_conv_transposend(
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    stride: Tuple[int],
    padding: Tuple[int],
    output_padding: Tuple[int],
    groups: int,
    dilation: Tuple[int],
) -> Tensor:

    output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:],
                                        stride, padding, output_padding,
                                        dilation)
    padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding))

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_output_size, weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size // st)
        weight_s.append(s_size // d)

    X = rfft(input, n=s[-1])
    W = rfft(weight, n=weight_s[-1])

    if stride[-1] > 1:
        X_neg_freq = X.flip(-1)[..., 1:]
        X_neg_freq.imag.mul_(-1)

        tmp = [X]
        for i in range(1, stride[-1]):
            if i % 2:
                tmp.append(X_neg_freq)
            else:
                tmp.append(X[..., 1:])

        X = torch.cat(tmp, -1)

    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(s) > 1:
        X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1)))
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + stride[:-1] + (1, )
        if sum(repeats) > X.ndim:
            X = X.repeat(*repeats)

        repeats = (1, 1) + dilation[:-1] + (1, )
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)

    Y = _complex_matmul(X, W, groups, True)

    output = irfftn(Y, dim=tuple(range(2, Y.ndim)))

    # Remove extra padded values
    index = (slice(None), ) * 2 + tuple(
        slice(p, o + p) for p, o in zip(padding, output_size))
    output = output[index].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
Пример #10
0
def fft_conv(
    signal: Tensor,
    kernel: Tensor,
    bias: Tensor = None,
    padding: Union[int, Iterable[int]] = 0,
    stride: Union[int, Iterable[int]] = 1,
    groups: int = 1,
) -> Tensor:
    """Performs N-d convolution of Tensors using a fast fourier transform, which
    is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
    the convolution (in order ot mimic the PyTorch direct convolution).

    Args:
        signal: (Tensor) Input tensor to be convolved with the kernel.
        kernel: (Tensor) Convolution kernel.
        bias: (Tensor) Bias tensor to add to the output.
        padding: (Union[int, Iterable[int]) Number of zero samples to pad the
            input on the last dimension.
        stride: (Union[int, Iterable[int]) Stride size for computing output values.

    Returns:
        (Tensor) Convolved tensor
    """
    # Cast padding & stride to tuples.
    padding_ = to_ntuple(padding, n=signal.ndim - 2)
    stride_ = to_ntuple(stride, n=signal.ndim - 2)

    # Pad the input signal & kernel tensors
    signal_padding = [p for p in padding_[::-1] for _ in range(2)]
    signal = f.pad(signal, signal_padding)

    # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
    # have *even* length.  Just pad with one more zero if the final dimension is odd.
    if signal.size(-1) % 2 != 0:
        signal_ = f.pad(signal, [0, 1])
    else:
        signal_ = signal

    kernel_padding = [
        pad for i in reversed(range(2, signal_.ndim))
        for pad in [0, signal_.size(i) - kernel.size(i)]
    ]
    padded_kernel = f.pad(kernel, kernel_padding)

    # Perform fourier convolution -- FFT, matrix multiply, then IFFT
    # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
    signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
    kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))

    kernel_fr.imag *= -1
    output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
    output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

    # Remove extra padded values
    crop_slices = [slice(
        0, output.size(0)), slice(0, output.size(1))] + [
            slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2])
            for i in range(2, signal.ndim)
        ]
    output = output[crop_slices].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
        output += bias.view(bias_shape)

    return output
Пример #11
0
def fft_conv(
    signal: Tensor,
    kernel: Tensor,
    bias: Tensor = None,
    padding: Union[int, Iterable[int], str] = 0,
    stride: Union[int, Iterable[int]] = 1,
    groups: int = 1,
    padding_mode: str = "constant",
) -> Tensor:
    """Performs N-d convolution of Tensors using a fast fourier transform, which
    is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
    the convolution (in order ot mimic the PyTorch direct convolution).
    Args:
        signal: (Tensor) Input tensor to be convolved with the kernel.
        kernel: (Tensor) Convolution kernel.
        bias: (Tensor) Bias tensor to add to the output.
        padding: (Union[int, Iterable[int]) Number of zero samples to pad the
            input on the last dimension.
        stride: (Union[int, Iterable[int]) Stride size for computing output values.
        groups: (Union[int, Iterable[int]])
        padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
                      reflection not available for 3d.
    Returns:
        (Tensor) Convolved tensor
    """
    # Cast stride to tuple.
    stride_ = to_ntuple(stride, n=signal.ndim - 2)

    if padding != "same":
        padding_ = to_ntuple(padding, n=signal.ndim - 2)
        signal_padding = [p for p in padding_[::-1] for _ in range(2)]
    else:
        # signal_padding = [
        #     (0, 0) if k <= s else ((k - s) // 2, k - (k - s) // 2)
        #     for s, k, in zip(signal.shape[2:], kernel.shape[2:])
        # ]
        # signal_padding = [p for pd in signal_padding[::-1] for p in pd]
        padding_ = [k // 2 for k in kernel.shape[2:]]

    signal_padding = [p for p in padding_[::-1] for _ in range(2)]
    # Pad the input signal & kernel tensors
    signal = f.pad(signal, signal_padding, mode=padding_mode)

    # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
    # have *even* length.  Just pad with one more zero if the final dimension is odd.
    signal_size = signal.size()  # original signal size without padding to even
    if signal.size(-1) % 2 != 0:
        signal = f.pad(signal, [0, 1])

    kernel_padding = [
        pad for i in reversed(range(2, signal.ndim))
        for pad in [0, signal.size(i) - kernel.size(i)]
    ]

    padded_kernel = f.pad(kernel, kernel_padding)
    assert (
        padded_kernel.shape[1:] == signal.shape[1:]
    ), f"padded kernel shape {padded_kernel.shape} not equal to signal shape {signal.shape}"

    # Perform fourier convolution -- FFT, matrix multiply, then IFFT
    # signal = signal.reshape(signal.size(0), groups, -1, *signal.shape[2:])
    signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
    kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))

    kernel_fr.imag *= -1
    output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
    output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

    # Remove extra padded values
    crop_slices = [slice(None), slice(None)] + [
        slice(
            0,
            (signal_size[i] - kernel.size(i) + (kernel.size(i) % 2)),
            # if padding != "same"
            # else None,
            stride_[i - 2],
        ) for i in range(2, signal.ndim)
    ]
    output = output[crop_slices].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
        output += bias.view(bias_shape)

    return output
Пример #12
0
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params):
    """Computes regularization filter in CCOT and ECO."""

    if not params.use_reg_window:
        return params.reg_window_min * torch.ones(1, 1, 1, 1)

    if getattr(params, 'reg_window_square', False):
        target_sz = target_sz.prod().sqrt() * torch.ones(2)

    # Normalization factor
    reg_scale = 0.5 * target_sz

    # Construct grid
    if getattr(params, 'reg_window_centered', True):
        wrg = torch.arange(-int((sz[0] - 1) / 2),
                           int(sz[0] / 2 + 1),
                           dtype=torch.float32).view(1, 1, -1, 1)
        wcg = torch.arange(-int((sz[1] - 1) / 2),
                           int(sz[1] / 2 + 1),
                           dtype=torch.float32).view(1, 1, 1, -1)
    else:
        wrg = torch.cat([
            torch.arange(0, int(sz[0] / 2 + 1), dtype=torch.float32),
            torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32)
        ]).view(1, 1, -1, 1)
        wcg = torch.cat([
            torch.arange(0, int(sz[1] / 2 + 1), dtype=torch.float32),
            torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32)
        ]).view(1, 1, 1, -1)

    # Construct regularization window
    reg_window = (params.reg_window_edge - params.reg_window_min) * \
                 (torch.abs(wrg / reg_scale[0]) ** params.reg_window_power +
                  torch.abs(wcg / reg_scale[1]) ** params.reg_window_power) + params.reg_window_min

    # Compute DFT and enforce sparsity
    reg_window_dft = torch.view_as_real(
        torch_fft.rfftn(reg_window, dim=[-2, -1])) / sz.prod()
    reg_window_dft_abs = complex.abs(reg_window_dft)
    reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold *
                   reg_window_dft_abs.max(), :] = 0

    # Do the inverse transform to correct for the window minimum

    reg_window_sparse = torch_fft.irfftn(torch.view_as_complex(reg_window_dft),
                                         s=sz.long().tolist(),
                                         dim=[-2, -1])
    reg_window_dft[
        0, 0, 0, 0,
        0] += params.reg_window_min - sz.prod() * reg_window_sparse.min()
    reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft))

    # Remove zeros
    max_inds, _ = reg_window_dft.nonzero(as_tuple=False).max(dim=0)
    mid_ind = int((reg_window_dft.shape[2] - 1) / 2)
    top = max_inds[-2].item() + 1
    bottom = 2 * mid_ind - max_inds[-2].item()
    right = max_inds[-1].item() + 1
    reg_window_dft = reg_window_dft[..., bottom:top, :right]
    if reg_window_dft.shape[-1] > 1:
        reg_window_dft = torch.cat(
            [reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1)

    return reg_window_dft
Пример #13
0
def fft_conv(
    signal: Tensor,
    kernel: Tensor,
    bias: Tensor = None,
    padding: Union[int, Iterable[int]] = 0,
    stride: Union[int, Iterable[int]] = 1,
    groups: int = 1,
) -> Tensor:
    """Performs N-d convolution of Tensors using a fast fourier transform, which
    is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
    the convolution (in order ot mimic the PyTorch direct convolution).

    Args:
        signal: (Tensor) Input tensor to be convolved with the kernel.
        kernel: (Tensor) Convolution kernel.
        bias: (Tensor) Bias tensor to add to the output.
        padding: (Union[int, Iterable[int]) Number of zero samples to pad the
            input on the last dimension.
        stride: (Union[int, Iterable[int]) Stride size for computing output values.

    Returns:
        (Tensor) Convolved tensor
    """
    # Cast padding & stride to tuples.
    # st = time.time()
    # padding_ = to_ntuple(padding, n=signal.ndim - 2)
    # stride_ = to_ntuple(stride, n=signal.ndim - 2)

    padding_ = padding

    stride_ = (1, 1)
    # print('padding_:', padding_)
    # print('stride_:', stride_)
    # padding_time = time.time() - st
    # print('padding_time:', padding_time)
    # Pad the input signal & kernel tensors
    signal_padding = [p for p in padding_[::-1] for _ in range(2)]
    signal = f.pad(signal, signal_padding)

    # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
    # have *even* length.  Just pad with one more zero if the final dimension is odd.
    if signal.size(-1) % 2 != 0:
        signal_ = f.pad(signal, [0, 1])
    else:
        signal_ = signal

    # st = time.time()
    kernel_padding = [
        pad for i in reversed(range(2, signal_.ndim))
        for pad in [0, signal_.size(i) - kernel.size(i)]
    ]
    # print(kernel_padding)
    # print(kernel.shape)
    padded_kernel = f.pad(kernel, kernel_padding)
    # padding_time_kernel = time.time() - st
    # print('padding_time_kernel:', padding_time_kernel)
    # Perform fourier convolution -- FFT, matrix multiply, then IFFT
    # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])

    # st = time.time()
    # signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
    # kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))

    signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
    kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))
    # rfft_time = time.time() - st
    # print('rfft_time:', rfft_time)
    # print('Line: padded signal shape:', signal_.shape)
    # print('Line: signal_ shape:', signal_.shape)
    # print('Line: padded_kernel shape:', padded_kernel.shape)
    # print('Line: kernel_fr shape:', kernel_fr.shape)

    # st = time.time()
    kernel_fr.imag *= -1

    # output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:])
    # print('KOOOOME:', output_fr.shape)
    # print('KOOOOME:', output_fr)
    # x = signal_fr[0,:]
    # print(x)
    # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:])
    # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups)
    output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
    output = output_fr
    # print('output.shape:', output.shape)

    # matmul_time = time.time() - st
    # print('matmul_time:', matmul_time)

    # st = time.time()
    output = irfftn(output_fr, dim=tuple(range(4, signal.ndim + 2)))
    # print('output irfftn .shape:', output.shape)
    # output = ifftn(output_fr, dim=tuple(range(3, signal.ndim+1)))
    # inverse_time = time.time() - st
    # print('inverse_time:', inverse_time)
    # st = time.time()
    # Remove extra padded values
    # print('signal, kernel, padding', signal.shape, kernel.shape, padding)
    crop_slices = [
        slice(0, output.size(0)),
        slice(0, output.size(1)),
        slice(0, output.size(2)),
        slice(0, output.size(3))
    ] + [
        slice(padding_[i - 3] - 1,
              (signal.size(i - 1) - kernel.size(i - 1) - padding_[i - 3] + 2),
              stride_[i - 3]) for i in range(3, signal.ndim + 1)
    ]
    # crop_slices =
    # print('crop_slices:', crop_slices)
    # print('my output before croping:', output.shape)
    output = output[crop_slices].contiguous()
    # output = output[:,:,:,1:].contiguous()
    # print('output after crop:', output.shape)
    # print('output norm 2:', torch.norm(output))
    # crop_time = time.time() - st
    # print('crop_time:', crop_time)
    # Optionally, add a bias term before returning.
    if bias is not None:
        bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
        output += bias.view(bias_shape)

    return output