def smooth_iid(): shape = [128, 128] center = tuple([s // 2 for s in shape]) nrep = 100 dim = len(shape) sig = [0.1, 0.5, 1, 2, 4, 8, 16] # [0, 0.01, 0.05, 0.1, 0.5, 1, 2, 4, 8] fwhm = [2.355 * s for s in sig] nc = [(2 * pi * (s**2))**(-dim / 2) if s > 0 else 1 for s in sig] yb = [(4 * pi * (s**2))**(-dim / 2) if s > 0 else 1 for s in sig] dat = torch.randn([nrep, *shape]) dat = smooth(dat, fwhm=1.5, basis=1, dim=2) var0 = dat.var(0)[64, 64].item() varb0 = [] varb1 = [] varbd = [] for f in fwhm: sdat = smooth(dat, fwhm=f, basis=0, dim=2) varb0.append(sdat.var(0)[center].item() / var0) sdat = smooth(dat, fwhm=f, basis=1, dim=2) varb1.append(sdat.var(0)[center].item() / var0) kernel = gauss_kernel(f, dim) sdat = conv(dim, dat, kernel, padding='auto', bound='dct2') varbd.append(sdat.var(0)[center].item() / var0) for fn in (plt.loglog, plt.semilogy, plt.plot): fn(sig, nc, 'k:') fn(sig, yb, 'k--') fn(sig, varb0, 'r-+') fn(sig, varb1, 'b-+') fn(sig, varbd, 'g-+') plt.xlabel('Smoothing sigma') plt.ylabel('Variance') plt.show() fn = plt.plot fn(sig[2:], nc[2:], 'k:') fn(sig[2:], yb[2:], 'k--') fn(sig[2:], varb0[2:], 'r-+') fn(sig[2:], varb1[2:], 'b-+') fn(sig[2:], varbd[2:], 'g-+') plt.xlabel('Smoothing sigma') plt.ylabel('Variance') plt.show() t = lambda v: v**(-1 / dim) varb0 = list(map(t, varb0)) varb1 = list(map(t, varb1)) varbd = list(map(t, varbd)) fn = plt.plot fn(sig, sig, 'k--') fn(sig[2:], varb0[2:], 'r-+') fn(sig[2:], varb1[2:], 'b-+') fn(sig[2:], varbd[2:], 'g-+') plt.xlabel('Smoothing sigma') plt.ylabel('Variance') plt.show()
def forward(self, x, noise=None, return_resolution=False): if noise is not None: noise = noise.expand(x.shape) dim = x.dim() - 2 backend = utils.backend(x) resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1], **backend) resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1], **backend) all_resolutions = [] out = torch.empty_like(x) for b in range(len(x)): for c in range(x.shape[1]): resolution = self.resolution(resolution_exp[c], resolution_scale[c]).sample() resolution = resolution.clamp_min(1) fwhm = [resolution] * dim y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2') if noise is not None: y += noise[b, c] factor = [1/resolution] * dim y = y[None, None] # need batch and channel for resize y = resize(y, factor=factor, anchor='f') factor = [resolution] * dim all_resolutions.append(factor) y = resize(y, factor=factor, shape=x.shape[2:], anchor='f') out[b, c] = y[0, 0] all_resolutions = utils.as_tensor(all_resolutions, **backend) return (out, all_resolutions) if return_resolution else out
def backward(self, g, x, w=None, min=None, max=None): """ Parameters ---------- g : (..., *bins) tensor x : (..., n, 2) tensor w : (..., n) tensor, optional min : (...) tensor_like, optional max : (...) tensor_like, optional Returns ------- g : (..., n, 2) tensor """ backend = dict(dtype=x.dtype, device=x.device) n = x.shape[-2] xbatch = x.shape[:-2] if w is not None: _, w = torch.broadcast_tensors(x[..., 0], w) batch = w.shape[:-1] x = x.expand([*batch, *x.shape[-2:]]) w = w.reshape([-1, n]) else: batch = xbatch x = x.reshape([-1, n, 2]) if min is None: min = x.min(-2, keepdim=True).values else: min = torch.as_tensor(min, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) if max is None: max = x.max(-2, keepdim=True).values else: max = torch.as_tensor(max, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) x = x.clone() bins = torch.as_tensor(self.bins, **backend) x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5) min = min.reshape([*xbatch, 2]) max = max.reshape([*xbatch, 2]) g = g.reshape([-1, *self.bins]) # smooth backward if any(self.fwhm): g = smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2) # push data into the histogram g = _jhistc_backward(g, x, w, self.order, self.bound, self.extrapolate) g = g.mul_(bins / (max - min)) # reshape g = g.reshape([*batch, n, 2]) return g
def forward(self, x, affine=None): """ Parameters ---------- x : (X, Y, Z) tensor or str affine : (4, 4) tensor, optional Returns ------- seg : (32, oX, oY, oZ) tensor Segmentation resliced : (oX, oY, oZ) tensor Input resliced to 1 mm RAS affine : (4, 4) tensor Output orientation matrix """ if self.verbose: print('Preprocessing... ', end='', flush=True) if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) x = SynthPreproc.addnoise(x) if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) fwhm = 0.25 * vx.reciprocal() fwhm[vx > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] oshape = x.shape x, crop = SynthPreproc.crop(x) x = SynthPreproc.preproc(x)[None, None] if self.verbose: print('done.', flush=True) print('Segmenting... ', end='', flush=True) s, x = super().forward(x)[0], x[0, 0] if self.verbose: print('done.', flush=True) print('Postprocessing... ', end='', flush=True) s = self.relabel(s.argmax(0)) x = SynthPreproc.pad(x, oshape, crop) s = SynthPreproc.pad(s, oshape, crop) if self.verbose: print('done.', flush=True) return s, x, affine
def forward(self, x): dim = x.dim() - 2 backend = dict(dtype=x.dtype, device=x.device) fwhm_exp = utils.make_vector(self.fwhm_exp, 1 if self.iso else dim, **backend) fwhm_scale = utils.make_vector(self.fwhm_scale, 1 if self.iso else dim, **backend) out = torch.as_tensor(x) for b in range(len(x)): fwhm = self.fwhm(fwhm_exp, fwhm_scale).sample().clamp_min_(0).expand([dim]).clone() out[b] = smooth(x[b], fwhm=fwhm, dim=dim, padding='same', bound='dct2') return out
def _build_pyramid(self, dat, levels, method, dim, bound, mask=None, preview=None): levels = list(levels) indexed_levels = list(enumerate(levels)) indexed_levels.sort(key=lambda x: x[1]) nb_levels = max(levels) if mask is not None: mask = mask.to(dat.device) dats = [dat] * levels.count(0) masks = [mask] * levels.count(0) previews = [preview] * levels.count(0) if mask is not None: mask = mask.to(dat.dtype) if preview is not None: preview = preview.to(dat.dtype) for level in range(1, nb_levels+1): shape = dat.shape[-dim:] kernel_size = [min(2, s) for s in shape] if method[0] == 'g': # gaussian pyramid # We assume the original data has a PSF of 1 input voxel. # We smooth by an additional 1-vx FWHM so that the data has a # PSF of 2 input voxels == 1 output voxel, then subsample. smooth = lambda x: spatial.smooth(x, fwhm=1, stride=2, dim=dim, bound=bound) elif method[0] == 'a': # average window smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size, stride=2, reduction='mean') elif method[0] == 'm': # median window smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size, stride=2, reduction='median') elif method[0] == 's': # strides slicer = [slice(None, None, 2)] * dim smooth = lambda x: x[(Ellipsis, *slicer)] else: raise ValueError(method) dat = smooth(dat) if mask is not None: mask = smooth(mask) if preview is not None: preview = smooth(preview) dats += [dat] * levels.count(level) masks += [mask] * levels.count(level) previews += [preview] * levels.count(level) reordered_dats = [None] * len(levels) reordered_masks = [None] * len(levels) reordered_previews = [None] * len(levels) for (i, level), dat, mask, preview \ in zip(indexed_levels, dats, masks, previews): reordered_dats[i] = dat reordered_masks[i] = mask reordered_previews[i] = preview return reordered_dats, reordered_masks, reordered_previews
def resize(cls, x, affine, target_vx=1): target_vx = utils.make_vector(target_vx, x.dim(), **utils.backend(affine)) vx = spatial.voxel_size(affine) factor = vx / target_vx fwhm = 0.25 * factor.reciprocal() fwhm[factor > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], factor.tolist(), affine=affine) x = x[0, 0] return x, affine
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- shape : sequence[int], optional channel : int, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) channel = overload.get('channel', self.channel) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # sample if parameters are callable nb_dim = len(shape) # device/dtype mean = utils.make_vector(self.mean, channel, **backend) amplitude = utils.make_vector(self.amplitude, channel, **backend) fwhm = utils.make_vector(self.fwhm, channel, **backend) # convert SE parameters to noise/kernel parameters sigma_se = fwhm / math.sqrt(8*math.log(2)) amplitude = amplitude * (2*pi)**(nb_dim/4) * sigma_se.sqrt() fwhm = fwhm * math.sqrt(2) # smooth out = torch.empty([batch, channel, *shape], **backend) for b in range(batch): for c in range(channel): sample = torch.distributions.Normal(mean[c], amplitude[c]).sample(shape) out[b, c] = spatial.smooth( sample, 'gauss', fwhm, basis=self.basis, bound='dct2', dim=nb_dim, padding='same') return out
def se_sample_smooth(shape, sigma, lam, mu=None, repeats=1, **backend): """Sample random fields with a squared exponential kernel. This function uses Gaussian smoothing of white noise. Parameters ---------- shape : sequence[int] Shape of the image / volume.å sigma : float SE amplitude. lam : float SE length-scale. mu : () or (*shape) tensor_like SE mean repeats : int, default=1 Number of sampled fields. Returns ------- field : (repeats, *shape) tensor Sampled random fields. """ # Convert SE parameters to Gaussian parameters dim = len(shape) sigma = sigma * ((2 * constants.pi)**(dim / 4)) * (lam**(dim / 2)) lam = lam / (2**0.5) fwhm = lam * 2 * ((2 * math.log(2))**0.5) # Sample white noise mul = 4 pad = int(math.ceil(2 * fwhm)) shape2 = [s * mul + 2 * pad for s in shape] field = torch.randn([repeats, *shape2], **backend) # Gaussian smoothing field = spatial.smooth(field, fwhm=fwhm * mul, dim=dim, basis=0) sub = (slice(None), ) + (slice(pad + mul // 2, -pad, mul), ) * dim field = field[sub] field *= sigma * (mul**(dim / 2)) # Add mean if mu is not None: mu = torch.as_tensor(mu, **backend) field += mu return field
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def _cli(args): """Command-line interface for `smooth` without exception handling""" args = args or sys.argv[1:] options = parser(args) if options.help: print(help) return fwhm = options.fwhm unit = 'mm' if isinstance(fwhm[-1], str): *fwhm, unit = fwhm fwhm = make_list(fwhm, 3) options.output = make_list(options.output, len(options.files)) for fname, ofname in zip(options.files, options.output): f = io.map(fname) vx = voxel_size(f.affine).tolist() dim = len(vx) if unit == 'mm': fwhm1 = [f / v for f, v in zip(fwhm, vx)] else: fwhm1 = fwhm[:len(vx)] dat = f.fdata() dat = movedim_front2back(dat, dim) dat = smooth(dat, type=options.method, fwhm=fwhm1, basis=options.basis, bound=options.padding, dim=dim) dat = movedim_back2front(dat, dim) folder, base, ext = fileparts(fname) ofname = ofname.format(dir=folder or '.', base=base, ext=ext, sep=os.path.sep) io.savef(dat, ofname, like=f)
def forward(self, x): if x.dtype.is_floating_point: fwd = 1 - x[:, :1] else: fwd = (x == 0).bitwise_not_().float() fwd = spatial.smooth(fwd, fwhm=self.kernel, dim=x.dim() - 2, padding='same', bound='replicate') fwd = fwd > 1e-3 if x.dtype.is_floating_point: bag = x[:, :1] * fwd bg = x[:, :1] * fwd.bitwise_not_() return torch.cat([bg, x[1:], bag], dim=1) else: x = x.clone() bag = (x == 0).bitwise_and_(fwd) label = self.label or (x.max() + 1) x[bag] = label return x
def forward(self, x, min=None, max=None, mask=None): """ Parameters ---------- x : (..., N, 2) tensor Input multivariate vector min : (..., 2) tensor, optional max : (..., 2) tensor, optional mask : (..., N) tensor, optional Returns ------- h : (..., B, B) tensor Joint histogram """ shape = x.shape x, min, max = self._prepare(x, min, max) # push data into the histogram # hidden feature: tell pullpush to use +/- 0.5 tolerance when # deciding if a coordinate is inbounds. extrapolate = self.extrapolate or 2 if mask is None: h = spatial.grid_count(x[:, None], self.n, self.order, self.bound, extrapolate)[:, 0] else: mask = mask.to(x.device, x.dtype) h = spatial.grid_push(mask, x[:, None], self.n, self.order, self.bound, extrapolate)[:, 0] h = h.to(x.dtype) h = h.reshape([*shape[:-2], *h.shape[-2:]]) if self.fwhm: h = spatial.smooth(h, fwhm=self.fwhm, bound=self.bound, dim=2) return h, min, max
def backward2(self, h, x, w=None, min=None, max=None): """ Parameters ---------- h : (..., *bins, [*bins]) tensor x : (..., n, 2) tensor w : (..., n) tensor, optional min : (...) tensor_like, optional max : (...) tensor_like, optional Returns ------- h : (..., n, 2) tensor """ backend = dict(dtype=x.dtype, device=x.device) n = x.shape[-2] xbatch = x.shape[:-2] if w is not None: _, w = torch.broadcast_tensors(x[..., 0], w) batch = w.shape[:-1] x = x.expand([*batch, *x.shape[-2:]]) w = w.reshape([-1, n]) else: batch = xbatch x = x.reshape([-1, n, 2]) if h.shape[:-2] == batch: is_diag = True elif h.shape[:-4] == batch: is_diag = False else: raise ValueError('Don\'t know what to do with that shape') if min is None: min = x.min(-2, keepdim=True).values else: min = torch.as_tensor(min, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) if max is None: max = x.max(-2, keepdim=True).values else: max = torch.as_tensor(max, **backend).expand([*xbatch, 2]).reshape([-1, 1, 2]) x = x.clone() bins = torch.as_tensor(self.bins, **backend) x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5) min = min.reshape([*xbatch, 2]) max = max.reshape([*xbatch, 2]) if is_diag: h = h.reshape([-1, *self.bins]) else: h = h.reshape([-1, *self.bins, *self.bins]) # smooth backward if any(self.fwhm): ker = kernels.smooth(fwhm=self.fwhm) if is_diag: ker = [k.square_() for k in ker] h = smooth(h, kernel=ker, bound=self.bound, dim=2) else: h = smooth(h, kernel=ker, bound=self.bound, dim=2) h = h.transpose(-4, -2).transpose(-3, -1) h = smooth(h, kernel=ker, bound=self.bound, dim=2) h = h.transpose(-4, -2).transpose(-3, -1) # push data into the histogram h = _jhistc_backward2(h, x, w, self.order, self.bound, self.extrapolate) h = h.mul_((bins / (max - min)).square_()) # reshape h = h.reshape([*batch, n, 2]) return h
def emmi_soft(moving, fixed, dim=None, prior=None, fwhm=None, max_iter=32, weights=None, grad=True, hess=True, return_prior=False): # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ tiny = 1e-16 dim = dim or (moving.dim() - 1) moving, fixed, weights, prior, shape = emmi_prepare( moving, fixed, weights, prior, dim) *batch, J, K = prior.shape Nb = len(batch) N = moving.shape[-1] Nm = weights.sum() # ------------------------------------------------------------------ # EM LOOP # ------------------------------------------------------------------ ll = -float('inf') z = moving.new_empty([*batch, K, N]) prior0 = torch.empty_like(prior) for n_iter in range(max_iter): ll_prev = ll # -------------------------------------------------------------- # E-step # -------------------------------------------------------------- z = sample_prior(prior.log(), fixed, z) z += moving.log() z, ll = math.softmax_lse(z, -2, lse=True, weights=weights) # -------------------------------------------------------------- # M-step # ------ # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)] # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j) # -------------------------------------------------------------- z *= weights prior0 = scatter_prior(prior0, fixed, z) prior.copy_(prior0).add_(tiny) # make it a joint distribution prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb) if fwhm: # smooth "prior" for the prior prior = prior.transpose(-1, -2) prior = spatial.smooth(prior, dim=1, basis=0, fwhm=fwhm, bound='replicate') prior = prior.transpose(-1, -2) # MI-like normalization prior /= prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True) if ll - ll_prev < 1e-5 * Nm: break # compute mutual information (times number of observations) # > prior contains p(x,y)/(p(x) p(y)) # > prior0 contains N * p(x,y) # >> 1/N Σ_{j,k} prior0[j,k] * log(prior[j,k]) # = Σ_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y) # = Σ_{x,y} p(x,y) * log p(x,y) # - Σ_{xy} p(x,y) log p(x) # - Σ_{xy} p(x,y) log p(y) # = Σ_{x,y} p(x,y) * log p(x,y) # - Σ_{x} p(x) log p(x) # - Σ_{y} p(y) log p(y) # = -H[x,y] + H[x] + H[y] # = MI[x, y] ll = -(prior0 * prior.log()).sum() / Nm out = [ll] # ------------------------------------------------------------------ # GRADIENTS # ------------------------------------------------------------------ # compute gradients # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y] # The objective function is \sum_n E[y_n] # > ll = Σ_n log p(x[n] == j[n], h) # = Σ_nj \sum_j q(x[n] == j) log \sum_k p(x[n] == j | z[n] == k, h) p(z[n] == k) if grad or hess: norm = linalg.dot( prior.transpose(-1, -2).unsqueeze(-1), moving.transpose(-1, -2).unsqueeze(-3)) norm = norm.add_(tiny).reciprocal_() g = sample_prior(prior, fixed * norm) if hess: norm = norm.square_().mul_(fixed).unsqueeze(-1) h = moving.new_zeros([*g.shape[:-2], K * (K + 1) // 2, N]) for j in range(J): h[..., :K, :] += prior[..., j, :K, None].square() * norm c = K for k in range(K): for kk in range(k + 1, K): h[..., c, :] += (prior[..., j, k, None] * prior[..., j, kk, None] * norm) c += 1 if grad: g *= weights g.neg_() g = g.reshape([*g.shape[:-1], *shape]) out.append(g) if hess: h *= weights h = h.reshape([*h.shape[:-1], *shape]) out.append(h) if return_prior: out.append(prior) return out[0] if len(out) == 1 else tuple(out)
def emmi_hard(moving, fixed, dim=None, prior=None, fwhm=None, max_iter=32, weights=None, grad=True, hess=True, return_prior=False): # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ tiny = 1e-16 dim = dim or (moving.dim() - 1) moving, fixed, weights, prior, shape = emmi_prepare( moving, fixed, weights, prior, dim) *batch, J, K = prior.shape Nb = len(batch) N = moving.shape[-1] Nm = weights.sum() # ------------------------------------------------------------------ # EM LOOP # ------------------------------------------------------------------ ll = -float('inf') z = moving.new_empty([*batch, K, N]) prior0 = torch.empty_like(prior) for n_iter in range(max_iter): ll_prev = ll # -------------------------------------------------------------- # E-step # ------ # estimate responsibilities of each moving cluster for each # fixed voxel using Bayes' rule: # p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k) # # . j[n] is the discretized fixed image # . p(z[n] == k) is the moving template # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k) # -------------------------------------------------------------- z = sample_prior(prior, fixed, z) z *= moving # -------------------------------------------------------------- # compute log-likelihood (log_sum of the posterior) # ll = Σ_n log p(x[n] == j[n]) # = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k) p(z[n] == k) # = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z} # -------------------------------------------------------------- ll = z.sum(-2, keepdim=True) ll = add_tiny_(ll, Nb) z /= ll ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double) z *= weights # -------------------------------------------------------------- # M-step # ------ # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)] # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j) # -------------------------------------------------------------- prior0 = scatter_prior(prior0, fixed, z) prior.copy_(prior0).add_(tiny) # make it a joint distribution prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb) if fwhm: # smooth "prior" for the prior prior = prior.transpose(-1, -2) prior = spatial.smooth(prior, dim=1, basis=0, fwhm=fwhm, bound='replicate') prior = prior.transpose(-1, -2) # prior /= prior.sum(dim=[-1, -2], keepdim=True) # MI-like normalization prior /= add_tiny_( prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True), Nb) if ll - ll_prev < 1e-5 * Nm: break # compute mutual information (times number of observations) # > prior contains p(x,y)/(p(x) p(y)) # > prior0 contains N * p(x,y) # >> 1/N \sum_{j,k} prior0[j,k] * log(prior[j,k]) # = \sum_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y) # = \sum_{x,y} p(x,y) * log p(x,y) # - \sum_{xy} p(x,y) log p(x) # - \sum_{xy} p(x,y) log p(y) # = \sum_{x,y} p(x,y) * log p(x,y) # - \sum_{x} p(x) log p(x) # - \sum_{y} p(y) log p(y) # = -H[x,y] + H[x] + H[y] # = MI[x, y] ll = -(prior0 * add_tiny_(prior, Nb).log()).sum() / Nm out = [ll] # ------------------------------------------------------------------ # GRADIENTS # ------------------------------------------------------------------ # compute gradients # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y] # The objective function is \sum_n E[y_n] # > ll = \sum_n log p(x[n] == j[n], h) # = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k, h) p(z[n] == k) if grad or hess: g = sample_prior(prior, fixed) norm = linalg.dot(g.transpose(-1, -2), moving.transpose(-1, -2)) norm = add_tiny_(norm, Nb).unsqueeze(-2).reciprocal_() g *= norm if hess: h = sym_outer(g, -2) if grad: g *= weights g /= -Nm # g.neg_() g = g.reshape([*g.shape[:-1], *shape]) out.append(g) if hess: h *= weights h /= Nm h = h.reshape([*h.shape[:-1], *shape]) out.append(h) if return_prior: out.append(prior) return out[0] if len(out) == 1 else tuple(out)
def forward(self, x): """ Parameters ---------- x : (batch, 1 or classes[-1], *shape) tensor Labels or probabilities Returns ------- x : (batch, channel, *shape) tensor """ batch, _, *shape = x.shape device = x.device dtype = x.dtype if not dtype.is_floating_point: dtype = self.dtype backend = dict(dtype=dtype, device=device) means = torch.as_tensor(self.means, **backend) scales = torch.as_tensor(self.scales, **backend) nb_classes = means.shape[-1] if means.dim() == 2: channel = len(means) elif scales.dim() == 2: channel = len(scales) else: channel = 1 means = means.expand([channel, nb_classes]).clone() scales = scales.expand([channel, nb_classes]).clone() fwhm = utils.make_vector(self.fwhm, nb_classes, **backend) implicit = x.shape[1] < nb_classes out = torch.zeros([batch, channel, *shape], **backend) for k in range(nb_classes): sampler = _get_dist('normal')(means[:, k], scales[:, k]) if x.dtype.is_floating_point: y1 = sampler.sample([batch, *shape]) y1 = utils.movedim(y1, -1, 1) for c, f in enumerate(fwhm): if f > 0: y1[:, c] = spatial.smooth(y1[:, c], fwhm=f, dim=len(shape), padding='same', bound='dct2') if not implicit: x1 = x[:, k, None] elif k > 0: x1 = x[:, k - 1, None] else: x1 = x.sum(1, keepdim=True).neg_().add_(1) out.addcmul(y1, x1) else: mask = (x.squeeze(1) == k) y1 = sampler.sample([mask.sum().long()]) out = utils.movedim(out, 1, -1) out[mask, :] = y1 out = utils.movedim(out, -1, 1) return out
def backward(self, x, g, min=None, max=None, hess=False, mask=None): """ Parameters ---------- x : (..., N, 2) tensor Input multidimensional vector g : (..., B, B) tensor Gradient with respect to the histogram min : (..., 2) tensor, optional max : (..., 2) tensor, optional Returns ------- g : (..., N, 2) tensor Gradient with respect to x """ if self.fwhm: g = spatial.smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2) shape = x.shape x, min, max = self._prepare(x, min, max) nvox = x.shape[-2] min = min.unsqueeze(-2) max = max.unsqueeze(-2) g = g.reshape([-1, *g.shape[-2:]]) extrapolate = self.extrapolate or 2 if not hess: g = spatial.grid_grad(g[:, None], x[:, None], self.order, self.bound, extrapolate) g = g[:, 0].reshape(shape) else: # 1) Absolute value of adjoint of gradient # we want shapes # o : [batch=1, channel=1, spatial=[1, vox], dim=2] # g : [batch=1, channel=1, spatial=[B(mov), B(fix)]] # x : [batch=1, spatial=[1, vox], dim=2] # -> [batch=1, channel=1, spatial=[B(mov), B(fix)]] order = _spatial.inter_to_nitorch([self.order], True) bound = _spatial.bound_to_nitorch([self.bound], True) o = torch.ones_like(x) g.requires_grad_() # triggers push o, = _spatial.grid_grad_backward(o[:, None, None], g[:, None], x[:, None], bound, order, extrapolate) g.requires_grad_(False) g *= o[:, 0] # 2) Absolute value of gradient # g : [batch=1, channel=1, spatial=[B(mov), B(fix)]] # x : [batch=1, spatial=[1, vox], dim=2] # -> [batch=1, channel=1, spatial=[1, vox], 2] g = _spatial.grid_grad(g[:, None], x[:, None], bound, order, extrapolate) g = g.reshape(shape) # adjoint of affine function nn = torch.as_tensor(self.n, dtype=x.dtype, device=x.device) factor = nn / (max - min) if hess: factor = factor.square_() g = g.mul_(factor) if mask is not None: g *= mask[..., None] return g
def _make_image(option, dim=None, device=None): """ Load an image and build a Gaussian pyramid (if requireD) Returns: ImagePyramid """ dat, mask, affine = _load_image(option.files, dim=dim, device=device, label=option.label) dim = dat.dim() - 1 if option.mask: mask1 = mask mask, _, _ = _load_image([option.mask], dim=dim, device=device, label=option.label) if mask.shape[-dim:] != dat.shape[-dim:]: raise ValueError('Mask should have the same shape as the image. ' f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}') if mask1 is not None: mask = mask * mask1 del mask1 if option.world: # overwrite orientation matrix affine = io.transforms.map(option.world).fdata().squeeze() for transform in (option.affine or []): transform = io.transforms.map(transform).fdata().squeeze() affine = spatial.affine_lmdiv(transform, affine) if not option.discretize and any(option.rescale): dat = _rescale_image(dat, option.rescale) if option.pad: pad = option.pad if isinstance(pad[-1], str): *pad, unit = pad else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) pad = torch.as_tensor(pad, **utils.backend(voxel_size)) pad = pad / voxel_size pad = pad.floor().int().tolist() else: pad = [int(p) for p in pad] pad = py.make_list(pad, dim) if any(pad): affine, _ = spatial.affine_pad(affine, dat.shape[-dim:], pad, side='both') dat = utils.pad(dat, pad, side='both', mode=option.bound) if mask is not None: mask = utils.pad(mask, pad, side='both', mode=option.bound) if option.fwhm: fwhm = option.fwhm if isinstance(fwhm[-1], str): *fwhm, unit = fwhm else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size)) fwhm = fwhm / voxel_size dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound) image = objects.ImagePyramid(dat, levels=option.pyramid, affine=affine, dim=dim, bound=option.bound, mask=mask, extrapolate=option.extrapolate, method=option.pyramid_method) if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1: for level in image: level.preview = level.dat level.dat = _soft_quantize_image(level.dat, option.soft_quantize) elif not option.label and option.discretize: for level in image: level.preview = level.dat level.dat = _discretize_image(level.dat, option.discretize) return image
def em_prior(moving, fixed, prior=None, weights=None, fwhm=None, max_iter=32, tolerance=1e-5, verbose=0): """Estimate H_jk = P[x == j, z == k] by Expectation Maximization The objective function is Π_n p(x ; H, mu) == Π_n Σ_k p(x | z == k; H) p(z == k; mu) Parameters ---------- moving : (B, K, *spatial) tensor fixed : (B, J|1, *spatial) tensor prior : (B, J, K) tensor, optional weights : (B, 1, *spatial) tensor, optional fwhm : float, optional max_iter : int, default=32 Returns ------- mi : (B,) Mutual information N : (B,) Total observation weight prior : (B, J, K) tensor Regularized Joint histogram P[X,Z] / (P[X]*P[Z]) """ # ------------------------------------------------------------------ # PREPARATION # ------------------------------------------------------------------ dim = fixed.dim() - 2 moving, fixed, weights = spatial_prepare(moving, fixed, weights) B, K, J, shape = spatial_shapes(moving, fixed, prior) N = shape.numel() Nm = spatial_sum(weights, dim).squeeze(-1) # Flatten moving = moving.reshape([*moving.shape[:2], -1]) fixed = fixed.reshape([*fixed.shape[:2], -1]) weights = weights.reshape([*weights.shape[:2], -1]) if fwhm is None: fwhm = J / 64 # ------------------------------------------------------------------ # initialize normalized histogram # `prior` contains the "normalized" joint histogram p[x,y] / (p[x] p[y]) # However, it is used as the conditional histogram p[x|y] during # the E-step. This is because the normalized histogram is equal to # the conditional histogram up to a constant term that does not # depend on x, and therefore disappears after normalization of the # responsibilities. # ------------------------------------------------------------------ if prior is None: prior = moving.new_ones([B, J, K]) prior /= prior.sum(dim=[-1, -2], keepdim=True) prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1, keepdim=True) else: prior = prior.clone() # prior /= prior.sum(dim=-2, keepdim=True) # conditional X | Z # prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1, keepdim=True) # ------------------------------------------------------------------ # INFER PRIOR (EM) # ------------------------------------------------------------------ ll_prev = -float('inf') z = moving.new_zeros([B, K, N]) prior0 = torch.empty_like(prior) for n_iter in range(max_iter): # -------------------------------------------------------------- # E-step # ------ # estimate responsibilities of each moving cluster for each # fixed voxel using Bayes' rule: # p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k) # # . j[n] is the discretized fixed image # . p(z[n] == k) is the moving template # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k) # -------------------------------------------------------------- sample_prior(prior, fixed, z) z *= moving # -------------------------------------------------------------- # compute log-likelihood (log_sum of the posterior) # ll = Σ_n log p(x[n] == j[n]) # = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k) p(z[n] == k) # = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z} # -------------------------------------------------------------- ll = z.sum(-2, keepdim=True) + tiny z /= ll ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double) z *= weights # -------------------------------------------------------------- # M-step # ------ # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)] # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) delta(x[n] == j) # -------------------------------------------------------------- scatter_prior(prior0, fixed, z) # prior[fixed] <- z prior.copy_(prior0).add_(tiny) # make it a joint distribution prior /= prior.sum(dim=[-1, -2], keepdim=True).add_(tiny) if fwhm: # smooth "prior" for the prior prior = prior.transpose(-1, -2) prior = spatial.smooth(prior, dim=1, basis=0, fwhm=fwhm, bound='replicate') prior = prior.transpose(-1, -2) # prior /= prior.sum(dim=-2, keepdim=True) prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1, keepdim=True) if verbose > 0: success = ll.sum() > ll_prev happy = ':D' if success else ':(' end = '\n' if verbose > 1 else '\r' print( f'(em) | {n_iter:02d} | {(ll/Nm).mean():12.6g} | {happy}', end=end) if ll.sum() - ll_prev < tolerance * Nm.sum(): break ll_prev = ll.sum() # if verbose == 1: # print('') # NOTE: # We could return: `joint_prior / (prior_x * prior_z)` instead of # `conditional_prior = joint_prior / prior_z` as we currently do, which # would correspond to optimizing the mutual information instead of # the conditional likelihood. But the additional term only depends on # the fixed image, so does not have an impact for registration. # # Both have the same computational cost, and MI might have a slightly # nicer range so we could do that eventually. mi = (prior0 * prior.log()).sum() return mi, Nm, prior