Exemplo n.º 1
0
def test_ifftn(norm, ndim, shuffle):
    if not _torch_has_fft_module or not _torch_has_old_fft:
        return True

    nifft._torch_has_complex = False
    nifft._torch_has_fft_module = False
    nifft._torch_has_fftshift = False

    dims = [0, 1, -2, 3]
    if shuffle:
        import random
        random.shuffle(dims)
    dims = dims[:ndim]

    x = torch.randn([4, 9, 16, 33, 2], dtype=torch.double)
    f1 = pyfft.ifftn(torch.complex(x[..., 0], x[..., 1]), norm=norm, dim=dims)
    f2 = nifft.ifftn(x, norm=norm, dim=dims)
    f2 = torch.complex(f2[..., 0], f2[..., 1])
    assert torch.allclose(f1, f2)

    x = torch.randn([4, 9, 16, 33], dtype=torch.double)
    f1 = pyfft.ifftn(x, norm=norm, dim=dims)
    f2 = nifft.ifftn(x, real=True, norm=norm, dim=dims)
    f2 = torch.complex(f2[..., 0], f2[..., 1])
    assert torch.allclose(f1, f2, atol=1e-5)
Exemplo n.º 2
0
def _laplacian_filter(phase, freq, dims):
    """
    Assumes ifftshift has been applied to phase and freq.
    Eq (5) from Schofield and Zhu.
    """
    g, ig = freq

    s = phase.sin()
    c = phase.cos()

    fft_ = lambda x: fft.fftn(x, dim=dims)
    ifft_ = lambda x: fft.ifftn(x, dim=dims)

    phase = ifft_(fft_(s).mul_(g)).mul_(c)
    phase -= ifft_(fft_(c).mul_(g)).mul_(s)
    phase = ifft_(fft_(phase).mul_(ig))
    phase = fft.real(phase)

    return phase
Exemplo n.º 3
0
def greens_apply(mom, greens, factor=1, voxel_size=1):
    """Apply the Greens function to a momentum field.

    Parameters
    ----------
    mom : (..., *spatial, dim) tensor
        Momentum
    greens : (*spatial, [dim, dim]) tensor
        Greens function
    voxel_size : [sequence of] float, default=1
        Voxel size. Only needed when no penalty is put on linear-elasticity.

    Returns
    -------
    vel : (..., *spatial, dim) tensor
        Velocity

    """
    # Authors
    # -------
    # .. John Ashburner <*****@*****.**> : original Matlab code
    # .. Yael Balbastre <*****@*****.**> : Python port
    #
    # License
    # -------
    # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner
    # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm)
    # under the GNU General Public Licence (version >= 2).

    mom, greens = utils.to_max_backend(mom, greens)
    dim = mom.shape[-1]

    # fourier transform
    mom = fft.fftn(mom, dim=list(range(-dim - 1, -1)), real=True)

    # mom = utils.movedim(mom, -1, 0)
    # if utils.torch_version('>=', (1, 8)):
    #     mom = torch.fft.fftn(mom, dim=list(range(-dim, 0)))
    # else:
    #     if torch.backends.mkl.is_available:
    #         # use rfft
    #         mom = torch.rfft(mom, dim, onesided=False)
    #     else:
    #         zero = mom.new_zeros([]).expand(mom.shape)
    #         mom = torch.stack([mom, zero], dim=-1)
    #         mom = torch.fft(mom, dim)
    # mom = utils.movedim(mom, 0, -1)

    # voxel-wise matrix multiplication
    # if greens.dim() == dim:
    #     voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom))
    #     voxel_size = voxel_size.square()
    #     if utils.torch_version('<', (1, 8)):
    #         greens = greens[..., None, None]
    #     mom = mom * greens
    #     mom = mom / voxel_size
    # else:
    #     if utils.torch_version('<', (1, 8)):
    #         mom[..., 0, :] = linalg.matvec(greens, mom[..., 0, :])
    #         mom[..., 1, :] = linalg.matvec(greens, mom[..., 1, :])
    #     else:
    #         mom = torch.complex(linalg.matvec(greens, mom.real),
    #                             linalg.matvec(greens, mom.imag))

    if greens.dim() == dim:
        voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom))
        voxel_size = voxel_size.square().reciprocal()
        greens = greens.unsqueeze(-1)
        mom = fft.mul(mom, greens, real=(False, True))
        mom = fft.mul(mom, voxel_size, real=(False, True))
    else:
        mom = fft.mul(mom, greens, real=(False, True))

    # inverse fourier transform
    # mom = utils.movedim(mom, -1, 0)
    # if utils.torch_version('>=', (1, 8)):
    #     mom = torch.fft.ifftn(mom, dim=list(range(-dim, 0))).real
    #     if callable(mom):
    #         mom = mom()
    # else:
    #     mom = torch.ifft(mom, dim)[..., 0]
    # mom = utils.movedim(mom, 0, -1)

    mom = fft.real(fft.ifftn(mom, dim=list(range(-dim - 1, -1))))
    mom /= factor

    return mom
