def dist_map(shape, dtype=None, device=None): """Return the squared distance between all pairs in a FOV. Parameters ---------- shape : sequence[int] dtype : optional device : optional Returns ------- dist : (prod(shape), proD(shape) tensor Squared distance map """ backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) g = spatial.identity_grid(shape, **backend) g = g.reshape([-1, dim]) g = (g[:, None, :] - g[None, :, :]).square_().sum(-1) return g
def susceptibility_phantom(shape, radius=None, dtype=None, device=None): """Generate a circle/sphere susceptibility phantom Parameters ---------- shape : sequence of int radius : default=shape/4 dtype : optional backend : optional Returns ------- f : (*shape) tensor[bool] susceptibility delta map """ shape = py.make_tuple(shape) radius = radius or (min(shape) / 4.) f = identity_grid(shape, dtype=dtype, device=device) for comp, s in zip(f.unbind(-1), shape): comp -= s / 2 f = f.square().sum(-1).sqrt() <= radius return f
def greens(shape, absolute=_default_absolute, membrane=_default_membrane, bending=_default_bending, lame=_default_lame, factor=1, voxel_size=1, dtype=None, device=None): """Generate the Greens function of a regulariser in Fourier space. Parameters ---------- shape : tuple[int] Output shape absolute : float, default=0.0001 Penalty on absolute values membrane : float, default=0.001 Penalty on membrane energy bending : float, default=0.2 Penalty on bending energy lame : float or (float, float), default=(0.05, 0.2) Penalty on linear-elastic energy voxel_size : [sequence of[ float, default=1 Voxel size dtype : torch.dtype, optional device : torch.device, optional Returns ------- greens : (*shape, [dim, dim]) tensor """ # Authors # ------- # .. John Ashburner <*****@*****.**> : original Matlab code # .. Yael Balbastre <*****@*****.**> : Python port # # License # ------- # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm) # under the GNU General Public Licence (version >= 2). backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) lame1, lame2 = py.make_list(lame, 2) if not absolute: absolute = max(absolute, max(membrane, bending, lame1, lame2) * 1e-3) prm = dict( absolute=absolute, membrane=membrane, bending=bending, # factor=factor, voxel_size=voxel_size, bound='dft') if lame1 or lame2: prm['lame'] = lame # allocate if lame1 or lame2: kernel = torch.zeros([*shape, dim, dim], **backend) else: kernel = torch.zeros([*shape, 1, 1], **backend) # only use center to generate kernel if bending: subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)] subsize = 5 else: subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)] subsize = 3 # generate kernel if lame1 or lame2: for d in range(dim): center = (subsize // 2, ) * dim + (d, d) subkernel[center] = 1 subkernel[..., :, d] = regulariser_grid(subkernel[..., :, d], **prm) else: center = (subsize // 2, ) * dim subkernel[center] = 1 subkernel[..., 0, 0] = regulariser(subkernel[None, ..., 0, 0], **prm, dim=dim)[0] kernel = fft.ifftshift(kernel, dim=range(dim)) # fourier transform # symmetric kernel -> real coefficients dtype = kernel.dtype kernel = kernel.double() kernel = fft.real(fft.fftn(kernel, dim=list(range(dim)), real=True)) # if utils.torch_version('>=', (1, 8)): # kernel = utils.movedim(kernel, [-2, -1], [0, 1]) # kernel = torch.fft.fftn(kernel, dim=list(range(-dim, 0))).real # if callable(kernel): # kernel = kernel() # kernel = utils.movedim(kernel, [0, 1], [-2, -1]) # else: # kernel = utils.movedim(kernel, [-2, -1], [0, 1]) # if torch.backends.mkl.is_available: # # use rfft # kernel = torch.rfft(kernel, dim, onesided=False) # else: # zero = kernel.new_zeros([]).expand(kernel.shape) # kernel = torch.stack([kernel, zero], dim=-1) # kernel = torch.fft(kernel, dim) # kernel = kernel[..., 0] # should be real # kernel = utils.movedim(kernel, [0, 1], [-2, -1]) kernel = kernel.to(dtype=dtype) if lame1 or lame2: kernel = _inv(kernel) #kernel.inverse() else: kernel = kernel[..., 0, 0].reciprocal_() return kernel
def empirical_cov(series, nb_dim=1, dim=None, subtract_mean=True, flatten=False, keepdim=False, return_mean=False): """Compute an empirical covariance Parameters ---------- series : (..., *dims) tensor_like Sample series nb_dim : int, default=1 Number of spatial dimensions. dim : [sequence of] int, default=None Dimensions that are reduced when computing the covariance. If None: all but the last `nb_dim`. subtract_mean : bool, default=True Subtract empirical mean before computing the covariance. flatten : bool, default=False If True, flatten the 'covariance' dimensions. keepdim : bool, default=False Keep reduced dimensions. Returns ------- cov : (..., *dims, *dims) or (..., prod(dims), prod(dims)) tensor Covariance. mean : (..., *dims) or (..., prod(dims)) tensor, if `return_mean` Mean. """ # Convert to tensor series = torch.as_tensor(series) prespatial = series.shape[:-nb_dim] spatial = series.shape[-nb_dim:] if dim is None: dim = range(series.dim() - nb_dim) dim = py.make_tuple(dim) dim = [series.dim() + d if d < 0 else d for d in dim] reduced = [prespatial[d] for d in dim] batch = [ prespatial[d] for d in range(series.dim() - nb_dim) if d not in dim ] # Subtract mean if subtract_mean: mean = series.mean(dim=dim, keepdim=True) series = series - mean # Compute empirical covariance. series = series.reshape([*series.shape[:-nb_dim], -1]) series = utils.movedim(series, dim, -2) series = series.reshape([*batch, -1, series.shape[-1]]) n_reduced = series.shape[-2] n_vox = series.shape[-1] # (*batch, reduced, spatial) # Torch's matmul just uses too much memory # We don't expect to have more than about 100 time frames, # so it is better to unroll the loop in python. # cov = torch.matmul(series.transpose(-1, -2), series) cov = None buf = series.new_empty([*batch, n_vox, n_vox]) for i in range(n_reduced): buf = torch.mul(series.transpose(-1, -2)[..., :, i, None], series[..., i, None, :], out=buf) if cov is None: cov = buf.clone() else: cov += buf cov /= py.prod(reduced) if keepdim: outshape = [1 if d in dim else s for d, s in enumerate(prespatial)] else: outshape = list(batch) if flatten: outshape_mean = outshape + [py.prod(spatial)] outshape += [py.prod(spatial)] * 2 else: outshape_mean = outshape + list(spatial) outshape += list(spatial) * 2 cov = cov.reshape(outshape) if return_mean: mean = mean.reshape(outshape_mean) return cov, mean return cov
def mrfield_greens(shape, absolute=0, membrane=0, bending=0, factor=1, voxel_size=1, dtype=None, device=None): """Generate the Greens function of a regulariser in Fourier space. Parameters ---------- shape : tuple[int] Output shape absolute : float, default=0.0001 Penalty on absolute values membrane : float, default=0.001 Penalty on membrane energy bending : float, default=0.2 Penalty on bending energy voxel_size : [sequence of[ float, default=1 Voxel size dtype : torch.dtype, optional device : torch.device, optional Returns ------- greens : (*shape, [dim, dim]) tensor """ # Adapted from the geodesic shooting code backend = dict(dtype=dtype, device=device) shape = py.make_tuple(shape) dim = len(shape) if not absolute: # we need some regularization to invert absolute = max(absolute, max(membrane, bending) * 1e-6) prm = dict(absolute=absolute, membrane=membrane, bending=bending, factor=factor, voxel_size=voxel_size, bound='dft') # allocate kernel = torch.zeros(shape, **backend) # only use center to generate kernel if bending: subkernel = kernel[tuple(slice(s // 2 - 2, s // 2 + 3) for s in shape)] subsize = 5 else: subkernel = kernel[tuple(slice(s // 2 - 1, s // 2 + 2) for s in shape)] subsize = 3 # generate kernel center = (subsize // 2, ) * dim subkernel[center] = 1 subkernel[...] = regulariser(subkernel, **prm, dim=dim) kernel = ifftshift(kernel, dim=range(dim)) # fourier transform # symmetric kernel -> real coefficients if utils.torch_version('>=', (1, 8)): kernel = torch.fft.fftn(kernel, dim=dim).real() else: if torch.backends.mkl.is_available: # use rfft kernel = torch.rfft(kernel, dim, onesided=False) else: zero = kernel.new_zeros([]).expand(kernel.shape) kernel = torch.stack([kernel, zero], dim=-1) kernel = torch.fft(kernel, dim) kernel = kernel[..., 0] # should be real kernel = kernel.reciprocal_() return kernel