def unwrap(phase, dim=None, bound='dct2', max_iter=0, tol=1e-5): """Laplacian unwrapping of the phase Parameters ---------- phase : tensor Wrapped phase, in radian dim : int, default=phase.dim() Number of spatial dimensions max_iter : int, default=0 Maximum number of unwrapping iterations. If 0, return the Laplacian filtered phase, which is not exactly equal to the input phase modulo 2 pi. tol : float, default=1e-5 Tolerance for early stopping Returns ------- unwrapped : tensor References ---------- .. "Fast phase unwrapping algorithm for interferometric applications" Marvin A. Schofield and Yimei Zhu Optics Letters (2003) """ # TODO: would be nice to use DCT/DST rather than padding once they # are available in PyTorch. dim = dim or phase.dim() dims = list(range(-dim, 0)) shape = bigshape = phase.shape[-dim:] if bound not in ('dct', 'circulant'): phase = utils.pad(phase, [d//2 for d in shape], side='both', mode=bound) bigshape = phase.shape[-dim:] freq = _laplacian_freq(bigshape, **utils.backend(phase)) phase = fft.ifftshift(phase, dim=dims) twopi = 2 * pymath.pi if max_iter == 0: phase = _laplacian_filter(phase, freq, dims) else: for n_iter in range(1, max_iter+1): filtered_phase = _laplacian_filter(phase, freq, dims) filtered_phase.sub_(phase).div_(twopi).round_().mul_(twopi) phase += filtered_phase if n_iter < max_iter and filtered_phase.mean() < tol: break phase = fft.fftshift(phase, dim=dims) if bound not in ('dct', 'circulant'): slicer = [slice(d//2, d+d//2) for d in shape] phase = phase[(Ellipsis, *slicer)] return phase
def _laplacian_freq(shape, **backend): """ Compute Fourier squared frequency on the lattice and its inverse. """ dim = len(shape) shape = torch.as_tensor(shape, **backend) g = spatial.identity_grid(shape, **backend) g -= shape // 2 g /= shape g = g.square_().sum(-1) if fft._torch_has_old_fft: g = g.unsqueeze(-1) g = fft.ifftshift(g, dim=list(range(dim))) ig = g.reciprocal() ig[(0,) * dim] = 0 return g, ig
def greens(shape, absolute=_default_absolute, membrane=_default_membrane, bending=_default_bending, lame=_default_lame, factor=1, voxel_size=1, dtype=None, device=None): """Generate the Greens function of a regulariser in Fourier space. Parameters ---------- shape : tuple[int] Output shape absolute : float, default=0.0001 Penalty on absolute values membrane : float, default=0.001 Penalty on membrane energy bending : float, default=0.2 Penalty on bending energy lame : float or (float, float), default=(0.05, 0.2) Penalty on linear-elastic energy voxel_size : [sequence of[ float, default=1 Voxel size dtype : torch.dtype, optional device : torch.device, optional Returns ------- greens : (*shape, [dim, dim]) tensor """ # Authors # ------- # .. John Ashburner <*****@*****.**> : original Matlab code # .. Yael Balbastre <*****@*****.**> : Python port # # License # ------- # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm) # under the GNU General Public Licence (version >= 2). backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) lame1, lame2 = py.make_list(lame, 2) if not absolute: absolute = max(absolute, max(membrane, bending, lame1, lame2) * 1e-3) prm = dict( absolute=absolute, membrane=membrane, bending=bending, # factor=factor, voxel_size=voxel_size, bound='dft') if lame1 or lame2: prm['lame'] = lame # allocate if lame1 or lame2: kernel = torch.zeros([*shape, dim, dim], **backend) else: kernel = torch.zeros([*shape, 1, 1], **backend) # only use center to generate kernel if bending: subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)] subsize = 5 else: subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)] subsize = 3 # generate kernel if lame1 or lame2: for d in range(dim): center = (subsize // 2, ) * dim + (d, d) subkernel[center] = 1 subkernel[..., :, d] = regulariser_grid(subkernel[..., :, d], **prm) else: center = (subsize // 2, ) * dim subkernel[center] = 1 subkernel[..., 0, 0] = regulariser(subkernel[None, ..., 0, 0], **prm, dim=dim)[0] kernel = fft.ifftshift(kernel, dim=range(dim)) # fourier transform # symmetric kernel -> real coefficients dtype = kernel.dtype kernel = kernel.double() kernel = fft.real(fft.fftn(kernel, dim=list(range(dim)), real=True)) # if utils.torch_version('>=', (1, 8)): # kernel = utils.movedim(kernel, [-2, -1], [0, 1]) # kernel = torch.fft.fftn(kernel, dim=list(range(-dim, 0))).real # if callable(kernel): # kernel = kernel() # kernel = utils.movedim(kernel, [0, 1], [-2, -1]) # else: # kernel = utils.movedim(kernel, [-2, -1], [0, 1]) # if torch.backends.mkl.is_available: # # use rfft # kernel = torch.rfft(kernel, dim, onesided=False) # else: # zero = kernel.new_zeros([]).expand(kernel.shape) # kernel = torch.stack([kernel, zero], dim=-1) # kernel = torch.fft(kernel, dim) # kernel = kernel[..., 0] # should be real # kernel = utils.movedim(kernel, [0, 1], [-2, -1]) kernel = kernel.to(dtype=dtype) if lame1 or lame2: kernel = _inv(kernel) #kernel.inverse() else: kernel = kernel[..., 0, 0].reciprocal_() return kernel
def mrfield_greens(shape, absolute=0, membrane=0, bending=0, factor=1, voxel_size=1, dtype=None, device=None): """Generate the Greens function of a regulariser in Fourier space. Parameters ---------- shape : tuple[int] Output shape absolute : float, default=0.0001 Penalty on absolute values membrane : float, default=0.001 Penalty on membrane energy bending : float, default=0.2 Penalty on bending energy voxel_size : [sequence of[ float, default=1 Voxel size dtype : torch.dtype, optional device : torch.device, optional Returns ------- greens : (*shape, [dim, dim]) tensor """ # Adapted from the geodesic shooting code backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) if not absolute: # we need some regularization to invert absolute = max(absolute, max(membrane, bending) * 1e-6) prm = dict(absolute=absolute, membrane=membrane, bending=bending, factor=factor, voxel_size=voxel_size, bound='dft') # allocate kernel = torch.zeros(shape, **backend) # only use center to generate kernel if bending: subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)] subsize = 5 else: subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)] subsize = 3 # generate kernel center = (subsize // 2, ) * dim subkernel[center] = 1 subkernel[...] = regulariser(subkernel, **prm, dim=dim) kernel = ifftshift(kernel, dim=range(dim)) # fourier transform # symmetric kernel -> real coefficients if utils.torch_version('>=', (1, 8)): kernel = torch.fft.fftn(kernel, dim=dim).real() else: if torch.backends.mkl.is_available: # use rfft kernel = torch.rfft(kernel, dim, onesided=False) else: zero = kernel.new_zeros([]).expand(kernel.shape) kernel = torch.stack([kernel, zero], dim=-1) kernel = torch.fft(kernel, dim) kernel = kernel[..., 0] # should be real kernel = kernel.reciprocal_() return kernel
def mrfield_greens2(shape, zdim=-1, voxel_size=1, dtype=None, device=None): """Semi-analytical second derivative of the Greens kernel. This function implements exactly the solution from Jenkinson et al. (Same as in the FSL source code), with the assumption that no gradients are played and the main field is constant and has no orthogonal components (Bz = B0, Bx = By = 0). The Greens kernel and its second derivatives are derived analytically and integrated numerically over a voxel. The returned tensor has already been Fourier transformed and could be cached if multiple field simulations with the same lattice size must be performed in a row. Parameters ---------- shape : sequence of int Lattice shape zdim : int, defualt=-1 Dimension of the main magnetic field voxel_size : [sequence of] int Voxel size dtype : torch.dtype, optional device : torch.device, optional Returns ------- kernel : (*shape) tensor Fourier transform of the (second derivatives of the) Greens kernel. """ import itertools def atan(num, den): return torch.where(den.abs() > 1e-8, torch.atan_(num / den), torch.atan2(num, den)) dim = len(py.make_list(shape)) g0 = identity_grid(shape, dtype=dtype, device=device) voxel_size = utils.make_vector(voxel_size, dim, dtype=torch.double).tolist() if dim == 3: if zdim in (-1, 2): odims = [-3, -2] elif zdim in (-2, 1): odims = [-3, -1] elif zdim in (-3, 0): odims = [-2, -1] else: raise NotImplementedError def make_shifted(shift): g = g0.clone() for g1, s, v, t in zip(g.unbind(-1), shape, voxel_size, shift): g1 -= s // 2 # make center voxel zero g1 += t # apply shift g1 *= v # convert to mm return g g = 0 for shift in itertools.product([-0.5, 0.5], repeat=dim): g1 = make_shifted(shift) if dim == 3: r = g1.square().sum(-1).sqrt_() g1 = atan(g1[..., odims[0]] * g1[..., odims[1]], g1[..., zdim] * r) else: raise NotImplementedError if py.prod(shift) < 0: g -= g1 else: g += g1 g /= 4. * constants.pi g = ifftshift(g, range(dim)) # move center voxel to first voxel # fourier transform # symmetric kernel -> real coefficients if utils.torch_version('>=', (1, 8)): g = torch.fft.fftn(g, dim=dim).real() else: if torch.backends.mkl.is_available: # use rfft g = torch.rfft(g, dim, onesided=False) else: zero = g.new_zeros([]).expand(g.shape) g = torch.stack([g, zero], dim=-1) g = torch.fft(g, dim) g = g[..., 0] # should be real return g