Exemplo n.º 4
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)
        voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
        voxel_size = voxel_size.tolist()
        lame = py.make_list(self.lame, 2)

        if (hasattr(self, '_greens')
                and self._voxel_size == voxel_size
                and self._shape == shape):
            greens = self._greens.to(dtype=dtype, device=device)
        else:
            greens = spatial.greens(
                shape,
                absolute=self.absolute,
                membrane=self.membrane,
                bending=self.bending,
                lame=self.lame,
                voxel_size=voxel_size,
                device=device,
                dtype=dtype)
            if any(lame):
                greens, scale, _ = torch.svd(greens)
                scale = scale.sqrt_()
                greens *= scale.unsqueeze(-1)
            else:
                greens = greens.sqrt_()

            if self.cache_greens:
                self._greens = greens
                self._voxel_size = voxel_size
                self._shape = shape

        sample = torch.randn([2, batch, *shape, nb_dim], **backend)

        # multiply by square root of greens
        if greens.dim() > nb_dim:  # lame
            sample = linalg.matvec(greens, sample)
        else:
            sample = sample * greens.unsqueeze(-1)
            voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
            sample = sample / voxel_size.sqrt()
        sample = fft.complex(sample[0], sample[1])

        # inverse Fourier transform
        dims = list(range(-nb_dim-1, -1))
        sample = fft.real(fft.ifftn(sample, dim=dims))
        sample *= py.prod(shape)

        # add mean
        sample += mean

        return sample
Exemplo n.º 5
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        voxel_size : float or (dim,) vector_like, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)
        voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
        voxel_size = voxel_size.tolist()

        if (hasattr(self, '_greens')
                and self._voxel_size == voxel_size
                and self._channel == channel
                and self._shape == shape):
            greens = self._greens.to(dtype=dtype, device=device)
        else:
            mean = utils.make_vector(self.mean, channel, **backend)
            absolute = utils.make_vector(self.absolute, channel, **backend)
            membrane = utils.make_vector(self.membrane, channel, **backend)
            bending = utils.make_vector(self.bending, channel, **backend)

            greens = []
            for c in range(channel):
                greens.append(spatial.greens(
                    shape,
                    absolute=absolute[c],
                    membrane=membrane[c],
                    bending=bending[c],
                    lame=0,
                    voxel_size=voxel_size,
                    device=device,
                    dtype=dtype))
            greens = torch.stack(greens)
            greens = greens.sqrt_()

            if self.cache_greens:
                self._greens = greens
                self._voxel_size = voxel_size
                self._shape = shape

        # sample white noise
        sample = torch.randn([2, batch, channel, *shape], **backend)
        sample *= greens.unsqueeze(-1)
        sample = fft.complex(sample[0], sample[1])

        # inverse Fourier transform
        dims = list(range(-nb_dim, 0))
        sample = fft.real(fft.ifftn(sample, dim=dims))
        sample *= py.prod(shape)

        # add mean
        sample += utils.unsqueeze(mean, -1, len(shape))

        return sample