コード例 #1
0
def spmat_interp_adjoint(
    data: Tensor,
    interp_mats: Union[Tensor, Tuple[Tensor, Tensor]],
    grid_size: Tensor,
) -> Tensor:
    """Sparse matrix interpolation adjoint backend."""
    if not isinstance(interp_mats, tuple):
        raise TypeError("interp_mats must be 2-tuple of (real_mat, imag_mat.")

    coef_mat_real, coef_mat_imag = interp_mats
    batch_size, num_coils = data.shape[:2]

    # sparse matrix multiply requires real
    data = torch.view_as_real(data)
    output_size = [batch_size, num_coils] + grid_size.tolist()

    # we have to do these transposes because torch.mm requires first to be spmatrix
    real_kdat = data.select(-1, 0).view(-1, data.shape[-2]).t().contiguous()
    imag_kdat = data.select(-1, 1).view(-1, data.shape[-2]).t().contiguous()
    coef_mat_real = coef_mat_real.t()
    coef_mat_imag = coef_mat_imag.t()

    # apply multiplies with complex conjugate
    image = torch.stack(
        [
            (torch.mm(coef_mat_real, real_kdat) +
             torch.mm(coef_mat_imag, imag_kdat)).t(),
            (torch.mm(coef_mat_real, imag_kdat) -
             torch.mm(coef_mat_imag, real_kdat)).t(),
        ],
        dim=-1,
    )

    return torch.view_as_complex(image).reshape(*output_size)
コード例 #2
0
ファイル: dct.py プロジェクト: Helmholtz-AI-Energy/HyDe
def dct(x: torch.Tensor, norm: str = "None"):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = torch.tensor(x_shape[-1], device=x.device)
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    # if TORCH_VER >= 1.8:
    Vc = torch.view_as_real(torch.fft.fft(v))
    # else:
    #     Vc = torch.rfft(v, 1, onesided=False)

    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (
        2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == "ortho":
        V[:, 0] /= torch.sqrt(4 * N)  # torch.sqrt(N) * 2  # np.sqrt(N) * 2
        V[:,
          1:] /= torch.sqrt(2 * N)  # torch.sqrt(N) * 2  # np.sqrt(N / 2) * 2

    V = 2 * V.view(x_shape)

    return V
コード例 #3
0
    def forward(self, x: Tensor) -> Tensor:
        """STFT forward path
        Args:
            x (Tensor): audio waveform of
                shape (nb_samples, nb_channels, nb_timesteps)
        Returns:
            STFT (Tensor): complex stft of
                shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2)
                last axis is stacked real and imaginary
        """

        shape = x.size()
        nb_samples, nb_channels, nb_timesteps = shape

        # pack batch
        x = x.view(-1, shape[-1])

        complex_stft = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.n_hop,
            window=self.window,
            center=self.center,
            normalized=False,
            onesided=True,
            pad_mode="reflect",
            return_complex=True,
        )
        stft_f = torch.view_as_real(complex_stft)
        # unpack batch
        stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:])
        return stft_f
コード例 #4
0
    def grid_1d_from_2d(self, x, dx, vis, y):

        nvis, N = x.shape
        W = self.config['W']
        Dx = W * dx

        xref = torch.ceil((x - 0.5 * W * dx) / dx) * dx
        xndx = torch.arange(W, dtype=xref.dtype, device=xref.device)
        xg = xref + xndx * dx

        gcf_val = _gcf_kaiser(xg - x, Dx, self.beta).float()

        # Batch mm unsupported for complex yet for torch CUDA
        vis_ri = torch.view_as_real(vis)
        vis_r = vis_ri[:, :, -2]
        vis_i = vis_ri[:, :, -1]
        vis2_r = torch.matmul(vis_r[:, :, None], gcf_val[:, None, :])
        vis2_i = torch.matmul(vis_i[:, :, None], gcf_val[:, None, :])
        vis2 = torch.view_as_complex(torch.stack([vis2_r, vis2_i], dim=-1))

        # vis2 = torch.matmul(vis[:, :, None], gcf_val[:, None, :])
        vis2 = vis2.reshape(nvis, -1)
        x2 = xg.repeat(1, N)
        y2 = torch.repeat_interleave(y, W, axis=-1)
        return x2, vis2, y2
