def __init__(self, *args): """ Parameters ---------- Either quat : (..., 4) tensor Or orientation : (..., 3) tensor attitude : (...) tensor Or i, j, k, r : tensors """ if len(args) == 1: ijkr = torch.as_tensor(args[0]) i = ijkr[..., 0] j = ijkr[..., 1] k = ijkr[..., 2] r = ijkr[..., 3] elif len(args) == 2: ijk, r = utils.to_max_backend(*args) i = ijk[..., 0] j = ijk[..., 1] k = ijk[..., 2] elif len(args) == 4: i, j, k, r = utils.to_max_backend(*args) else: raise ValueError('Expected 1, 2 or 4 arguments') self.i = i self.j = j self.k = k self.r = r
def build_se(sqdist, sigma, lam, **backend): """Build squared-exponential covariance matrix Parameters ---------- sqdist : sequence[int] or (vox, vox) tensor If a tensor -> it is the pre-computed squared distance map If a tuple -> it is the shape and we build the distance map sigma : (*batch) tensor_like Amplitude lam : (*batch) tensor_like Length-scale Returns ------- cov : (*batch, vox, vox) tensor Covariance matrix """ lam, sigma = utils.to_max_backend(lam, sigma, **backend, force_float=True) backend = utils.backend(lam) # Build SE covariance matrix if not torch.is_tensor(sqdist): shape = sqdist e = dist_map(shape, **backend) else: e = sqdist.to(**backend) del sqdist lam = lam[..., None, None] sigma = sigma[..., None, None] e = e.mul(-0.5 / (lam**2)).exp_().mul_(sigma**2) return e
def irls_tukey_reweight(moving, fixed, lam=1, c=4.685, joint=False, dim=None, mask=None): """Update iteratively reweighted least-squares weights for Tukey's biweight Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Equivalent to Gaussian noise precision (used to standardize the residuals) c : float, default=4.685 Tukey's threshold. Approximately equal to a number of standard deviations above which the loss is capped. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions Returns ------- weights : (..., K|1, *spatial) tensor IRLS weights """ if lam is None: lam = 1 c = c * c fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions weights = (moving - fixed).square_().mul_(lam) if mask is not None: weights = weights.mul_(mask) if joint: weights = weights.sum(dim=-dim - 1, keepdims=True) zeromsk = weights > c weights = weights.div_(-c).add_(1).square() weights[zeromsk].zero_() return weights
def irls_laplace_reweight(moving, fixed, lam=1, joint=False, eps=1e-5, dim=None, mask=None): """Update iteratively reweighted least-squares weights for l1 Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Inverse-squared scale parameter of the Laplace distribution. (equivalent to Gaussian noise precision) dim : int, default=`fixed.dim() - 1` Number of spatial dimensions Returns ------- weights : (..., K|1, *spatial) tensor IRLS weights """ if lam is None: lam = 1 fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions weights = (moving - fixed).square_().mul_(lam) if mask is not None: weights = weights.mul_(mask) if joint: weights = weights.sum(dim=-dim - 1, keepdims=True) weights = weights.sqrt_().clamp_min_(eps).reciprocal_() if mask is not None: weights = weights.masked_fill_(mask == 0, 0) return weights
def mrfield_greens_apply(mom, greens): """Apply the Greens function to a momentum field. Parameters ---------- mom : (..., *spatial) tensor Momentum greens : (*spatial) tensor Greens function Returns ------- field : (..., *spatial) tensor Field """ mom, greens = utils.to_max_backend(mom, greens) dim = greens.dim() # fourier transform if utils.torch_version('>=', (1, 8)): mom = torch.fft.fftn(mom, dim=dim) else: if torch.backends.mkl.is_available: # use rfft mom = torch.rfft(mom, dim, onesided=False) else: zero = mom.new_zeros([]).expand(mom.shape) mom = torch.stack([mom, zero], dim=-1) mom = torch.fft(mom, dim) # voxel wise multiplication mom = mom * greens[..., None] # inverse fourier transform if utils.torch_version('>=', (1, 8)): mom = torch.fft.ifftn(mom, dim=dim).real() else: mom = torch.ifft(mom, dim)[..., 0] return mom
def mp2rage(pd, r1, r2s=None, transmit=None, receive=None, gfactor=None, tr=6.25, ti1=0.8, ti2=2.2, tx=None, te=None, fa=(4, 5), n=160, eff=0.96, sigma=None, device=None, return_combined=True): """Simulate data generated by a (simplified) MP2RAGE sequence. The defaults are parameters used at 3T in the original MP2RAGE paper. However, I don't get a nice image with these parameters when applied to maps obtained at 3T with the hmri toolbox. Here are (unrealistic) parameters that seem to give a decent contrast: tr=6.25, ti1=1.4, ti2=4.5, tx=5.8e-3, fa=(4, 5), n=160, eff=0.96 Tissue parameters ----------------- pd : tensor_like Proton density r1 : tensor_like Longitudinal relaxation rate, in 1/sec r2s : tensor_like, optional Transverse relaxation rate, in 1/sec. If not provided, T2*-bias is not included. Fields ------ transmit : tensor_like, optional Transmit B1 field receive : tensor_like, optional Receive B1 field gfactor : tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- tr : float default=6.25 Full Repetition time, in sec. (Time between two inversion pulses) ti1 : float, default=0.8 First inversion time, in sec. (Time between inversion pulse and middle of the first echo train) ti2 : float, default=2.2 Second inversion time, in sec. (Time between inversion pulse and middle of the second echo train) tx : float, default=te*2 or 5.8e-3 Excitation repetition time, in sec. (Time between two excitation pulses within the echo train) te : float, default=tx/2 Echo time, in sec. fa : float or (float, float), default=(4, 5) Flip angle of the first and second acquisition block, in deg If only one value, it is shared between the blocks. n : int, default=160 Number of excitation pulses (= phase encoding steps) per train. eff : float, default=0.96 Efficiency of the inversion pulse. Noise ----- sigma : float, optional Standard-deviation of the sampled Rician noise (no sampling if `None`) Returns ------- mp2rage : tensor, if return_combined is True Simulated MP2RAGE image image1 : tensor, if return_combined is False Image at first inversion time image2 : tensor, if return_combined is False Image at second inversion time References ---------- ..[1] "MP2RAGE, a self bias-field corrected sequence for improved segmentation and T1-mapping at high field." Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R. Neuroimage. 2010 Jan 15;49(2):1271-81. doi: 10.1016/j.neuroimage.2009.10.002 """ pd, r1, r2s, transmit, receive, gfactor \ = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor) pd, r1, r2s, transmit, receive, gfactor \ = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device) if tx is None and te is None: tx = 5.8e-3 tx = tx or 2 * te # Time between excitation pulses te = te or tx / 2 # Echo time fa1, fa2 = py.make_list(fa, 2) fa1 = fa1 * constants.pi / 180 # Flip angle of first GRE block fa2 = fa2 * constants.pi / 180 # Flip angle of second GRE block n = n or min(pd.shape) # Number of readouts (PE steps) per loop tr1 = n * tx # First GRE block tr2 = n * tx # Second GRE block tp = ti1 - tr1 / 2 # Preparation time tw = ti2 - tr2 / 2 - ti1 - tr1 / 2 # Wait time between GRE blocks td = tr - ti2 - tr2 / 2 # Recovery time if return_combined and not sigma: m = mp2rage_nonoise(pd, r1, tx, tp, tw, td, tr, fa1, fa2, n, eff, transmit) m = torch.where(~torch.isfinite(m), m.new_zeros([1]), m) return m mi1, mi2 = mp2rage_uncombined(pd, r1, r2s, tx, tp, tw, td, tr, te, fa1, fa2, n, eff, transmit, receive) # noise mi1 = add_noise(mi1, std=sigma, gfactor=gfactor) mi2 = add_noise(mi2, std=sigma, gfactor=gfactor) if return_combined: m = mp2rage_from_ir(mi1, mi2) m = torch.where(~torch.isfinite(m), m.new_zeros([1]), m) return m else: mi1 = torch.where(~torch.isfinite(mi1), mi1.new_zeros([]), mi1) mi2 = torch.where(~torch.isfinite(mi2), mi2.new_zeros([]), mi2) return mi1, mi2
def fs_to_affine(shape, voxel_size=1., x=None, y=None, z=None, c=0., source='voxel', dest='ras'): """Transform FreeSurfer orientation parameters into an affine matrix. The returned matrix is effectively a "<source> to <dest>" transform. Parameters ---------- shape : sequence of int voxel_size : [sequence of] float, default=1 x : [sequence of] float, default=[1, 0, 0] y: [sequence of] float, default=[0, 1, 0] z: [sequence of] float, default=[0, 0, 1] c: [sequence of] float, default=0 source : {'voxel', 'physical', 'ras'}, default='voxel' dest : {'voxel', 'physical', 'ras'}, default='ras' Returns ------- affine : (4, 4) tensor """ dim = len(shape) shape, voxel_size, x, y, z, c \ = utils.to_max_backend(shape, voxel_size, x, y, z, c) backend = dict(dtype=shape.dtype, device=shape.device) voxel_size = utils.make_vector(voxel_size, dim) if x is None: x = [1, 0, 0] if y is None: y = [0, 1, 0] if z is None: z = [0, 0, 1] x = utils.make_vector(x, dim) y = utils.make_vector(y, dim) z = utils.make_vector(z, dim) c = utils.make_vector(c, dim) shift = shape / 2. shift = -shift * voxel_size vox2phys = Orientation(shift, voxel_size).affine() phys2ras = XYZC(x, y, z, c).affine() affines = [] if source.lower().startswith('vox'): affines.append(vox2phys) middle_space = 'phys' elif source.lower().startswith('phys'): if dest.lower().startswith('vox'): affines.append(affine_inv(vox2phys)) middle_space = 'vox' else: affines.append(phys2ras) middle_space = 'ras' elif source.lower() == 'ras': affines.append(affine_inv(phys2ras)) middle_space = 'phys' else: # We need a matrix to switch orientations affines.append(layout_matrix(source, **backend)) middle_space = 'ras' if dest.lower().startswith('phys'): if middle_space == 'vox': affines.append(vox2phys) elif middle_space == 'ras': affines.append(affine_inv(phys2ras)) elif dest.lower().startswith('vox'): if middle_space == 'phys': affines.append(affine_inv(vox2phys)) elif middle_space == 'ras': affines.append(affine_inv(phys2ras)) affines.append(affine_inv(vox2phys)) elif dest.lower().startswith('ras'): if middle_space == 'phys': affines.append(phys2ras) elif middle_space.lower().startswith('vox'): affines.append(vox2phys) affines.append(phys2ras) else: if middle_space == 'phys': affines.append(affine_inv(phys2ras)) elif middle_space == 'vox': affines.append(vox2phys) affines.append(phys2ras) layout = layout_matrix(dest, **backend) affines.append(affine_inv(layout)) affine, *affines = affines for aff in affines: affine = affine_matmul(aff, affine) return affine
def register(fixed=None, moving=None, dim=None, loss='mse', basis='CSO', optim='ogm', max_iter=500, lr=1, ls=6, plot=False, klosure=RegisterStep, logaff=None, verbose=True): """Affine registration between two images using Lie groups. Parameters ---------- fixed : (..., K, *spatial) tensor Fixed image moving : (..., K, *spatial) tensor Moving image dim : int, default=`fixed.dim() - 1` Number of spatial dimensions loss : {'mse', 'cat'} or OptimizationLoss, default='mse' 'mse': Mean-squared error 'cat': Categorical cross-entropy optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='ogm' 'gn' : Gauss-Newton 'gd' : Gradient descent 'momentum' : Gradient descent with momentum 'nesterov' : Nesterov-accelerated gradient descent 'ogm' : Optimized gradient descent (Kim & Fessler) 'lbfgs' : Limited-memory BFGS max_iter : int, default=100 Maximum number of Gauss-Newton or Gradient descent iterations lr : float, default=1 Learning rate. ls : int, default=6 Number of line search iterations. plot : bool, default=False Plot progress Returns ------- logaff : (...) tensor Displacement field. """ # If no inputs provided: demo "circle to square" if fixed is None or moving is None: fixed, moving = phantoms.demo_register(cat=(loss == 'cat')) # init tensors fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) basis = spatial.affine_basis(basis, dim, **utils.backend(fixed)) if logaff is None: logaff = torch.zeros(len(basis), **utils.backend(fixed)) # logaff = torch.zeros(12, **utils.backend(fixed)) # init optimizer optim = regutils.make_iteroptim_affine(optim, lr, ls, max_iter) # init loss loss = losses.make_loss(loss, dim) # optimize if verbose: print( f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}' ) print('-' * 63) closure = klosure(moving, fixed, loss, basis=basis, verbose=verbose, plot=plot, max_iter=optim.max_iter) logaff = optim.iter(logaff, closure) if verbose: print('') return logaff
def conv(dim, tensor, kernel, bias=None, stride=1, padding=0, bound='zero', dilation=1, groups=1): """Perform a convolution Parameters ---------- dim : {1, 2, 3} Number of spatial dimensions tensor : (*batch, [channel_in,] *spatial_in) tensor Input tensor kernel : ([channel_in, channel_out,] *kernel_size) tensor Convolution kernel bias : ([channel_out,]) tensor, optional Bias tensor stride : int or sequence[int], default=1 Strides between output elements, padding : 'same' or int or sequence[int], default=0 Padding performed before the convolution. If 'same', the padding is chosen such that the shape of the output tensor is `spatial_in // stride`. bound : str, default='zero' Boundary conditions used in the padding. dilation : int or sequence[int], default=1 Dilation of the kernel. groups : int, default=1 Returns ------- convolved : (*batch, [channel_out], *spatial_out) """ # move everything to the same dtype/device tensor, kernel, bias = utils.to_max_backend(tensor, kernel, bias) # sanity checks + reshape for torch's conv if kernel.dim() not in (dim, dim + 2): raise ValueError('Kernel shape should be (*kernel_size) or ' '(channel_in, channel_out, *kernel_size) but ' 'got {}'.format(kernel.shape)) has_channels = kernel.dim() == dim + 2 channels_in = kernel.shape[0] if has_channels else 1 channels_out = kernel.shape[1] if has_channels else 1 kernel_size = kernel.shape[(2*has_channels):] kernel = kernel.reshape([channels_in, channels_out, *kernel_size]) batch = tensor.shape[:-(dim+has_channels)] spatial_in = tensor.shape[(-dim):] if has_channels and tensor.shape[-(dim+has_channels)] != channels_in: raise ValueError('Number of input channels not consistent: ' 'Got {} (kernel) and {} (tensor).' .format( channels_in, tensor.shape[-(dim+has_channels)])) tensor = tensor.reshape([-1, channels_in, *spatial_in]) if bias: bias = bias.flatten() if bias.numel() == 1: bias = bias.expand(channels_out) elif bias.numel() != channels_out: raise ValueError('Number of output channels not consistent: ' 'Got {} (kernel) and {} (bias).' .format( channels_out, bias.numel())) # Perform padding dilation = make_list(dilation, dim) padding = make_list(padding, dim) padding = [0 if p == 'valid' else 'same' if p == 'auto' else p for p in padding] for i in range(dim): if isinstance(padding[i], str): assert padding[i].lower() == 'same' if kernel_size[i] % 2 == 0: raise ValueError('Cannot compute "same" padding ' 'for even-sized kernels.') padding[i] = dilation[i] * (kernel_size[i] // 2) if bound != 'zero' and sum(padding) > 0: tensor = core.utils.pad(tensor, padding, bound, side='both') padding = 0 conv_fn = (F.conv1d if dim == 1 else F.conv2d if dim == 2 else F.conv3d if dim == 3 else None) if not conv_fn: raise NotImplementedError('Convolution is only implemented in ' 'dimension 1, 2 or 3.') tensor = conv_fn(tensor, kernel, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) spatial_out = tensor.shape[(-dim):] channels_out = [channels_out] if has_channels else [] tensor = tensor.reshape([*batch, *channels_out, *spatial_out]) return tensor
def lcc(moving, fixed, dim=None, patch=20, stride=1, lam=1, mode='g', grad=True, hess=True, mask=None): """Local correlation coefficient (squared) This function implements a squared version of Cachier and Pennec's local correlation coefficient, so that anti-correlations are not penalized. Parameters ---------- moving : (..., K, *spatial) tensor Moving image with K channels. fixed : (..., K, *spatial) tensor Fixed image with K channels. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions. patch : int, default=5 Patch size lam : float or ([B], K|1, [*spatial]) tensor_like, default=1 Precision of the NCC distribution grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return approximate Hessian Returns ------- ll : () tensor References ---------- ..[1] "3D Non-Rigid Registration by Gradient Descent on a Gaussian- Windowed Similarity Measure using Convolutions" Pascal Cachier, Xavier Pennec MMBIA (2000) """ if moving.requires_grad: sqrt_ = torch.sqrt div_ = torch.div else: sqrt_ = torch.sqrt_ div_ = lambda x, y: x.div_(y) fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) dim = dim or (fixed.dim() - 1) shape = fixed.shape[-dim:] if mask is not None: mask = mask.to(**utils.backend(fixed)) else: mask = fixed.new_ones(fixed.shape[-dim:]) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) patch = list(map(float, py.ensure_list(patch))) stride = py.ensure_list(stride) stride = [s or 0 for s in stride] fwd = lambda x: local_mean( x, patch, stride, dim=dim, mode=mode, mask=mask, cache=local_cache) bwd = lambda x: local_mean(x, patch, stride, dim=dim, mode=mode, mask=mask, backward=True, shape=shape, cache=local_cache) sumall = lambda x: x.sum(list(range(-dim, 0)), keepdim=True) # compute ncc within each patch mom0, moving_mean, fixed_mean, moving_std, fixed_std, corr = \ _suffstat(fwd, moving, fixed) mom0 = mom0.div_(sumall(mom0).clamp_min_(1e-5)).mul_(lam) moving_std = sqrt_(moving_std.addcmul_(moving_mean, moving_mean, value=-1)) fixed_std = sqrt_(fixed_std.addcmul_(fixed_mean, fixed_mean, value=-1)) moving_std.clamp_min_(1e-5) fixed_std.clamp_min_(1e-5) corr = div_( div_(corr.addcmul_(moving_mean, fixed_mean, value=-1), moving_std), fixed_std) corr2 = corr.square().neg_().add_(1).clamp_min_(1e-8) out = [] if grad or hess: h = (corr / moving_std).square_().mul_(mom0).div_(corr2) h = bwd(h) if grad: # g = G' * (corr.*(corr.*xmean./xstd - ymean./ystd)./xstd) # - x .* (G' * (corr./ xstd).^2) # + y .* (G' * (corr ./ (xstd.*ystd))) # g = -2 * g fixed_mean = fixed_mean.div_(fixed_std) moving_mean = moving_mean.div_(moving_std) g = fixed_mean.addcmul_(corr, moving_mean, value=-1) fixed_mean = moving_mean = None g = g.mul_(corr).div_(moving_std).mul_(mom0).div_(corr2) g = bwd(g) g = g.addcmul_(h, moving) g = g.addcmul_(bwd( corr.div_(moving_std).div_(fixed_std).mul_(mom0).div_(corr2)), fixed, value=-1) g = g.mul_(2) out.append(g) if hess: # h = 2 * (G' * (corr./ xstd).^2) h = h.mul_(2) out.append(h) # return stuff corr = corr2.log_().mul_(mom0) corr = corr.sum() out = [corr, *out] return tuple(out) if len(out) > 1 else out[0]
def fse(pd, r1, r2=None, receive=None, gfactor=None, te=0.02, tr=5, sigma=None, device=None): """Simulate data generated by a (simplified) Fast Spin-Echo (FSE) sequence. Tissue parameters ----------------- pd : tensor_like Proton density r1 : tensor_like Longitudinal relaxation rate, in 1/sec r2 : tensor_like, optional Transverse relaxation rate, in 1/sec. Fields ------ receive : tensor_like, optional Receive B1 field gfactor : tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- te : float, default=3e-3 Echo time, in sec tr : float default=2.3 Repetition time, in sec. Noise ----- sigma : float, optional Standard-deviation of the sampled noise (no sampling if `None`) Returns ------- sim : tensor Simulated FSE image """ pd, r1, r2, receive, gfactor \ = utils.to_max_backend(pd, r1, r2, receive, gfactor) pd, r1, r2, receive, gfactor \ = utils.to(pd, r1, r2, receive, gfactor, device=device) if receive is not None: pd = pd * receive del receive e1 = r1.mul(tr).neg_().exp_() e2 = r2.mul(te).neg_().exp_() signal = pd * (1 - e1) * e2 # noise signal = add_noise(signal, std=sigma) return signal
def physio_sample(shape=None, sigma_p=0.008, lam_p=0.4, sigma_0=2., sigma_r=1., lam_r=0.2, signal=100., repeats=100, sampler='svd', **backend): """Sample from the fMRI physiological model. Parameters ---------- shape : list[int], default=[32, 32] Shape of the field of view sigma_p : float, default=0.008 Amplitude of the physiological noise. lam_p : float, default=0.4 Length-scale of the physiological noise (i.e., smoothness). sigma_0 : float, default=2.0 Amplitude of the thermal noise. sigma_r : float, default=1. Amplitude of the reconstruction filter. lam_r : float, default=0.4 Length-scale of the reconstruction filter. signal : float,default=100. Mean signal. repeats : int, default=100 Number of repeats in the time series Returns ------- time_series : (repeats, *shape) tensor[dtype] fMRI time series. Forward model: ReconFilter(Signal * Physio + Thermal) replicate_series : (repeats, *shape) tensor[dtype] Replicate series. Forward model: ReconFilter(Signal + Thermal) """ shape = [32, 32] if shape is None else shape dim = len(shape) sigma_p_recon = sigma_p * sigma_r * (1 + 2 * (lam_r**2) / (lam_p**2))**(-dim / 4) sigma_0_recon = sigma_0 * sigma_r * (4. * constants.pi * lam_r**2)**(-dim / 4) lam_p_recon = (lam_p**2 + 2. * lam_r**2)**0.5 lam_0_recon = (2.**0.5) * lam_r param = sigma_p_recon, sigma_0_recon, lam_p_recon, lam_0_recon param = utils.to_max_backend(*param, **backend) sigma_p_recon, sigma_0_recon, lam_p_recon, lam_0_recon = param backend = utils.backend(sigma_p_recon) # thermal noise (*) recon tr = lambda: se_sample(shape, sigma_0_recon, lam_0_recon, repeats=repeats, sampler=sampler, **backend) # physio noise (*) recon pr = lambda: se_sample(shape, sigma_p_recon, lam_p_recon, repeats=repeats, sampler=sampler, **backend) time_series = signal * (1. + pr()) + tr() replicate_series = signal + tr() return time_series, replicate_series
def compose(self, orient_in, deformation, orient_mean, affine=None, orient_out=None, shape_out=None): """Composes a deformation defined in a mean space to an image space. Parameters ---------- orient_in : (4, 4) tensor Orientation of the input image deformation : (*shape_mean, 3) tensor Random deformation orient_mean : (4, 4) tensor Orientation of the mean space (where the deformation is) affine : (4, 4) tensor, default=identity Random affine orient_out : (4, 4) tensor, default=orient_in Orientation of the output image shape_out : sequence[int], default=shape_mean Shape of the output image Returns ------- grid : (*shape_out, 3) Voxel-to-voxel transform """ if orient_out is None: orient_out = orient_in if shape_out is None: shape_out = deformation.shape[:-1] if affine is None: affine = torch.eye(4, 4, device=orient_in.device, dtype=orient_in.dtype) shape_mean = deformation.shape[:-1] orient_in, affine, deformation, orient_mean, orient_out \ = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out) backend = utils.backend(deformation) eye = torch.eye(4, **backend) # Compose deformation on the right right_affine = spatial.affine_lmdiv(orient_mean, orient_out) if not (shape_mean == shape_out and right_affine.all_close(eye)): # the mean space and native space are not the same # we must compose the diffeo with a dense affine transform # we write the diffeo as an identity plus a displacement # (id + disp)(aff) = aff + disp(aff) # ------- # to displacement deformation = deformation - spatial.identity_grid( deformation.shape[:-1], **backend) trf = spatial.affine_grid(right_affine, shape_out) deformation = spatial.grid_pull(utils.movedim(deformation, -1, 0)[None], trf[None], bound='dft', extrapolate=True) deformation = utils.movedim(deformation[0], 0, -1) trf = trf + deformation # add displacement # Compose deformation on the left # the output of the diffeo(right) are mean_space voxels # we must compose on the left with `in\(aff(mean))` # ------- left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in), affine) left_affine = spatial.affine_matmul(left_affine, orient_mean) trf = spatial.affine_matvec(left_affine, trf) return trf
def register(fixed=None, moving=None, dim=None, lam=1., loss='mse', optim='nesterov', hilbert=None, max_iter=500, sub_iter=16, lr=None, ls=0, plot=False, klosure=RegisterStep, kernel=None, **prm): """Nonlinear registration between two images using smooth displacements. Parameters ---------- fixed : (..., K, *spatial) tensor Fixed image moving : (..., K, *spatial) tensor Moving image dim : int, default=`fixed.dim() - 1` Number of spatial dimensions lam : float, default=1 Modulate regularisation loss : {'mse', 'cat'} or OptimizationLoss, default='mse' 'mse': Mean-squared error 'cat': Categorical cross-entropy optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='relax' 'relax' : Gauss-Newton (linear system solved by relaxation) 'cg' : Gauss-Newton (linear system solved by conjugate gradient) 'gd' : Gradient descent 'momentum' : Gradient descent with momentum 'nesterov' : Nesterov-accelerated gradient descent 'lbfgs' : Limited-memory BFGS hilbert : bool, default=True Use hilbert preconditioning (not used if optim is second order) max_iter : int, default=100 Maximum number of Gauss-Newton or Gradient descent iterations sub_iter : int, default=16 Number of relax/cg iterations per GN step lr : float, default=1 Learning rate. ls : int, default=0 Number of line search iterations. absolute : float, default=1e-4 Penalty on absolute displacements membrane : float, default=1e-3 Penalty on first derivatives bending : float, default=0.2 Penalty on second derivatives lame : (float, float), default=(0.05, 0.2) Penalty on zooms and shears Returns ------- disp : (..., *spatial, dim) tensor Displacement field. """ defaults_velocity(prm) # If no inputs provided: demo "circle to square" if fixed is None or moving is None: fixed, moving = phantoms.demo_register(cat=(loss == 'cat')) # init tensors fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) shape = fixed.shape[-dim:] lam = lam / py.prod(shape) prm['factor'] = lam velshape = [*fixed.shape[:-dim - 1], *shape, dim] vel = torch.zeros(velshape, **utils.backend(fixed)) # init optimizer optim = regutils.make_iteroptim_grid(optim, lr, ls, max_iter, sub_iter, **prm) if hilbert is None: hilbert = not optim.requires_hess if hilbert and kernel is None: kernel = spatial.greens(shape, **prm, **utils.backend(fixed)) if kernel is not None: optim.preconditioner = lambda x: spatial.greens_apply(x, kernel) # init loss loss = losses.make_loss(loss, dim) print( f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}' ) print('-' * 63) closure = klosure(moving, fixed, loss, plot=plot, max_iter=optim.max_iter, **prm) vel = optim.iter(vel, closure) print('') return vel
def mprage(pd, r1, r2s=None, transmit=None, receive=None, gfactor=None, tr=2.3, ti=0.9, tx=None, te=None, fa=9, n=160, eff=0.96, sigma=None, device=None): """Simulate data generated by a (simplified) MP-RAGE sequence. Default parameters mimic the ADNI-3 protocol on 3T Siemens scanners. Our Implementation is based on the MP2RAGE paper, where the sequence is stripped from the second GRE readout block. Tissue parameters ----------------- pd : tensor_like Proton density r1 : tensor_like Longitudinal relaxation rate, in 1/sec r2s : tensor_like, optional Transverse relaxation rate, in 1/sec. If not provided, T2*-bias is not included. Fields ------ transmit : tensor_like, optional Transmit B1 field receive : tensor_like, optional Receive B1 field gfactor : tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- tr : float default=2.3 Repetition time, in sec. (Time between two inversion pulses) ti : float, default=0.9 Inversion time, in sec. (Time between inversion pulse and middle of the echo train) tx : float, default=2*te or 6e-3 Excitation repetition time, in sec (Time between two excitation pulses within the echo train) te : float, default=tx/2 Echo time, in sec fa : float, default=9 Flip angle, in deg n : int, default=160 Number of excitation pulses (= phase encoding steps) per train. eff : float, default=0.96 Efficiency of the inversion pulse. Noise ----- sigma : float, optional Standard-deviation of the sampled Rician noise (no sampling if `None`) Returns ------- sim : tensor Simulated MPRAGE image References ---------- ..[1] "MP2RAGE, a self bias-field corrected sequence for improved segmentation and T1-mapping at high field." Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R. Neuroimage. 2010 Jan 15;49(2):1271-81. doi: 10.1016/j.neuroimage.2009.10.002 """ pd, r1, r2s, transmit, receive, gfactor \ = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor) pd, r1, r2s, transmit, receive, gfactor \ = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device) backend = utils.backend(pd) if tx is None and te is None: tx = 6e-3 tx = tx or 2 * te # Time between excitation pulses te = te or tx / 2 # Echo time fa = fa * constants.pi / 180 # Flip angle of GRE block n = n or min(pd.shape) # Number of readouts (PE steps) per loop tr1 = n * tx # GRE block tp = ti - tr1 / 2 # Preparation time td = tr - ti - tr1 / 2 # Recovery time m = n // 2 # Middle of echo train if transmit is not None: fa = transmit * fa del transmit fa = torch.as_tensor(fa, **backend) # precompute exponential terms ex = r1.mul(-tx).exp() ep = r1.mul(-tp).exp() ed = r1.mul(-td).exp() e1 = r1.mul(-tr).exp() c = fa.cos() # steady state s = (1 - ep) * (c * ex).pow(n) s = s + (1 - ex) * (1 - (c * ex).pow(n)) / (1 - c * ex) s = s * ed + (1 - ed) s = s * pd / (1 + eff * c.pow(n) * e1) # IR component s = -eff * s * ep / pd + (1 - ep) s = s * (c * ex).pow(m - 1) s = s + (1 - ex) * (1 - (c * ex).pow(m - 1)) / (1 - c * ex) s = s * fa.sin() s = s.abs() # Modulation (PD, B1-, R2*) if receive is not None: pd = pd * receive del receive s = s * pd if r2s is not None: e2 = r2s.mul(-te).exp_() s = s * e2 del r2s # noise s = add_noise(s, std=sigma, gfactor=gfactor) return s
def mp2rage_old(pd, r1, r2s=None, transmit=None, receive=None, gfactor=None, tr=6.25, ti1=0.8, ti2=2.2, tx=None, te=None, fa=(4, 5), n=160, eff=0.96, sigma=None, device=None, return_combined=True): """Simulate data generated by a (simplified) MP2RAGE sequence. The defaults are parameters used at 3T in the original MP2RAGE paper. However, I don't get a nice image with these parameters when applied to maps obtained at 3T with the hmri toolbox. Here are (unrealistic) parameters that seem to give a decent contrast: tr=6.25, ti1=1.4, ti2=4.5, tx=5.8e-3, fa=(4, 5), n=160, eff=0.96 Tissue parameters ----------------- pd : tensor_like Proton density r1 : tensor_like Longitudinal relaxation rate, in 1/sec r2s : tensor_like, optional Transverse relaxation rate, in 1/sec. If not provided, T2*-bias is not included. Fields ------ transmit : tensor_like, optional Transmit B1 field receive : tensor_like, optional Receive B1 field gfactor : tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- tr : float default=6.25 Full Repetition time, in sec. (Time between two inversion pulses) ti1 : float, default=0.8 First inversion time, in sec. (Time between inversion pulse and middle of the first echo train) ti2 : float, default=2.2 Second inversion time, in sec. (Time between inversion pulse and middle of the second echo train) tx : float, default=te*2 or 5.8e-3 Excitation repetition time, in sec. (Time between two excitation pulses within the echo train) te : float, default=minitr/2 Echo time, in sec. fa : float or (float, float), default=(4, 5) Flip angle of the first and second acquisition block, in deg If only one value, it is shared between the blocks. n : int, default=160 Number of excitation pulses (= phase encoding steps) per train. eff : float, default=0.96 Efficiency of the inversion pulse. Noise ----- sigma : float, optional Standard-deviation of the sampled Rician noise (no sampling if `None`) Returns ------- mp2rage : tensor, if return_combined is True Simulated MP2RAGE image image1 : tensor, if return_combined is False Image at first inversion time image2 : tensor, if return_combined is False Image at second inversion time References ---------- ..[1] "MP2RAGE, a self bias-field corrected sequence for improved segmentation and T1-mapping at high field." Marques JP, Kober T, Krueger G, van der Zwaag W, Van de Moortele PF, Gruetter R. Neuroimage. 2010 Jan 15;49(2):1271-81. doi: 10.1016/j.neuroimage.2009.10.002 """ pd, r1, r2s, transmit, receive, gfactor \ = utils.to_max_backend(pd, r1, r2s, transmit, receive, gfactor) pd, r1, r2s, transmit, receive, gfactor \ = utils.to(pd, r1, r2s, transmit, receive, gfactor, device=device) backend = utils.backend(pd) if tx is None and te is None: tx = 5.8e-3 tx = tx or 2 * te # Time between excitation pulses te = te or tx / 2 # Echo time fa1, fa2 = py.make_list(fa, 2) fa1 = fa1 * constants.pi / 180 # Flip angle of first GRE block fa2 = fa2 * constants.pi / 180 # Flip angle of second GRE block n = n or min(pd.shape) # Number of readouts (PE steps) per loop tr1 = n * tx # First GRE block tr2 = n * tx # Second GRE block tp = ti1 - tr1 / 2 # Preparation time tw = ti2 - tr2 / 2 - ti1 - tr1 / 2 # Wait time between GRE blocks td = tr - ti2 - tr2 / 2 # Recovery time m = n // 2 # Middle of echo train if transmit is not None: fa1 = transmit * fa1 fa2 = transmit * fa2 del transmit fa1 = torch.as_tensor(fa1, **backend) fa2 = torch.as_tensor(fa2, **backend) # precompute exponential terms ex = r1.mul(-tx).exp() ep = r1.mul(-tp).exp() ew = r1.mul(-tw).exp() ed = r1.mul(-td).exp() e1 = r1.mul(-tr).exp() c1 = fa1.cos() c2 = fa2.cos() # steady state mss = (1 - ep) * (c1 * ex).pow(n) mss = mss + (1 - ex) * (1 - (c1 * ex).pow(n)) / (1 - c1 * ex) mss = mss * ew + (1 - ew) mss = mss * (c2 * ex).pow(n) mss = mss + (1 - ex) * (1 - (c2 * ex).pow(n)) / (1 - c2 * ex) mss = mss * ed + (1 - ed) mss = mss * pd / (1 + eff * (c1 * c2).pow(n) * e1) # IR components mi1 = -eff * mss * ep / pd + (1 - ep) mi1 = mi1 * (c1 * ex).pow(m - 1) mi1 = mi1 + (1 - ex) * (1 - (c1 * ex).pow(m - 1)) / (1 - c1 * ex) mi1 = mi1 * fa1.sin() mi1 = mi1.abs() mi2 = (mss / pd - (1 - ed)) / (ed * (c2 * ex).pow(m)) mi2 = mi2 + (1 - ex) * (1 - (c2 * ex).pow(-m)) / (1 - c2 * ex) mi2 = mi2 * fa2.sin() mi2 = mi2.abs() if return_combined and not sigma: m = (mi1 * mi2) / (mi1.square() + mi2.square()) m = torch.where(~torch.isfinite(m), m.new_zeros([]), m) return m # Common component (pd, B1-, R2*) if receive is not None: pd = pd * receive del receive mi1 = mi1 * pd mi2 = mi2 * pd if r2s is not None: e2 = r2s.mul(-te).exp_() mi1 = mi1 * e2 mi2 = mi2 * e2 del r2s # noise mi1 = add_noise(mi1, std=sigma, gfactor=gfactor) mi2 = add_noise(mi2, std=sigma, gfactor=gfactor) if return_combined: m = (mi1 * mi2) / (mi1.square() + mi2.square()) m = torch.where(~torch.isfinite(m), m.new_zeros([]), m) return m else: mi1 = torch.where(~torch.isfinite(mi1), mi1.new_zeros([]), mi1) mi2 = torch.where(~torch.isfinite(mi2), mi2.new_zeros([]), mi2) return mi1, mi2
def mse(moving, fixed, lam=1, dim=None, grad=True, hess=True, mask=None): """Mean-squared error loss for optimisation-based registration. (A factor 1/2 is included, and the loss is averaged across voxels, but not across channels or batches) Parameters ---------- moving : ([B], K, *spatial) tensor Moving image fixed : ([B], K, *spatial) tensor Fixed image lam : float or ([B], K|1, [*spatial]) tensor_like Gaussian noise precision (or IRLS weights) dim : int, default=`fixed.dim() - 1` Number of spatial dimensions grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return Hessian Returns ------- ll : () tensor Negative log-likelihood g : (..., K, *spatial) tensor, optional Gradient with respect to the moving imaged h : (..., K, *spatial) tensor, optional (Diagonal) Hessian with respect to the moving image """ fixed, moving, lam = utils.to_max_backend(fixed, moving, lam) if mask is not None: mask = mask.to(fixed.device) dim = dim or (fixed.dim() - 1) if lam.dim() <= 2: if lam.dim() == 0: lam = lam.flatten() lam = utils.unsqueeze(lam, -1, dim) # pad spatial dimensions nvox = py.prod(fixed.shape[-dim:]) if moving.requires_grad: ll = moving - fixed if mask is not None: ll = ll.mul_(mask) ll = ll.square().mul_(lam).sum() / (2 * nvox) else: ll = moving - fixed if mask is not None: ll = ll.mul_(mask) ll = ll.square_().mul_(lam).sum() / (2 * nvox) out = [ll] if grad: g = moving - fixed if mask is not None: g = g.mul_(mask) g = g.mul_(lam).div_(nvox) out.append(g) if hess: h = lam / nvox if mask is not None: h = mask * h out.append(h) return tuple(out) if len(out) > 1 else out[0]
def __init__(self, x, y, z, c): x, y, z, c = utils.to_max_backend(x, y, z, c) self.x = x self.y = y self.z = z self.c = c
def intensity_preproc(*images, min=None, max=None, eq=None): """(Joint) rescaling and intensity equalizing. Parameters ---------- *images : (*batch, H, W) tensor Input (batch of) 2d images. All batch shapes should be broadcastable together. min : tensor_like, optional Minimum value. Should be broadcastable to batch. Default: 5th percentile of each batch element. max : tensor_like, optional Maximum value. Should be broadcastable to batch. Default: 95th percentile of each batch element. eq : {'linear', 'quadratic', 'log', None} or float, default=None Apply histogram equalization. If 'quadratic' or 'log', the histogram of the transformed signal is equalized. If float, the signal is taken to that power before being equalized. Returns ------- *images : (*batch, H, W) tensor Preprocessed images. Intensities are scaled within [0, 1]. """ if len(images) == 1: images = [utils.to_max_backend(*images)] else: images = utils.to_max_backend(*images) backend = utils.backend(images[0]) eps = constants.eps(images[0].dtype) # rescale min/max min = py.make_list(min, len(images)) max = py.make_list(max, len(images)) min = [ utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True) if mn is None else torch.as_tensor(mn, **backend)[None, None] for image, mn in zip(images, min) ] min, *othermin = min for mn in othermin: min = torch.min(min, mn) del othermin max = [ utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True) if mx is None else torch.as_tensor(mx, **backend)[None, None] for image, mx in zip(images, max) ] max, *othermax = max for mx in othermax: max = torch.max(max, mx) del othermax images = [torch.max(torch.min(image, max), min) for image in images] images = [ image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min)) for image in images ] if not eq: return tuple(images) if len(images) > 1 else images[0] # reshape and concatenate batch = utils.expanded_shape(*[image.shape[:-2] for image in images]) images = [image.expand([*batch, *image.shape[-2:]]) for image in images] shapes = [image.shape[-2:] for image in images] chunks = [py.prod(s) for s in shapes] images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)] images = torch.cat(images, dim=-1) if eq is True: eq = 'linear' if not isinstance(eq, str): if eq >= 0: images = images.pow(eq) else: images = images.clamp_min_(constants.eps(images.dtype)).pow(eq) elif eq.startswith('q'): images = images.square() elif eq.startswith('log'): images = images.clamp_min_(constants.eps(images.dtype)).log() images = histeq(images, dim=-1) if not (isinstance(eq, str) and eq.startswith('lin')): # rescale min/max images -= math.min(images, dim=-1, keepdim=True) images /= math.max(images, dim=-1, keepdim=True) images = images.split(chunks, dim=-1) images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)] return tuple(images) if len(images) > 1 else images[0]
def __init__(self, axis, angle): axis, angle = utils.to_max_backend(axis, angle) self.ax = axis self.theta = angle
def cc(moving, fixed, dim=None, grad=True, hess=True, mask=None): """Squared Pearson's correlation coefficient loss 1 - (E[(x - mu_x)'(y - mu_y)]/(s_x * s_y)) ** 2 Parameters ---------- moving : (..., K, *spatial) tensor Moving image with K channels. fixed : (..., K, *spatial) tensor Fixed image with K channels. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions. grad : bool, default=True Compute an return gradient hess : bool, default=True Compute and return approximate Hessian Returns ------- ll : () tensor """ moving, fixed = utils.to_max_backend(moving, fixed) moving = moving.clone() fixed = fixed.clone() dim = dim or (fixed.dim() - 1) dims = list(range(-dim, 0)) if mask is not None: mask = mask.to(fixed.device) mean = lambda x: (x * mask).sum(dim=dims, keepdim=True).div_( mask.sum(dim=dims, keepdim=True)) else: mean = lambda x: x.mean(dim=dims, keepdim=True) n = py.prod(fixed.shape[-dim:]) moving -= mean(moving) fixed -= mean(fixed) sigm = mean(moving.square()).sqrt_() sigf = mean(fixed.square()).sqrt_() moving = moving.div_(sigm) fixed = fixed.div_(sigf) corr = mean(moving * fixed) corr2 = 1 - corr.square() corr2.clamp_min_(1e-8) out = [] if grad: g = 2 * corr * (moving * corr - fixed) / (n * sigm) g /= corr2 # chain rule for log if mask is not None: g = g.mul_(mask) out.append(g) if hess: # approximate hessian h = 2 * (corr / sigm).square() / n h /= corr2 # chain rule for log if mask is not None: h = h * mask out.append(h) # return stuff corr = corr2.log_().sum() out = [corr, *out] return tuple(out) if len(out) > 1 else out[0]
def __init__(self, shift, scale, orientation='RAS'): shift, scale = utils.to_max_backend(shift, scale) self.shift = shift self.scale = scale self.orientation = orientation
def greens_apply(mom, greens, factor=1, voxel_size=1): """Apply the Greens function to a momentum field. Parameters ---------- mom : (..., *spatial, dim) tensor Momentum greens : (*spatial, [dim, dim]) tensor Greens function voxel_size : [sequence of] float, default=1 Voxel size. Only needed when no penalty is put on linear-elasticity. Returns ------- vel : (..., *spatial, dim) tensor Velocity """ # Authors # ------- # .. John Ashburner <*****@*****.**> : original Matlab code # .. Yael Balbastre <*****@*****.**> : Python port # # License # ------- # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm) # under the GNU General Public Licence (version >= 2). mom, greens = utils.to_max_backend(mom, greens) dim = mom.shape[-1] # fourier transform mom = fft.fftn(mom, dim=list(range(-dim - 1, -1)), real=True) # mom = utils.movedim(mom, -1, 0) # if utils.torch_version('>=', (1, 8)): # mom = torch.fft.fftn(mom, dim=list(range(-dim, 0))) # else: # if torch.backends.mkl.is_available: # # use rfft # mom = torch.rfft(mom, dim, onesided=False) # else: # zero = mom.new_zeros([]).expand(mom.shape) # mom = torch.stack([mom, zero], dim=-1) # mom = torch.fft(mom, dim) # mom = utils.movedim(mom, 0, -1) # voxel-wise matrix multiplication # if greens.dim() == dim: # voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom)) # voxel_size = voxel_size.square() # if utils.torch_version('<', (1, 8)): # greens = greens[..., None, None] # mom = mom * greens # mom = mom / voxel_size # else: # if utils.torch_version('<', (1, 8)): # mom[..., 0, :] = linalg.matvec(greens, mom[..., 0, :]) # mom[..., 1, :] = linalg.matvec(greens, mom[..., 1, :]) # else: # mom = torch.complex(linalg.matvec(greens, mom.real), # linalg.matvec(greens, mom.imag)) if greens.dim() == dim: voxel_size = utils.make_vector(voxel_size, dim, **utils.backend(mom)) voxel_size = voxel_size.square().reciprocal() greens = greens.unsqueeze(-1) mom = fft.mul(mom, greens, real=(False, True)) mom = fft.mul(mom, voxel_size, real=(False, True)) else: mom = fft.mul(mom, greens, real=(False, True)) # inverse fourier transform # mom = utils.movedim(mom, -1, 0) # if utils.torch_version('>=', (1, 8)): # mom = torch.fft.ifftn(mom, dim=list(range(-dim, 0))).real # if callable(mom): # mom = mom() # else: # mom = torch.ifft(mom, dim)[..., 0] # mom = utils.movedim(mom, 0, -1) mom = fft.real(fft.ifftn(mom, dim=list(range(-dim - 1, -1)))) mom /= factor return mom
def spgr(pd, r1, r2s=None, mt=None, transmit=None, receive=None, gfactor=None, te=0, tr=25e-3, fa=20, sigma=None, device=None): """Simulate data generated by a Spoiled Gradient-Echo (SPGR/FLASH) sequence. Tissue parameters ----------------- pd : tensor_like Proton density r1 : tensor_like Longitudinal relaxation rate, in 1/sec r2s : tensor_like, optional Transverse relaxation rate, in 1/sec. Mandatory if any `te > 0`. mt : tensor_like, optional MTsat. Mandatory if any `mtpulse == True`. Fields ------ transmit : tensor_like, optional Transmit B1 field receive : tensor_like, optional Receive B1 field gfactor : tensor_like, optional G-factor map. If provided and `sigma` is not `None`, the g-factor map is used to sample non-stationary noise. Sequence parameters ------------------- te : float, default=0 Echo time, in sec tr : float default=2.5e-3 Repetition time, in sec fa : float, default=20 Flip angle, in deg Noise ----- sigma : float, optional Standard-deviation of the sampled Rician noise (no sampling if `None`) Returns ------- sim : tensor Simulated SPGR image """ pd, r1, r2s, mt, transmit, receive, gfactor \ = utils.to_max_backend(pd, r1, r2s, mt, transmit, receive, gfactor) pd, r1, r2s, mt, transmit, receive, gfactor \ = utils.to(pd, r1, r2s, mt, transmit, receive, gfactor, device=device) backend = utils.backend(pd) fa = fa * constants.pi / 180. if transmit is not None: fa = fa * transmit del transmit fa = torch.as_tensor(fa, **backend) if receive is not None: pd = pd * receive del receive pd = pd * fa.sin() fa = fa.cos() e1, r1 = r1.mul(tr).neg_().exp(), None signal = pd * (1 - e1) if mt is not None: omt = mt.neg().add_(1) signal *= omt signal /= (1 - fa * omt * e1) del omt else: signal /= (1 - fa * e1) if r2s is not None: e2, r2s = r2s.mul(te).neg_().exp(), None signal *= e2 del e2 # noise signal = add_noise(signal, std=sigma) return signal
def lgmmh(moving, fixed, dim=None, bins=3, patch=7, stride=1, grad=True, hess=True, mode='g', max_iter=128, theta=None, return_theta=False): fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) shape = fixed.shape[-dim:] if not isinstance(patch, (list, tuple)): patch = [patch] patch = list(patch) if not isinstance(stride, (list, tuple)): stride = [stride] stride = [s or 0 for s in stride] fwd = Fwd(patch, stride, dim, mode) bwd = Bwd(patch, stride, dim, mode, shape) gmmfit = fit_lgmm2(moving, fixed, bins, max_iter, dim, patch=patch, stride=stride, mode=mode, theta=theta) # drop unused variables get = gmmfit.get if return_theta else gmmfit.pop pop = gmmfit.pop z = pop('resp') moving_mean = get('xmean') fixed_mean = get('ymean') moving_var = get('xvar') fixed_var = get('yvar') corr = get('corr') prior = get('prior') out = [pop('nll')] nvox = py.prod(z.shape[-dim:]) moving = moving.unsqueeze(-dim-1) fixed = fixed.unsqueeze(-dim-1) if grad: z0 = fwd(z, None).clamp_min_(1e-10) # gradient of the GMM entropy # L = 0.5 * \sum_k pi_k log|\Sigma_k| + cte @torch.jit.script def make_grad(bwd: Bwd, z, z0, moving, fixed, moving_mean, fixed_mean, moving_var, fixed_var, corr, prior) -> Tensor: cov = corr * (moving_var * fixed_var).sqrt() idet = moving_var * fixed_var * (1 - corr * corr) idet = prior / idet # gradient of determinant + chain rule of log g = moving * bwd(fixed_var * idet, z, z0) - fixed * bwd(cov * idet, z, z0) g -= bwd((moving_mean * fixed_var - fixed_mean * cov) * idet, z, z0) g = g.sum(-bwd.dim-1) return g g = make_grad(bwd, z, z0, moving, fixed, moving_mean, fixed_mean, moving_var, fixed_var, corr, prior) g.div_(nvox) out.append(g) if hess: # # hessian of (1 - corr^2) # imoving_var = moving_var.reciprocal() # corr2 = corr * corr # h = corr2 * imoving_var # # chain rule (with Fisher's scoring) # h /= 1 - corr2 # # hessian of log(moving_var) # h += imoving_var # # weight by proportion and sum # h = h * (z * prior) # h = h.sum(-1) @torch.jit.script def make_hess(bwd: Bwd, z, z0, moving_var, corr, prior) -> Tensor: h = (1 - z0) * prior / (moving_var * (1 - corr * corr)) h = bwd(h, z, z0) h = h.sum(-bwd.dim-1) return h h = make_hess(bwd, z, z0, moving_var, corr, prior) h.div_(nvox) out.append(h) if return_theta: out.append(gmmfit) return out[0] if len(out) == 1 else tuple(out)
def dice_nolog(moving, fixed, dim=None, grad=True, hess=True, mask=None, add_background=False, weighted=False): """Dice loss for optimisation-based registration. Parameters ---------- moving : (..., K, *spatial) tensor Moving image of probabilities (post-softmax). The background class should be omitted. fixed : (..., K, *spatial) tensor Fixed image of probabilities. dim : int, default=`fixed.dim() - 1` Number of spatial dimensions. grad : bool, default=True Compute and return gradient hess : bool, default=True Compute and return Hessian mask : (..., *spatial) tensor, optional Mask of voxels to include in the loss (all by default) add_background : bool, default=False Include the Dice of the (implicit) background class in the loss. weighted : bool or tensor, default=False Weights for each class. If True, weight by positive rate. Returns ------- ll : () tensor Negative log-likelihood g : (..., K, *spatial) tensor, optional Gradient with respect to the moving image. h : (..., K, *spatial) tensor, optional Hessian with respect to the moving image. """ fixed, moving = utils.to_max_backend(fixed, moving) dim = dim or (fixed.dim() - 1) nc = moving.shape[-dim-1] # nb classes - bck fixed = utils.slice_tensor(fixed, slice(nc), -dim-1) # remove bkg class if mask is not None: mask = mask.to(moving.device) nvox = mask.sum(range(-dim-1), keepdim=True) else: nvox = py.prod(fixed.shape[-dim:]) @torch.jit.script def rescale(x, dim_channel: int, add_background: bool = False): """Ensure that a tensor is in [0, 1]""" x = x.clamp_min(0) x = x / x.sum(dim_channel, keepdim=True).clamp_min_(1) if add_background: x = torch.stack([x, 1 - x.sum(dim_channel, keepdim=True)], dim_channel) return x moving = rescale(moving, -dim-1, add_background) fixed = rescale(fixed, -dim-1, add_background) if mask is not None: moving *= mask fixed *= mask if weighted is True: weighted = fixed.sum(list(range(-dim, 0)), keepdim=True).div_(nvox) elif weighted is not False: weighted = torch.as_tensor(weighted, **utils.backend(moving)) for _ in range(dim): weighted = weighted.unsqueeze(-1) else: weighted = None @torch.jit.script def loss_components(moving, fixed, dim: int, weighted: Optional[Tensor] = None): """Compute the (negative) DiceLoss, (positive) Dice and union""" dims = [d for d in range(-dim, 0)] overlap = (moving * fixed).sum(dims, keepdim=True) union = (moving + fixed).sum(dims, keepdim=True) union += 1e-5 dice = 2 * overlap / union if weighted is not None: ll = 1 - weighted * dice else: ll = 1 - dice ll = ll.sum() return ll, dice, union ll, dice, union = loss_components(moving, fixed, dim, weighted) out = [ll] # gradient if grad: @torch.jit.script def do_grad(dice, fixed, union): return (dice - 2 * fixed) / union g = do_grad(dice, fixed, union) if weighted is not None: g *= weighted if add_background: g_last = utils.slice_tensor(g, slice(-1, None), -dim-1) g = utils.slice_tensor(g, slice(-1), -dim-1) g -= g_last if mask is not None: g *= mask out.append(g) # hessian if hess: @torch.jit.script def do_hess(dice, fixed, union, nvox, dim: int): dims = [d for d in range(-dim, 0)] positive_rate = fixed.sum(dims, keepdim=True) / nvox h = (dice - fixed - positive_rate).abs() h = 2 * nvox * h / union.square() return h nvox = torch.as_tensor(nvox, device=moving.device) h = do_hess(dice, fixed, union, nvox, dim) if weighted is not None: h *= weighted if add_background: h_foreground = utils.slice_tensor(h, slice(-1), -dim-1) h = utils.slice_tensor(h, slice(-1, None), -dim-1) # h background hshape = list(h.shape) hshape[-dim-1] = nc*(nc+1)//2 h = h.expand(hshape).clone() diag = utils.slice_tensor(h, range(nc), -dim-1) diag += h_foreground if mask is not None: h *= mask out.append(h) return tuple(out) if len(out) > 1 else out[0]