示例#1
0
文件: ffts.py 项目: aisari/torchsar
def ifft(x, n=None, axis=0, norm="backward", shift=False):
    """IFFT in torchsar

    IFFT in torchsar, since ifft in torch only supports complex-complex transformation,
    for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))),
    also you can use torch's rifft.

    Parameters
    ----------
    x : {torch array}
        both complex and real representation are supported. Since torch does not
        support complex array, when :attr:`x` is complex, we will change the representation
        in real formation(last dimension is 2, real, imag), after IFFT, it will be change back.
    n : int, optional
        number of ifft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of ifft (the default is 0, which the first dimension)
    norm : bool, optional
        Normalization mode. For the backward transform (ifft()), these correspond to:
        - "forward" - no normalization
        - "backward" - normalize by ``1/n`` (default)
        - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        ifft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    if shift:
        y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis),
                                       n=n,
                                       dim=axis,
                                       norm=norm),
                            dim=axis)
    else:
        y = thfft.ifft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
示例#2
0
    def forward(self, class_one, class_two):
        batch_size, depth, height, width = class_one.size()

        class_one_flat = class_one.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim1)
        class_two_flat = class_two.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim2)

        sketch_1 = class_one_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = class_two_flat.mm(self.sparse_sketch_matrix2)

        fft1_real = afft.fft(sketch_1)
        fft1_imag = afft.fft(Variable(torch.zeros(sketch_1.size())).cuda())
        fft2_real = afft.fft(sketch_2)
        fft2_imag = afft.fft(Variable(torch.zeros(sketch_2.size())).cuda())

        fft_product_real = fft1_real.mul(fft2_real) - fft1_imag.mul(fft2_imag)
        fft_product_imag = fft1_real.mul(fft2_imag) + fft1_imag.mul(fft2_real)

        cbp_flat = afft.ifft(fft_product_real)
        #cbd_flat_2 = afft.ifft(fft_product_imag)[0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=1).sum(dim=1)

        return cbp.float()
def _istft(x,
           n_fft,
           ola_weight,
           win_length,
           window,
           hop_length,
           center,
           normalized,
           onesided,
           pad_mode,
           return_complex,
           norm_envelope=None):
    """
    A helper function to do istft.
    """
    if onesided:
        x = fft.irfft(x,
                      n=n_fft,
                      dim=-2,
                      norm='ortho' if normalized else 'backward')
    else:
        x = fft.ifft(x,
                     n=n_fft,
                     dim=-2,
                     norm='ortho' if normalized else 'backward').real

    x, norm_envelope = _ola(x,
                            hop_length,
                            ola_weight,
                            padding=n_fft // 2 if center else 0,
                            norm_envelope=norm_envelope)
    return x, norm_envelope
示例#4
0
def hilbert(x, ndft=None):
    r"""Analytic signal of x.

    Return the analytic signal of a real signal x, x + j\hat{x}, where \hat{x}
    is the Hilbert transform of x.

    Parameters
    ----------
    x: torch.Tensor
        Audio signal to be analyzed.
        Always assumes x is real, and x.shape[-1] is the signal length.

    Returns
    -------
    out: torch.Tensor
        out.shape == (*x.shape, 2)

    """
    if ndft is None:
        sig = x
    else:
        assert ndft > x.size(-1)
        sig = F.pad(x, (0, ndft - x.size(-1)))
    xspec = fft.fft(sig)
    siglen = sig.size(-1)
    hh = torch.zeros(siglen, dtype=sig.dtype, device=sig.device)
    if siglen % 2 == 0:
        hh[0] = hh[siglen // 2] = 1
        hh[1:siglen // 2] = 2
    else:
        hh[0] = 1
        hh[1:(siglen + 1) // 2] = 2

    return fft.ifft(xspec * hh)
示例#5
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))

            # For 1.8+
            # pytorch now offers a complex64 data type
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            # norm=forward means no normalization
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # Plot images to confirm fft worked
            # t_img = complex_magnitude(target)
            # print(t_img.dtype, t_img.shape)
            # plt.imshow(t_img)
            # plt.show()
            # plt.imshow(target.real)
            # plt.show()

            # center crop and resize
            # target = torch.unsqueeze(target, dim=0)
            # target = center_crop(target, (128, 128))
            # target = torch.squeeze(target)

            # Crop out ends
            target = np.stack([target.real, target.imag], axis=-1)
            target = target[100:-100, 24:-24, :]

            # Downsample in image space
            shape = target.shape
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Get kspace of cropped image
            target = torch.view_as_complex(torch.from_numpy(target))
            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            # Realign kspace to keep high freq signal in center
            # Note that original fastmri code did not do this...
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
示例#6
0
def _hilbert_transform(x):
    if torch.is_complex(x):
        raise ValueError("x must be real.")

    N = x.shape[-1]
    f = fft(x, N, dim=-1)
    h = torch.zeros_like(f)
    if N % 2 == 0:
        h[..., 0] = h[..., N // 2] = 1
        h[..., 1:N // 2] = 2
    else:
        h[..., 0] = 1
        h[..., 1:(N + 1) // 2] = 2

    return ifft(f * h, dim=-1)
示例#7
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # transform
            target = torch.stack([target.real, target.imag])
            target = self.deform(target)  # outputs numpy
            target = torch.from_numpy(target)
            target = target.permute(1, 2, 0).contiguous()
            # center crop and resize
            # target = center_crop(target, (128, 128))
            # target = resize(target, (128,128))

            # Crop out ends
            target = target.numpy()[100:-100, 24:-24, :]
            # Downsample in image space
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Making contiguous is necessary for complex view
            target = torch.from_numpy(target)
            target = target.contiguous()
            target = torch.view_as_complex(target)

            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
示例#8
0
            output = fft_fn(output, 1)
            output = utils.movedim(output, -2, -j - ndim - 1)
        if norm == 'forward':
            output /= py.prod(s)

        # Make complex and move back dimension to its original position
        newdim = list(range(-ndim - 1, -1))
        output = utils.movedim(output, newdim, dim)
        if _torch_has_complex:
            output = torch.view_as_complex(output)

        return output


if _torch_has_fft_module:
    ifft = lambda *a, real=None, **k: fft_mod.ifft(*a, **k)
else:

    def ifft(input, n=None, dim=-1, norm='backward', real=None):
        """One dimensional discrete inverse Fourier transform.

        Parameters
        ----------
        input : tensor
            Input signal.
            If torch <= 1.5, the last dimension must be of length 2 and
            contain the real and imaginary parts of the signal, unless
            `real is True`.
        n : int, optional
            Signal length. If given, the input will either be zero-padded
            or trimmed to this length before computing the FFT.
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor],
                stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int],
                groups: int) -> Tensor:

    output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride,
                              padding, dilation)
    reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2)
    padded_input = F.pad(input, reversed_padding_repeated_twice)

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size)
        weight_s.append(s_size // d)

    X = rfftn(padded_input, s=s)

    W = rfft(weight, n=weight_s[-1])
    # handle dilation
    # handle dilation for last dim
    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(weight_s) > 1:
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + dilation[:-1] + (1, )
        W.imag.mul_(-1)
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)
    else:
        W.imag.mul_(-1)

    Y = _complex_matmul(X, W, groups)

    # handle stride
    if len(stride) > 1:
        for i, st in enumerate(stride[:-1]):
            if st > 1:
                Y = Y.reshape(*Y.shape[:i + 2], st, -1,
                              *Y.shape[i + 3:]).mean(i + 2)

            Y = ifft(Y, dim=i + 2)
            Y = Y.as_strided(
                Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:],
                Y.stride())

    if stride[-1] > 1:
        n_fft = Y.size(-1) * 2 - 2
        new_n_fft = n_fft // stride[-1]
        step_size = new_n_fft // 2
        strided_Y_size = step_size + 1

        unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
        unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2,
                                                 step_size)
        Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(
            -2), unfolded_Y_imag[..., ::2, :].sum(-2)
        Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip(
            -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1)

        Y_real = Y_pos_real.add_(Y_neg_real)
        Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1)
        Y_imag = F.pad(Y_imag, [1, 1])

        Y = torch.view_as_complex(torch.stack((Y_real, Y_imag),
                                              -1)).div_(stride[-1])

    output = irfft(Y)

    # Remove extra padded values
    output = output[..., :output_size[-1]].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