コード例 #5
0
def _handle_complex(tensor):
    """
    Returns a real view of a tensor if complex dtype else just the tensor
    need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
    """
    return torch.view_as_real(tensor) if not isinstance(tensor,
                                                        torch.nn.UninitializedParameter) and tensor.is_complex() else tensor
コード例 #6
0
        def _test_all_gather(self,
                             group,
                             group_id,
                             rank,
                             cuda=False,
                             rank_to_GPU=None,
                             dtype=torch.float,
                             qtype=None):
            for dest in group:
                tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype)
                tensors = [
                    _build_tensor([dest + 1, dest + 1], -1, dtype=dtype)
                    for i in group
                ]
                expected_tensors = [
                    _build_tensor([dest + 1, dest + 1], i, dtype=dtype)
                    for i in group
                ]
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
                if tensors[0].dtype == torch.complex64:
                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
                else:
                    tensor_shapes = [tensors[0].shape]
                allgather = quant.auto_quantize(dist.all_gather,
                                                qtype,
                                                quant_loss=None)
                allgather(tensors, tensor, group=group_id, async_op=False)

                for t1, t2 in zip(tensors, expected_tensors):
                    self.assertEqual(t1, t2)
コード例 #7
0
ファイル: ffts.py プロジェクト: aisari/torchsar
def ifft(x, n=None, axis=0, norm="backward", shift=False):
    """IFFT in torchsar

    IFFT in torchsar, since ifft in torch only supports complex-complex transformation,
    for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))),
    also you can use torch's rifft.

    Parameters
    ----------
    x : {torch array}
        both complex and real representation are supported. Since torch does not
        support complex array, when :attr:`x` is complex, we will change the representation
        in real formation(last dimension is 2, real, imag), after IFFT, it will be change back.
    n : int, optional
        number of ifft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of ifft (the default is 0, which the first dimension)
    norm : bool, optional
        Normalization mode. For the backward transform (ifft()), these correspond to:
        - "forward" - no normalization
        - "backward" - normalize by ``1/n`` (default)
        - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        ifft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    if shift:
        y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis),
                                       n=n,
                                       dim=axis,
                                       norm=norm),
                            dim=axis)
    else:
        y = thfft.ifft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
コード例 #8
0
def bound_complex_mask(mask: ComplexTensor, bound_type="tanh"):
    r"""Bound a complex mask, as proposed in [1], section 3.2.

    Valid bound types, for a complex mask $M = |M| ⋅ e^{i φ(M)}$:

    - Unbounded ("UBD"): :math:`M_{\mathrm{UBD}} = M`
    - Sigmoid ("BDSS"): :math:`M_{\mathrm{BDSS}} = σ(|M|) e^{i σ(φ(M))}`
    - Tanh ("BDT"): :math:`M_{\mathrm{BDT}} = \mathrm{tanh}(|M|) e^{i φ(M)}`

    Args:
        bound_type (str or None): The type of bound to use, either of
            "tanh"/"bdt" (default), "sigmoid"/"bdss" or None/"bdt".

    References
        - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net",
          Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107
    """
    if bound_type in {"BDSS", "sigmoid"}:
        return on_reim(torch.sigmoid)(mask)
    elif bound_type in {"BDT", "tanh", "UBD", None}:
        mask_mag, mask_phase = torchaudio.functional.magphase(torch.view_as_real(mask))
        if bound_type in {"BDT", "tanh"}:
            mask_mag_bounded = torch.tanh(mask_mag)
        else:
            mask_mag_bounded = mask_mag
        return torch_complex_from_magphase(mask_mag_bounded, mask_phase)
    else:
        raise ValueError(f"Unknown mask bound {bound_type}")
コード例 #9
0
    def test_batch_TimeStretch(self, test_pseudo_complex):
        rate = 2
        num_freq = 1025
        num_frames = 400

        spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
        pattern = [3, 1, 1, 1]
        if test_pseudo_complex:
            spec = torch.view_as_real(spec)
            pattern += [1]

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=num_freq,
            hop_length=512,
        )(spec).repeat(*pattern)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=num_freq,
            hop_length=512,
        )(spec.repeat(*pattern))

        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
コード例 #10
0
ファイル: fftc.py プロジェクト: zongjg/fastMRI
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.

    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Normalization mode. See ``torch.fft.ifft``.

    Returns:
        The IFFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")

    data = ifftshift(data, dim=[-3, -2])
    data = torch.view_as_real(
        torch.fft.ifftn(  # type: ignore
            torch.view_as_complex(data),
            dim=(-2, -1),
            norm=norm))
    data = fftshift(data, dim=[-3, -2])

    return data
コード例 #11
0
def ifft2(data):
    assert data.shape[-1] == 2
    data = ifftshift(data, axes=(-3, -2))
    data = torch.view_as_complex(data)
    data = torch.fft.ifftn(data, dim=(-2, -1), norm='ortho')
    data = torch.view_as_real(data)
    return data
コード例 #12
0
ファイル: data.py プロジェクト: AlexGrinch/NeMo
 def get_spec(self, audio):
     with torch.cuda.amp.autocast(enabled=False):
         spec = self.stft(audio)
         if spec.dtype in [torch.cfloat, torch.cdouble]:
             spec = torch.view_as_real(spec)
         spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
     return spec
コード例 #13
0
    def forward(self,input):
        assert(input.shape[1] == input.shape[2])
        
        # padding
        input = self.pad(torch.view_as_complex(input))

        # scaling
        input = self.deformation(input)
        #self.deformation(input)
        
        # to Fourier domain
        input = complex_ifftshift(input)
        input = complex_fft(input, 2)
        input = complex_fftshift(input)
#         input = torch.view_as_real(complex_fftshift(input))

        # Zernike layers in the Fourier plane
        input = self.zernike_ft(input)

        # to direct domain
#         input = torch.view_as_complex(input)
        input = complex_ifftshift(input)
        input = complex_ifft(input, 2)
        input = complex_fftshift(input)
#         input = torch.view_as_real(input)
         
        # Zernike layers in the direct plane
        input = self.zernike_direct(input)
        
        # Crop at the center (because of coeff) 
        input = crop_center(input,self.nxy)

        return torch.view_as_real(input)
コード例 #14
0
    def test_phase_vocoder(self, rate, test_pseudo_complex):
        hop_length = 256
        num_freq = 1025
        num_frames = 400
        torch.random.manual_seed(42)

        # Due to cummulative sum, numerical error in using torch.float32 will
        # result in bottom right values of the stretched sectrogram to not
        # match with librosa.
        spec = torch.randn(num_freq,
                           num_frames,
                           device=self.device,
                           dtype=torch.complex128)
        phase_advance = torch.linspace(0,
                                       np.pi * hop_length,
                                       num_freq,
                                       device=self.device,
                                       dtype=torch.float64)[..., None]

        stretched = F.phase_vocoder(
            torch.view_as_real(spec) if test_pseudo_complex else spec,
            rate=rate,
            phase_advance=phase_advance)

        expected_stretched = librosa.phase_vocoder(spec.cpu().numpy(),
                                                   rate=rate,
                                                   hop_length=hop_length)

        self.assertEqual(
            torch.view_as_complex(stretched)
            if test_pseudo_complex else stretched,
            torch.from_numpy(expected_stretched))
コード例 #15
0
def spmat_interp(image: Tensor,
                 interp_mats: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tensor:
    """Sparse matrix interpolation backend."""
    if not isinstance(interp_mats, tuple):
        raise TypeError("interp_mats must be 2-tuple of (real_mat, imag_mat.")

    coef_mat_real, coef_mat_imag = interp_mats
    batch_size, num_coils = image.shape[:2]

    # sparse matrix multiply requires real
    image = torch.view_as_real(image)
    output_size = [batch_size, num_coils, -1]

    # we have to do these transposes because torch.mm requires first to be spmatrix
    image = image.reshape(batch_size * num_coils, -1, 2)
    real_griddat = image.select(-1, 0).t().contiguous()
    imag_griddat = image.select(-1, 1).t().contiguous()

    # apply multiplies
    kdat = torch.stack(
        [
            (torch.mm(coef_mat_real, real_griddat) -
             torch.mm(coef_mat_imag, imag_griddat)).t(),
            (torch.mm(coef_mat_real, imag_griddat) +
             torch.mm(coef_mat_imag, real_griddat)).t(),
        ],
        dim=-1,
    )

    return torch.view_as_complex(kdat).reshape(*output_size)
コード例 #16
0
def stft(x, fft_size, hop_size, win_length, window):
    """Perform STFT and convert to magnitude spectrogram.
    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length.
        window (str): Window function type.
    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """
    if is_torchver_higher18:  ## For future pytorch release (1.8<=), they strongly prefer to use return_complex=True
        x_stft = torch.stft(x,
                            fft_size,
                            hop_size,
                            win_length,
                            window,
                            return_complex=True)
        x_stft = torch.view_as_real(x_stft)
    else:
        x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
    real = x_stft[..., 0]
    imag = x_stft[..., 1]

    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
    return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
コード例 #17
0
def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = torch.view_as_real(torch.fft.fft(v, dim=1))

    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (
        2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V
コード例 #18
0
ファイル: functional_impl.py プロジェクト: twistedmove/audio
    def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
        """Verify the output shape of phase vocoder"""
        hop_length = 256
        num_freq = 1025
        num_frames = 400
        batch_size = 2

        torch.random.manual_seed(42)
        spec = torch.randn(batch_size,
                           num_freq,
                           num_frames,
                           dtype=self.complex_dtype,
                           device=self.device)
        if test_pseudo_complex:
            spec = torch.view_as_real(spec)

        phase_advance = torch.linspace(0,
                                       np.pi * hop_length,
                                       num_freq,
                                       dtype=self.real_dtype,
                                       device=self.device)[..., None]

        spec_stretch = F.phase_vocoder(spec,
                                       rate=rate,
                                       phase_advance=phase_advance)

        assert spec.dim() == spec_stretch.dim()
        expected_shape = torch.Size(
            [batch_size, num_freq,
             int(np.ceil(num_frames / rate))])
        output_shape = (torch.view_as_complex(spec_stretch)
                        if test_pseudo_complex else spec_stretch).shape
        assert output_shape == expected_shape
コード例 #19
0
ファイル: test_meta.py プロジェクト: donhuvy/pytorch
 def test_view_as_real(self):
     x = torch.randn(4, dtype=torch.complex64)
     y = torch.view_as_real(x)
     m = MetaConverter()(y)
     self.assertEqual(m.shape, y.shape)
     self.assertEqual(m.stride(), y.stride())
     self.assertEqual(m.dtype, y.dtype)
コード例 #20
0
ファイル: autograd_test_impl.py プロジェクト: zkneupper/audio
    def test_timestretch_non_zero(self, rate, test_pseudo_complex):
        """Verify that ``T.TimeStretch`` does not fail if it's not close to 0

        ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
        for cases where input is not zero.

        As tested above, when spectrogram contains values close to zero, the gradients are unstable
        and gradcheck fails.

        In this test, we generate spectrogram from random signal, then we push the points around
        zero away from the origin.

        This process does not reflect the real use-case, and it is not practical for users, but
        this helps us understand to what degree the function is differentiable and when not.
        """
        n_fft = 16
        transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
        waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2)
        spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)

        # 1e-3 is too small (on CPU)
        epsilon = 1e-2
        too_close = spectrogram.abs() < epsilon
        spectrogram[too_close] = epsilon * spectrogram[
            too_close] / spectrogram[too_close].abs()
        if test_pseudo_complex:
            spectrogram = torch.view_as_real(spectrogram)
        self.assert_grad(transform, [spectrogram])
コード例 #21
0
ファイル: adamax.py プロジェクト: huaxz1986/pytorch
def _multi_tensor_adamax(params: List[Tensor],
                         grads: List[Tensor],
                         exp_avgs: List[Tensor],
                         exp_infs: List[Tensor],
                         state_steps: List[Tensor],
                         *,
                         beta1: float,
                         beta2: float,
                         lr: float,
                         weight_decay: float,
                         eps: float,
                         maximize: bool):

    if len(params) == 0:
        return

    if maximize:
        grads = torch._foreach_neg(grads)

    params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
    grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
    exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
    exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs]

    # Update steps
    torch._foreach_add_(state_steps, 1)

    if weight_decay != 0:
        torch._foreach_add_(grads, params, alpha=weight_decay)

    # Update biased first moment estimate.
    torch._foreach_mul_(exp_avgs, beta1)
    torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)

    # Update the exponentially weighted infinity norm.
    torch._foreach_mul_(exp_infs, beta2)

    for exp_inf, grad in zip(exp_infs, grads):
        norm_buf = torch.cat([
            exp_inf.unsqueeze(0),
            grad.abs().add_(eps).unsqueeze_(0)
        ], 0)
        torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

    bias_corrections = [1 - beta1 ** step.item() for step in state_steps]
    clr = [-1 * (lr / bias_correction) for bias_correction in bias_corrections]
    torch._foreach_addcdiv_(params, exp_avgs, exp_infs, clr)
コード例 #22
0
ファイル: vsi.py プロジェクト: francois-rozet/piqa
def sdsp(
    x: Tensor,
    filtr: Tensor,
    value_range: float = 1.,
    sigma_c: float = 0.001,
    sigma_d: float = 145.,
) -> Tensor:
    r"""Detects salient regions from :math:`x`.

    Args:
        x: An input tensor, :math:`(N, 3, H, W)`.
        filtr: The frequency domain filter, :math:`(H, W)`.
        value_range: The value range :math:`L` of the input (usually `1.` or `255`).

    Note:
        For the remaining arguments, refer to [Zhang2013]_.

    Returns:
        The visual saliency tensor, :math:`(N, H, W)`.

    Example:
        >>> x = torch.rand(5, 3, 256, 256)
        >>> filtr = sdsp_filter(x)
        >>> vs = sdsp(x, filtr)
        >>> vs.size()
        torch.Size([5, 256, 256])
    """

    x_lab = xyz_to_lab(rgb_to_xyz(x, value_range))

    # Frequency prior
    x_f = fft.ifft2(fft.fft2(x_lab) * filtr)
    x_f = cx.real(torch.view_as_real(x_f))

    s_f = l2_norm(x_f, dims=[1])

    # Color prior
    x_ab = x_lab[:, 1:]

    lo, _ = x_ab.flatten(-2).min(dim=-1)
    up, _ = x_ab.flatten(-2).max(dim=-1)

    lo = lo.view(lo.shape + (1, 1))
    up = up.view(lo.shape)
    span = torch.where(up > lo, up - lo, torch.tensor(1.).to(lo))

    x_ab = (x_ab - lo) / span

    s_c = 1. - torch.exp(-torch.sum(x_ab**2, dim=1) / sigma_c**2)

    # Location prior
    a, b = [torch.arange(n).to(x) - (n - 1) / 2 for n in x.shape[-2:]]

    s_d = torch.exp(-(a[None, :]**2 + b[:, None]**2) / sigma_d**2)

    # Visual saliency
    vs = s_f * s_c * s_d

    return vs
コード例 #23
0
def s2_fft(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, complex]
    :return:  [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    b_in = x.size(-2) // 2
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    if b_out is None:
        b_out = b_in
    assert b_out <= b_in
    batch_size = x.size()[:-3]

    x = x.view(-1, 2 * b_in, 2 * b_in, 2)  # [batch, beta, alpha, complex]
    '''
    :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    '''
    nspec = b_out**2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in,
                           nl=b_out,
                           weighted=not for_grad,
                           device=x.device)
    wigner = wigner.view(2 * b_in, -1)  # [beta, l * m] (2 * b_in, nspec)

    x = torch.view_as_real(torch.fft.fft(
        torch.view_as_complex(x)))  # [batch, beta, m, complex]

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in,
                                               nspec=nspec,
                                               nbatch=nbatch,
                                               device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                    args=[
                        x.contiguous().data_ptr(),
                        wigner.contiguous().data_ptr(),
                        output.data_ptr()
                    ],
                    stream=stream)
        # [l * m, batch, complex]
    else:
        for l in range(b_out):
            s = slice(l**2, l**2 + 2 * l + 1)
            xx = torch.cat(
                (x[:, :,
                   -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
            output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))

    output = output.view(-1, *batch_size,
                         2)  # [l * m, ..., complex] (nspec, ..., 2)
    return output
コード例 #24
0
def A_realspace(r, t, psi, out):
    """

    :param r:   K x 2
    :param t:   BB x NY x NX
    :param psi: B x K x MY x MX
    :param out: K x MY x MX
    :return:
    """
    gpu = cuda.get_current_device()
    threadsperblock = gpu.MAX_THREADS_PER_BLOCK
    blockspergrid = m.ceil(np.prod(out.shape) / threadsperblock)
    # print(r.shape,t.shape,psi.shape,out.shape)
    A_realspace_kernel[blockspergrid, threadsperblock](r, th.view_as_real(t),
                                                       th.view_as_real(psi),
                                                       th.view_as_real(out))
    return out
コード例 #25
0
ファイル: gw.py プロジェクト: francois-rozet/amsi
    def events(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r""" x* """

        x = self.postprocess(self.x_star)
        x = torch.tensor(x)
        x = torch.view_as_real(x)

        return None, x[None]
コード例 #26
0
def _single_tensor_adagrad(params: List[Tensor], grads: List[Tensor],
                           state_sums: List[Tensor], state_steps: List[Tensor],
                           *, lr: float, weight_decay: float, lr_decay: float,
                           eps: float, has_sparse_grad: bool):

    for (param, grad, state_sum, step_t) in zip(params, grads, state_sums,
                                                state_steps):
        # update step
        step_t += 1
        step = step_t.item()

        if weight_decay != 0:
            if grad.is_sparse:
                raise RuntimeError(
                    "weight_decay option is not compatible with sparse gradients"
                )
            grad = grad.add(param, alpha=weight_decay)

        clr = lr / (1 + (step - 1) * lr_decay)

        if grad.is_sparse:
            grad = grad.coalesce(
            )  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices,
                                        grad_values.pow(2)))
            std = state_sum.sparse_mask(grad)
            std_values = std._values().sqrt_().add_(eps)
            param.add_(_make_sparse(grad, grad_indices,
                                    grad_values / std_values),
                       alpha=-clr)
        else:
            is_complex = torch.is_complex(param)
            if is_complex:
                grad = torch.view_as_real(grad)
                state_sum = torch.view_as_real(state_sum)
                param = torch.view_as_real(param)
            state_sum.addcmul_(grad, grad, value=1)
            std = state_sum.sqrt().add_(eps)
            param.addcdiv_(grad, std, value=-clr)
            if is_complex:
                param = torch.view_as_complex(param)
                state_sum = torch.view_as_complex(state_sum)
コード例 #27
0
 def irfft(input, n=None):
     if torch.is_complex(input):
         input = torch.view_as_real(input)
     else:
         input = torch.nn.functional.pad(input[..., None], (0, 1))
     if n is None:
         n = 2 * (input.size(-1) - 1)
     return torch.irfft(input, 1, signal_sizes=(n,))
コード例 #28
0
def _matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor:
    vals, vecs = torch.eig(matrix, eigenvectors=True)
    vals = torch.view_as_complex(vals.contiguous())
    vals_pow = vals.pow(p)
    vals_pow = torch.view_as_real(vals_pow)[:, 0]
    matrix_pow = torch.matmul(
        vecs, torch.matmul(torch.diag(vals_pow), torch.inverse(vecs)))
    return matrix_pow
コード例 #29
0
 def test_InverseSpectrogram_pseudocomplex(self):
     tensor = common_utils.get_whitenoise(sample_rate=8000)
     spectrogram = common_utils.get_spectrogram(tensor,
                                                n_fft=400,
                                                hop_length=100)
     spectrogram = torch.view_as_real(spectrogram)
     self._assert_consistency(
         T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
コード例 #30
0
def adagrad(params: List[Tensor], grads: List[Tensor],
            state_sums: List[Tensor], state_steps: List[int], *, lr: float,
            weight_decay: float, lr_decay: float, eps: float):
    r"""Functional API that performs Adagrad algorithm computation.

    See :class:`~torch.optim.Adagrad` for details.
    """

    for (param, grad, state_sum, step) in zip(params, grads, state_sums,
                                              state_steps):
        if weight_decay != 0:
            if grad.is_sparse:
                raise RuntimeError(
                    "weight_decay option is not compatible with sparse gradients"
                )
            grad = grad.add(param, alpha=weight_decay)

        clr = lr / (1 + (step - 1) * lr_decay)

        if grad.is_sparse:
            grad = grad.coalesce(
            )  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices,
                                        grad_values.pow(2)))
            std = state_sum.sparse_mask(grad)
            std_values = std._values().sqrt_().add_(eps)
            param.add_(_make_sparse(grad, grad_indices,
                                    grad_values / std_values),
                       alpha=-clr)
        else:
            is_complex = torch.is_complex(param)
            if is_complex:
                grad = torch.view_as_real(grad)
                state_sum = torch.view_as_real(state_sum)
                param = torch.view_as_real(param)
            state_sum.addcmul_(grad, grad, value=1)
            std = state_sum.sqrt().add_(eps)
            param.addcdiv_(grad, std, value=-clr)
            if is_complex:
                param = torch.view_as_complex(param)
                state_sum = torch.view_as_complex(state_sum)