def _smooth_for_reg(dat, mat, samp): """Smoothing for image registration. FWHM is computed from voxel size and sub-sampling amount. Parameters ---------- dat : (X, Y, Z) tensor_like 3D image volume. mat : (4, 4) tensor_like Affine matrix. samp : float Amount of sub-sampling (in mm). Returns ------- dat : (Nx, Ny, Nz) tensor_like Smoothed 3D image volume. """ if samp <= 0: return dat samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device) # Make smoothing kernel vx = voxel_size(mat).to(dat.device).type(dat.dtype) fwhm = torch.sqrt( torch.max(samp**2 - vx**2, torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx smo = smooth(('gauss', ) * 3, fwhm=fwhm, device=dat.device, dtype=dat.dtype, sep=True) # Padding amount for subsequent convolution size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4]) size_pad = (torch.tensor(size_pad) - 1) // 2 size_pad = tuple(size_pad.int().tolist()) # Smooth deformation with Gaussian kernel (by separable convolution) dat = pad(dat, size_pad, side='both') dat = dat[None, None, ...] dat = F.conv3d(dat, smo[0]) dat = F.conv3d(dat, smo[1]) dat = F.conv3d(dat, smo[2])[0, 0, ...] return dat
def backward2(self, h, x, w=None, min=None, max=None): """ Parameters ---------- h : (..., *bins, [*bins]) tensor x : (..., n, 2) tensor w : (..., n) tensor, optional min : (...) tensor_like, optional max : (...) tensor_like, optional Returns ------- h : (..., n, 2) tensor """ backend = dict(dtype=x.dtype, device=x.device) n = x.shape[-2] xbatch = x.shape[:-2] if w is not None: _, w = torch.broadcast_tensors(x[..., 0], w) batch = w.shape[:-1] x = x.expand([*batch, *x.shape[-2:]]) w = w.reshape([-1, n]) else: batch = xbatch x = x.reshape([-1, n, 2]) if h.shape[:-2] == batch: is_diag = True elif h.shape[:-4] == batch: is_diag = False else: raise ValueError('Don\'t know what to do with that shape') if min is None: min = x.min(-2, keepdim=True).values else: min = torch.as_tensor(min, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) if max is None: max = x.max(-2, keepdim=True).values else: max = torch.as_tensor(max, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) x = x.clone() bins = torch.as_tensor(self.bins, **backend) x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5) min = min.reshape([*xbatch, 2]) max = max.reshape([*xbatch, 2]) if is_diag: h = h.reshape([-1, *self.bins]) else: h = h.reshape([-1, *self.bins, *self.bins]) # smooth backward if any(self.fwhm): ker = kernels.smooth(fwhm=self.fwhm) if is_diag: ker = [k.square_() for k in ker] h = smooth(h, kernel=ker, bound=self.bound, dim=2) else: h = smooth(h, kernel=ker, bound=self.bound, dim=2) h = h.transpose(-4, -2).transpose(-3, -1) h = smooth(h, kernel=ker, bound=self.bound, dim=2) h = h.transpose(-4, -2).transpose(-3, -1) # push data into the histogram h = _jhistc_backward2(h, x, w, self.order, self.bound, self.extrapolate) h = h.mul_((bins / (max - min)).square_()) # reshape h = h.reshape([*batch, n, 2]) return h
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) amplitude = overload.get('amplitude', self.amplitude) fwhm = overload.get('fwhm', self.fwhm) channel = overload.get('channel', self.channel) basis = overload.get('basis', self.basis) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) # sample if parameters are callable mean = mean() if callable(mean) else mean amplitude = amplitude() if callable(amplitude) else amplitude fwhm = fwhm() if callable(fwhm) else fwhm # device/dtype mean = torch.as_tensor(mean, dtype=dtype, device=device) amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device) fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device) # reshape nb_dim = len(shape) full_shape = [batch, channel, *shape] mean = mean.expand(full_shape) amplitude = amplitude.expand(full_shape) fwhm = fwhm.expand([batch, channel, nb_dim]) conv = torch.nn.functional.conv1d if nb_dim == 1 else \ torch.nn.functional.conv2d if nb_dim == 2 else \ torch.nn.functional.conv3d if nb_dim == 3 else None # convert SE parameters to noise/kernel parameters sigma_se = fwhm / math.sqrt(8 * math.log(2)) sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim) amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt() fwhm = fwhm * math.sqrt(2) # smooth samples_b = [] for b in range(batch): samples_c = [] for c in range(channel): kernel = smooth('gauss', fwhm[b, c], basis=basis, device=device, dtype=dtype) # compute input shape pad_shape = [ shape[d] + kernel[d].shape[d + 2] - 1 for d in range(nb_dim) ] mean1 = ensure_shape(mean[b, c], pad_shape, mode='reflect2', side='both') amplitude1 = ensure_shape(amplitude[b, c], pad_shape, mode='reflect2', side='both') # generate sample sample = torch.distributions.Normal(mean1, amplitude1).sample() sample = sample[None, None, ...] # convolve for ker in kernel: sample = conv(sample, ker) samples_c.append(sample) samples_b.append(torch.cat(samples_c, dim=1)) sample = torch.cat(samples_b, dim=0) return sample
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf): """Estimates full width at half maximum (FWHM) and noise standard deviation (sd) of a 2D or 3D image. It is assumed that the image has been generated as: dat = Ky + n, where K is Gaussian smoothing with some FWHM and n is additive Gaussian noise. FWHM and n are estimated. Parameters ---------- dat : str or (*spatial) tensor Image data or path to nifti file vx : [sequence of] float, default=1 Voxel size verbose : {0, 1, 2}, default=0 Verbosity level: * 0: No verbosity * 1: Print FWHM and sd to screen * 2: 1 + show mask mn : float, optional Exclude values below mx : float, optional Exclude values above Returns ------- fwhm : (dim,) tensor Estimated FWHM sd : scalar tensor Estimated noise standard deviation. References ---------- ..[1] "Linked independent component analysis for multimodal data fusion." Appendix A Groves AR, Beckmann CF, Smith SM, Woolrich MW. Neuroimage. 2011 Feb 1;54(3):2198-217. """ if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): if vx is None: vx = get_voxel_size(dat.affine) dat = dat.fdata(rand=True, missing=0) dat = torch.as_tensor(dat) dim = dat.dim() if vx is None: vx = 1 vx = utils.make_vector(vx, dim) backend = utils.backend(dat) # Make mask msk = (dat > mn).bitwise_and_(dat <= mx) dat = dat.masked_fill(~msk, 0) # TODO: we should erode the mask so that only voxels whose neighbours # are in the mask are considered when computing gradients. if verbose >= 2: show_slices(msk) # Compute image gradient g = diff(dat, dim=range(dim), side='central', voxel_size=vx, bound='dft').abs_() slicer = (slice(1, -1), ) * dim g = g[(*slicer, None)] g[msk[slicer], :] = 0 g = g.reshape([-1, dim]).sum(0, dtype=torch.double) # Make dat have zero mean dat = dat[slicer] dat = dat[msk[slicer]] x0 = dat - dat.mean() # Compute FWHM fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double) fwhm = fwhm / g if verbose >= 1: print(f'FWHM={fwhm.tolist()}') # Compute noise standard deviation sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0] sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0] sz = 1.0 if dim == 3: sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0] sc = (sx * sy * sz) / dim sc.clamp_min_(1) sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc)) if verbose >= 1: print(f'sd={sd.tolist()}') return fwhm, sd
def _hist_2d(img0, img1, mx_int, fwhm): """Make 2D histogram, requires: * Images same size. * Images same min and max intensities (non-negative). Parameters ---------- img0 : (X, Y, Z) tensor_like First image volume. img1 : (X, Y, Z) tensor_like Second image volume. mx_int : int This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=511 -> H.shape = (512, 512)). fwhm : float Full-width at half max of Gaussian kernel, for smoothing histogram. Returns ---------- H : (mx_int + 1, mx_int + 1) tensor_like Joint intensity histogram. Notes ---------- Naive method for computing a 2D histogram: h = torch.zeros((mx_int + 1, mx_int + 1)) for n in range(num_vox): h[img0[n], mg1[n]] += 1 """ fwhm = (fwhm, ) * 2 # Convert each 'coordinate' of intensities to an index # (replicates the sub2ind function of MATLAB) img0 = img0.flatten().floor() img1 = img1.flatten().floor() sub = torch.stack((img0, img1), dim=1) # (num_vox, 2) to_ind = torch.tensor((1, mx_int + 1), dtype=sub.dtype, device=img0.device)[..., None] # (2, 1) ind = torch.tensordot(sub, to_ind, dims=([1], [0])) # (nvox, 1) # Build histogram H by adding up counts according to the indicies in ind H = torch.zeros(mx_int + 1, mx_int + 1, device=img0.device, dtype=ind.dtype) H.put_(ind.long(), torch.ones(1, device=img0.device, dtype=ind.dtype).expand_as(ind), accumulate=True) # Smoothing kernel smo = smooth(('gauss', ) * 2, fwhm=fwhm, device=img0.device, dtype=torch.float32, sep=True) # Pad p = (smo[0].shape[2], smo[1].shape[3]) p = (torch.tensor(p) - 1) // 2 p = tuple(p.int().tolist()) H = pad(H, p, side='both') # Smooth H = H[None, None, ...] H = F.conv2d(H, smo[0]) H = F.conv2d(H, smo[1]) H = H[0, 0, ...] # Clamp H = H.clamp_min(0.0) # Add eps H = H + 1e-7 # # Visualise histogram # import matplotlib.pyplot as plt # plt.figure(num=1) # plt.imshow(H.detach().cpu(), # cmap='coolwarm', interpolation='nearest', # aspect='equal', vmax=0.05*H.max()) # plt.axis('off') # plt.show() return H
def _proj_info(dim_y, mat_y, dim_x, mat_x, rigid=None, prof_ip=0, prof_tp=0, gap=0.0, device='cpu', scl=0.0, samp=0): """ Define projection operator object, to be used with _proj_apply. Args: dim_y ((int, int, int))): High-res image dimensions (3,). mat_y (torch.tensor): High-res affine matrix (4, 4). dim_x ((int, int, int))): Low-res image dimensions (3,). mat_x (torch.tensor): Low-res affine matrix (4, 4). rigid (torch.tensor): Rigid transformation aligning x to y (4, 4), defaults to eye(4). prof_ip (int, optional): In-plane slice profile (0=rect|1=tri|2=gauss), defaults to 0. prof_tp (int, optional): Through-plane slice profile (0=rect|1=tri|2=gauss), defaults to 0. gap (float, optional): Slice-gap between 0 and 1, defaults to 0. device (torch.device, optional): Device. Defaults to 'cpu'. scl (float, optional): Odd/even slice scaling, defaults to 0. Returns: po (_proj_op()): Projection operator object. """ # Get projection operator object po = _proj_op() # Data types dtype = torch.float64 dtype_smo_ker = torch.float32 # Output properties if not isinstance(dim_y, torch.Tensor): dim_y = torch.tensor(dim_y, device=device, dtype=dtype) po.dim_y = dim_y po.mat_y = mat_y po.vx_y = voxel_size(mat_y) # Input properties if not isinstance(dim_x, torch.Tensor): dim_x = torch.tensor(dim_x, device=device, dtype=dtype) po.dim_x = dim_x po.mat_x = mat_x po.vx_x = voxel_size(mat_x) # Number of dimensions ndim = len(dim_y) one = torch.tensor((1, ) * ndim, device=device, dtype=torch.float64) if rigid is None: po.rigid = torch.eye(ndim + 1, device=device, dtype=dtype) else: po.rigid = rigid.type(dtype).to(device) # Slice-profile gap_cn = torch.zeros(ndim, device=device, dtype=dtype) profile = torch.tensor((prof_ip, ) * ndim, device=device, dtype=dtype) dim_thick = torch.max(po.vx_x, dim=0)[1] gap_cn[dim_thick] = gap profile[dim_thick] = prof_tp po.dim_thick = dim_thick if samp > 0: # Sub-sampling samp = torch.tensor((samp, ) * ndim, device=device, dtype=torch.float64) # Intermediate to lowres sk = torch.max(one, torch.floor(samp * one / po.vx_x + 0.5)) D_x = torch.diag(torch.cat((sk, one[0, None]))) po.D_x = D_x # Modulate lowres po.mat_x = po.mat_x.mm(D_x) po.dim_x = D_x.inverse()[:ndim, :ndim].mm( po.dim_x[..., None]).floor().squeeze() if torch.sum(torch.abs(po.vx_x - po.vx_x)) > 1e-4: # Intermediate to highres (only for superres) sk = torch.max(one, torch.floor(samp * one / po.vx_y + 0.5)) D_y = torch.diag(torch.cat((sk, one[0, None]))) po.D_y = D_y # Modulate highres po.mat_y = po.mat_y.mm(D_y) po.vx_y = voxel_size(po.mat_y) po.dim_y = D_y.inverse()[:ndim, :ndim].mm( po.dim_y[..., None]).floor().squeeze() po.vx_x = voxel_size(po.mat_x) # Make intermediate ratio = torch.solve(po.mat_x, po.mat_y)[0] # mat_y\mat_x ratio = (ratio[:ndim, :ndim]**2).sum(0).sqrt() ratio = ratio.ceil().clamp(1) # ratio low/high >= 1 mat_yx = torch.cat((ratio, torch.ones(1, device=device, dtype=dtype))).diag() po.mat_yx = po.mat_x.matmul(mat_yx.inverse()) # mat_x/mat_yx po.dim_yx = (po.dim_x - 1) * ratio + 1 # Make elements with ratio <= 1 use dirac profile profile[ratio == 1] = -1 profile = profile.int().tolist() # Make smoothing kernel (slice-profile) fwhm = (1. - gap_cn) * ratio smo_ker = smooth(profile, fwhm, sep=False, dtype=dtype_smo_ker, device=device) po.smo_ker = smo_ker # Add offset to intermediate space off = torch.tensor(smo_ker.shape[-ndim:], dtype=dtype, device=device) off = -(off - 1) // 2 # set offset mat_off = torch.eye(ndim + 1, dtype=torch.float64, device=device) mat_off[:ndim, -1] = off po.dim_yx = po.dim_yx + 2 * torch.abs(off) po.mat_yx = torch.matmul(po.mat_yx, mat_off) # Odd/even slice scaling if isinstance(scl, torch.Tensor): po.scl = scl else: po.scl = torch.tensor(scl, dtype=torch.float32, device=device) # To tuple of ints po.dim_y = tuple(po.dim_y.int().tolist()) po.dim_yx = tuple(po.dim_yx.int().tolist()) po.dim_x = tuple(po.dim_x.int().tolist()) po.ratio = tuple(ratio.int().tolist()) return po