def test_ifftn(norm, ndim, shuffle): if not _torch_has_fft_module or not _torch_has_old_fft: return True nifft._torch_has_complex = False nifft._torch_has_fft_module = False nifft._torch_has_fftshift = False dims = [0, 1, -2, 3] if shuffle: import random random.shuffle(dims) dims = dims[:ndim] x = torch.randn([4, 9, 16, 33, 2], dtype=torch.double) f1 = pyfft.ifftn(torch.complex(x[..., 0], x[..., 1]), norm=norm, dim=dims) f2 = nifft.ifftn(x, norm=norm, dim=dims) f2 = torch.complex(f2[..., 0], f2[..., 1]) assert torch.allclose(f1, f2) x = torch.randn([4, 9, 16, 33], dtype=torch.double) f1 = pyfft.ifftn(x, norm=norm, dim=dims) f2 = nifft.ifftn(x, real=True, norm=norm, dim=dims) f2 = torch.complex(f2[..., 0], f2[..., 1]) assert torch.allclose(f1, f2, atol=1e-5)
def _laplacian_filter(phase, freq, dims): """ Assumes ifftshift has been applied to phase and freq. Eq (5) from Schofield and Zhu. """ g, ig = freq s = phase.sin() c = phase.cos() fft_ = lambda x: fft.fftn(x, dim=dims) ifft_ = lambda x: fft.ifftn(x, dim=dims) phase = ifft_(fft_(s).mul_(g)).mul_(c) phase -= ifft_(fft_(c).mul_(g)).mul_(s) phase = ifft_(fft_(phase).mul_(ig)) phase = fft.real(phase) return phase
def greens_apply(mom, greens, factor=1, voxel_size=1): """Apply the Greens function to a momentum field. Parameters ---------- mom : (..., *spatial, dim) tensor Momentum greens : (*spatial, [dim, dim]) tensor Greens function voxel_size : [sequence of] float, default=1 Voxel size. Only needed when no penalty is put on linear-elasticity. Returns ------- vel : (..., *spatial, dim) tensor Velocity """ # Authors # ------- # .. John Ashburner <*****@*****.**> : original Matlab code # .. Yael Balbastre <*****@*****.**> : Python port # # License # ------- # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm) # under the GNU General Public Licence (version >= 2). mom, greens = utils.to_max_backend(mom, greens) dim = mom.shape[-1] # fourier transform mom = fft.fftn(mom, dim=list(range(-dim - 1, -1)), real=True) # mom = utils.movedim(mom, -1, 0) # if utils.torch_version('>=', (1, 8)): # mom = torch.fft.fftn(mom, dim=list(range(-dim, 0))) # else: # if torch.backends.mkl.is_available: # # use rfft # mom = torch.rfft(mom, dim, onesided=False) # else: # zero = mom.new_zeros([]).expand(mom.shape) # mom = torch.stack([mom, zero], dim=-1) # mom = torch.fft(mom, dim) # mom = utils.movedim(mom, 0, -1) # voxel-wise matrix multiplication # if greens.dim() == dim: # voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom)) # voxel_size = voxel_size.square() # if utils.torch_version('<', (1, 8)): # greens = greens[..., None, None] # mom = mom * greens # mom = mom / voxel_size # else: # if utils.torch_version('<', (1, 8)): # mom[..., 0, :] = linalg.matvec(greens, mom[..., 0, :]) # mom[..., 1, :] = linalg.matvec(greens, mom[..., 1, :]) # else: # mom = torch.complex(linalg.matvec(greens, mom.real), # linalg.matvec(greens, mom.imag)) if greens.dim() == dim: voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom)) voxel_size = voxel_size.square().reciprocal() greens = greens.unsqueeze(-1) mom = fft.mul(mom, greens, real=(False, True)) mom = fft.mul(mom, voxel_size, real=(False, True)) else: mom = fft.mul(mom, greens, real=(False, True)) # inverse fourier transform # mom = utils.movedim(mom, -1, 0) # if utils.torch_version('>=', (1, 8)): # mom = torch.fft.ifftn(mom, dim=list(range(-dim, 0))).real # if callable(mom): # mom = mom() # else: # mom = torch.ifft(mom, dim)[..., 0] # mom = utils.movedim(mom, 0, -1) mom = fft.real(fft.ifftn(mom, dim=list(range(-dim - 1, -1)))) mom /= factor return mom
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) voxel_size = overload.get('voxel_size', self.voxel_size) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) voxel_size = voxel_size.tolist() lame = py.make_list(self.lame, 2) if (hasattr(self, '_greens') and self._voxel_size == voxel_size and self._shape == shape): greens = self._greens.to(dtype=dtype, device=device) else: greens = spatial.greens( shape, absolute=self.absolute, membrane=self.membrane, bending=self.bending, lame=self.lame, voxel_size=voxel_size, device=device, dtype=dtype) if any(lame): greens, scale, _ = torch.svd(greens) scale = scale.sqrt_() greens *= scale.unsqueeze(-1) else: greens = greens.sqrt_() if self.cache_greens: self._greens = greens self._voxel_size = voxel_size self._shape = shape sample = torch.randn([2, batch, *shape, nb_dim], **backend) # multiply by square root of greens if greens.dim() > nb_dim: # lame sample = linalg.matvec(greens, sample) else: sample = sample * greens.unsqueeze(-1) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) sample = sample / voxel_size.sqrt() sample = fft.complex(sample[0], sample[1]) # inverse Fourier transform dims = list(range(-nb_dim-1, -1)) sample = fft.real(fft.ifftn(sample, dim=dims)) sample *= py.prod(shape) # add mean sample += mean return sample
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional channel : int, optional voxel_size : float or (dim,) vector_like, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) channel = overload.get('channel', self.channel) voxel_size = overload.get('voxel_size', self.voxel_size) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) voxel_size = utils.make_vector(voxel_size, nb_dim, **backend) voxel_size = voxel_size.tolist() if (hasattr(self, '_greens') and self._voxel_size == voxel_size and self._channel == channel and self._shape == shape): greens = self._greens.to(dtype=dtype, device=device) else: mean = utils.make_vector(self.mean, channel, **backend) absolute = utils.make_vector(self.absolute, channel, **backend) membrane = utils.make_vector(self.membrane, channel, **backend) bending = utils.make_vector(self.bending, channel, **backend) greens = [] for c in range(channel): greens.append(spatial.greens( shape, absolute=absolute[c], membrane=membrane[c], bending=bending[c], lame=0, voxel_size=voxel_size, device=device, dtype=dtype)) greens = torch.stack(greens) greens = greens.sqrt_() if self.cache_greens: self._greens = greens self._voxel_size = voxel_size self._shape = shape # sample white noise sample = torch.randn([2, batch, channel, *shape], **backend) sample *= greens.unsqueeze(-1) sample = fft.complex(sample[0], sample[1]) # inverse Fourier transform dims = list(range(-nb_dim, 0)) sample = fft.real(fft.ifftn(sample, dim=dims)) sample *= py.prod(shape) # add mean sample += utils.unsqueeze(mean, -1, len(shape)) return sample