示例#10
0
def _ft_surrogate(x=None, f=None, eps=1, random_state=None):
    """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.

    Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
    and modified.

    MIT License

    Copyright (c) 2018 Clifford Lab

    Permission is hereby granted, free of charge, to any person obtaining a
    copy of this software and associated documentation files (the "Software"),
    to deal in the Software without restriction, including without limitation
    the rights to use, copy, modify, merge, publish, distribute, sublicense,
    and/or sell copies of the Software, and to permit persons to whom the
    Software is furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
    DEALINGS IN THE SOFTWARE.

    Parameters
    ----------
    x: torch.tensor, optional
        Single EEG channel signal in time space. Should not be passed if f is
        given. Defaults to None.
    f: torch.tensor, optional
        Fourier spectrum of a single EEG channel signal. Should not be passed
        if x is given. Defaults to None.
    eps: float, optional
        Float between 0 and 1 setting the range over which the phase
        pertubation is uniformly sampled: [0, `eps` * 2 * `pi`]. Defaults to 1.
    random_state: int | numpy.random.Generator, optional
        By default None.

    References
    ----------
    .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
       Clifford, G. D. (2018). Addressing Class Imbalance in Classification
       Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
       preprint arXiv:1806.08675.
    """
    assert isinstance(
        eps,
        (Real, torch.FloatTensor, torch.cuda.FloatTensor)
    ) and 0 <= eps <= 1, f"eps must be a float beween 0 and 1. Got {eps}."
    if f is None:
        assert x is not None, 'Neither x nor f provided.'
        f = fft(x.double(), dim=-1)
        device = x.device
    else:
        device = f.device
    n = f.shape[-1]
    random_phase = _new_random_fft_phase[n % 2](
        n,
        device=device,
        random_state=random_state
    )
    f_shifted = f * torch.exp(eps * random_phase)
    shifted = ifft(f_shifted, dim=-1)
    return shifted.real.float()
示例#11
0
 def postprocess(self, X):
     Z = ifft(X, norm='ortho', dim=2)
     return F.normalize(Z).real
示例#12
0
def morlet_fft_convolution(x, log_scales, dtime, unpadding=0, density=True, omega0=6.0):
    """
    Calculates a Morlet continuous wavelet transform
    for a given signal across a range of frequencies

    Parameters:
    ===========
    x : torch.Tensor, shape (batch, channels, sequence)
        A batch of multichannel sequences
    log_scales : torch.Tensor, shape (1, 1, freqs, 1)
        A tensor of logarithmic scales
    dtime : float
        Change in time per sample. The inverse of the sampling frequency.
    unpadding : int
        The amount of padding to remove from each side of the sequence
    density : bool (default = True)
        Whether to normalize so the power spectrum sums to one.
        This effectively removes amplitude fluctuations.
    omega0 : float
        Dimensionless omega0 parameter for wavelet transform
    Returns:
    ========
    out : torch.Tensor, shape (batch, channels, freqs, sequence)
        The transformed signal.
    """

    n_sequence = x.shape[-1]

    # Pad with extra zero if needed
    pad_sequence = n_sequence % 2 != 0
    x = F.pad(x, (0, 1)) if pad_sequence else x

    # Set index to remove the extra zero if added
    idx0 = (n_sequence // 2) + unpadding
    idx1 = (
        (n_sequence // 2) + n_sequence - unpadding - 1
        if pad_sequence
        else (n_sequence // 2) + n_sequence - unpadding
    )
    x = F.pad(x, (n_sequence // 2, n_sequence // 2))

    # (batch, channels, sequence) -> (batch, channels, freqs, sequence)
    x = x.unsqueeze(-2)

    # Calculate the omega values
    n_padded = x.shape[-1]
    omegas = (
        -2
        * np.pi
        * torch.arange(start=-n_padded // 2, end=n_padded // 2, device=x.device)
        / (n_padded * dtime)
    )[
        None, None, None
    ]  # (sequence,) -> (batch, channels, freqs, sequence)

    # Fourier transform the padded signal
    x_hat = fftshift1d(fft.fft(x, dim=-1))

    # Calculate the wavelets
    morlet = morlet_conj_ft(omegas * log_scales.exp(), omega0)

    # Perform the wavelet transform
    convolved = (
        fft.ifft(morlet * x_hat, dim=-1)[..., idx0:idx1] * log_scales.mul(0.5).exp()
    )

    power = convolved.abs()

    # scale power to account for disproportionally
    # large response at low frequencies
    # power_scale = (
    #    np.pi ** -0.25
    #    * np.exp(0.25 * (omega0 - np.sqrt(omega0 ** 2 + 2)) ** 2)
    #    / scales.mul(2).sqrt()
    # )
    log_power_scale = (
        -0.25 * LOG_PI
        + 0.25 * (omega0 - np.sqrt(omega0 ** 2 + 2)) ** 2
        - log_scales.add(LOG2).mul(0.5)
    )
    log_power_scaled = power.log() + log_power_scale
    log_total_power = log_power_scaled.logsumexp(
        (1, 2), keepdims=True
    )  # (channels, freqs)
    log_density = log_power_scaled - log_total_power
    return log_density.exp()
 def irfft(x):
     return fft.ifft(x,
                     n=n_fft,
                     dim=-2,
                     norm='ortho' if normalized else 'backward').real
示例#14
0
    def split_step_solver(self, a, T, d, c, L):
        '''
        Parameters
        ----------
        a : TYPE: torch.int64 tensor of shape [batch_size, seq_len], optional
            DESCRIPTION: Pulse amplitude set.
        T : TYPE: float
            DESCRIPTION: Pulse width.
        d : TYPE: float
            DESCRIPTION: dispersion coefficient.
        c : TYPE: float
            DESCRIPTION: nonlinearity coefficient.
        L : TYPE: float
            DESCRIPTION: End of transmission line (z_end).

        Returns
        -------
        t : TYPE: torch.float32 tensor of shape [dim_t]
            DESCRIPTION: Time points. See prepare_data description.
        z : TYPE: torch.float32 tensor of shape [dim_z]
            DESCRIPTION: Points in space. See prepare_data description.
        u : TYPE: torch.complex128 tensor of shape [batch_size, dim_z, dim_t].
            DESCRIPTION: Output of the split-step solution.
        '''
        z = torch.linspace(0, L, self.dim_z, device=self.device)

        tMax = self.t_end + 4 * sqrt(2 * (1 + L**2))
        tMin = -tMax

        dt = (tMax - tMin) / self.dim_t
        t = torch.linspace(tMin, tMax - dt, self.dim_t, device=self.device)

        # prepare frequencies
        dw = 2 * pi / (tMax - tMin)
        w = dw * torch.cat(
            (torch.arange(0, self.dim_t / 2 + 1, device=self.device),
             torch.arange(-self.dim_t / 2 + 1, 0, device=self.device)))

        # prepare linear propagator
        LP = torch.exp(-1j * d * self.dz / 2 * w**2)

        # Set initial condition
        u = torch.zeros(a.shape[0],
                        self.dim_z,
                        self.dim_t,
                        dtype=self.complex_type,
                        device=self.device)

        buf = self.Etanal(t, torch.tensor(0, device=self.device), a, T)
        u[:, 0, :] = buf

        n = 0
        # Numerical integration (split-step)
        for i in range(1, int(L / self.dz) + 1):
            buf = ifft(LP * fft(buf))
            buf = buf * torch.exp(1j * c * self.dz * buf.abs()**2)
            buf = ifft(LP * fft(buf))

            if i % self.z_stride == 0:
                n += 1
                u[:, n, :] = buf

        # Dispersion compensation procedure (back propagation D**(-1))
        if self.dispersion_compensate:
            zw = torch.mm(z.view(z.shape[-1], 1), w.view(1, w.shape[-1])**2)
            u = ifft(torch.exp(1j * d * zw) * fft(u))

        return t, z, u
def cwt(
    data: torch.Tensor,
    scales: Union[np.ndarray, torch.Tensor],  # type: ignore
    wavelet: Union[ContinuousWavelet, str],
    sampling_period: float = 1.0,
) -> Tuple[torch.Tensor, np.ndarray]:  # type: ignore
    """Compute the single dimensional continuous wavelet transform.

    This function is a PyTorch port of pywt.cwt as found at:
    https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py

    Args:
        data (torch.Tensor): The input tensor of shape [batch_size, time].
        scales (torch.Tensor or np.array):
            The wavelet scales to use. One can use
            ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine
            what physical frequency, ``f``. Here, ``f`` is in hertz when the
            ``sampling_period`` is given in seconds.
            wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.
        wavelet (ContinuousWavelet or str): The continuous wavelet to work with.
        sampling_period (float): Sampling period for the frequencies output (optional).
            The values computed for ``coefs`` are independent of the choice of
            ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
            period).

    Raises:
        ValueError: If a scale is too small for the input signal.

    Returns:
        Tuple[torch.Tensor, np.ndarray]: A tuple with the transformation matrix
            and frequencies in this order.
    """
    # accept array_like input; make a copy to ensure a contiguous array
    if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
        wavelet = DiscreteContinuousWavelet(wavelet)
    if type(scales) is torch.Tensor:
        scales = scales.numpy()
    elif np.isscalar(scales):
        scales = np.array([scales])
    # if not np.isscalar(axis):
    #    raise np.AxisError("axis must be a scalar.")

    precision = 10
    int_psi, x = integrate_wavelet(wavelet, precision=precision)
    if type(wavelet) is ContinuousWavelet:
        int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
    int_psi = torch.tensor(int_psi, device=data.device)

    # convert int_psi, x to the same precision as the data
    x = np.asarray(x, dtype=data.cpu().numpy().real.dtype)

    size_scale0 = -1
    fft_data = None

    out = []
    for scale in scales:
        step = x[1] - x[0]
        j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
        j = j.astype(int)  # floor
        if j[-1] >= len(int_psi):
            j = np.extract(j < len(int_psi), j)
        int_psi_scale = int_psi[j].flip(0)

        # The padding is selected for:
        # - optimal FFT complexity
        # - to be larger than the two signals length to avoid circular
        #   convolution
        size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1)
        if size_scale != size_scale0:
            # Must recompute fft_data when the padding size changes.
            fft_data = fft(data, size_scale, dim=-1)
        size_scale0 = size_scale
        fft_wav = fft(int_psi_scale, size_scale, dim=-1)
        conv = ifft(fft_wav * fft_data, dim=-1)
        conv = conv[..., :data.shape[-1] + len(int_psi_scale) - 1]

        coef = -np.sqrt(scale) * torch.diff(conv, dim=-1)

        # transform axis is always -1
        d = (coef.shape[-1] - data.shape[-1]) / 2.0
        if d > 0:
            coef = coef[..., int(np.floor(d)):-int(np.ceil(d))]
        elif d < 0:
            raise ValueError("Selected scale of {} too small.".format(scale))

        out.append(coef)
    out_tensor = torch.stack(out)
    if type(wavelet) is Wavelet:
        out_tensor = out_tensor.real
    else:
        out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real

    frequencies = scale2frequency(wavelet, scales, precision)
    if np.isscalar(frequencies):
        frequencies = np.array([frequencies])
    frequencies /= sampling_period
    return out_tensor, frequencies