def geodesic_dist(x, w, conn=1, nb_iter=1, dim=None): """Geodesic distance to a label Parameters ---------- x : (..., *spatial) tensor w : (..., *spatial) tensor conn : int nb_iter : int dim : int Returns ------- y : (..., *spatial) tensor """ in_dtype = x.dtype if in_dtype is not torch.bool: x = x > 0 x = x.to(torch.uint8) dim = dim or x.dim() d = torch.full(x.shape, float('inf'), **utils.backend(w)) d[x > 0] = 0 crop = (Ellipsis, *([slice(1, -1)]*dim)) dcrop = utils.unfold(d, [3]*dim, stride=1) w = utils.unfold(w, [3]*dim, stride=1) for n_iter in range(1, nb_iter+1): w0 = w[(Ellipsis, *([1]*dim))] for coord in itertools.product([0, 1], repeat=dim): if sum(coord) == 0 or sum(coord) > conn: continue mini_dist = sum(c*c for c in coord) ** 0.5 coords = set() for sgn in itertools.product([-1, 1], repeat=dim): coord1 = [1 + c*s for c, s in zip(coord, sgn)] if tuple(coord1) in coords: continue coords.add(tuple(coord1)) coord1 = (Ellipsis, *coord1) new_dist = (w[coord1] - w0).abs() * (dcrop[coord1] + mini_dist) new_dist.masked_fill_(torch.isfinite(new_dist).bitwise_not_(), float('inf')) msk = new_dist < d[crop] d[crop][msk] = new_dist[msk] print(d[crop].isfinite().sum()) msk = torch.isfinite(d).bitwise_not_() d[msk] = d[~msk].max() return d
def forward(self, q, k, v, **overload): """ Parameters ---------- q : (b, c, *spatial) Queries k : (b, c, *spatial) Keys v : (b, c, *spatial) Values Returns ------- x : (b, c, *spatial) """ kernel_size = overload.pop('kernel_size', self.kernel_size) stride = overload.pop('stride', self.kernel_size) padding = overload.pop('padding', self.padding) padding_mode = overload.pop('padding_mode', self.padding_mode) dim = q.dim() - 2 if padding == 'auto': k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode) v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode) elif padding: padding = [0] * 2 + py.make_list(padding, dim) k = utils.pad(k, padding, side='both', mode=padding_mode) v = utils.pad(v, padding, side='both', mode=padding_mode) # compute weights by query/key dot product kernel_size = py.make_list(kernel_size, dim) k = utils.unfold(k, kernel_size, stride) k = k.reshape([*k.shape[:dim + 2], -1]) k = utils.movedim(k, 1, -1) q = utils.movedim(q[..., None], 1, -1) k = math.softmax(linalg.dot(k, q), dim=-1) k = k[:, None] # add back channel dimension # compute new values by weight/value dot product v = utils.unfold(v, kernel_size, stride) v = v.reshape([*v.shape[:dim + 2], -1]) v = linalg.dot(k, v) return v
def forward(self, x): shape = x.shape[2:] dim = len(shape) pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, self.kernel, collapse=True) x = x[:, :, torch.randperm(x.shape[2])] x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x): shape = x.shape[2:] dim = len(shape) pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, self.kernel, collapse=True) for n in range(self.nb_swap): i1, i2 = torch.randint(low=0, high=x.shape[2]-1, size=(2,)).tolist() x[:,:,i1], x[:,:,i2] = x[:,:,i2], x[:,:,i1] x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False): """Apply a GRAPPA kernel to an accelerated k-space All batch elements should have the same sampling pattern Parameters ---------- kspace : ([*batch], coils, *freq) Accelerated k-space patterns : (*freq) tensor[long] Code of sampling pattern about each k-space location kernel_size : sequence of int GRAPPA kernel size kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor Dictionary of GRAPPA kernels (keys are pattern codes) Returns ------- kspace : ([*batch], coils, *freq) """ ndim = patterns.dim() coils, *freq = kspace.shape[-ndim - 1:] batch = kspace.shape[:-ndim - 1] kernel_size = py.make_list(kernel_size, ndim) kspace_out = kspace if not inplace: kspace_out = kspace_out.clone() kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size], side='both') kspace = utils.unfold(kspace, kernel_size, stride=1) def t(x): return x.transpose(-1, -2) for code, kernel in kernels.items(): kernel = kernels[code] pattern = code_to_pattern(code, kernel_size, device=kspace.device) pattern_size = pattern.sum() mask = patterns == code kspace1 = kspace[..., mask, :, :][..., pattern] kspace1 = kspace1.transpose(-2, -3) \ .reshape([*batch, -1, coils * pattern_size]) kernel = kernel.reshape([*batch, coils, coils * pattern_size]) kspace1 = t(kspace1.matmul(t(kernel))) kspace_out[..., mask] = kspace1 return kspace_out
def forward(self, x): dim = x.dim() - 2 backend = utils.backend(x) kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend) kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend) kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)] shape = x.shape[2:] kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)] pshape = [x+(k-x%k) for x,k in zip(shape,kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel, collapse=True) x = x[:, :, torch.randperm(x.shape[2])] x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x): dim = x.dim() - 2 backend = utils.backend(x) kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend) kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend) shape = x.shape[2:] for n in range(self.nb_drop): kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)] kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)] pshape = [x+(k-x%k) for x,k in zip(shape,kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel, collapse=True) i1 = torch.randint(low=0, high=x.shape[2]-1, size=(1,)).item() x[:,:,i1] = 0 x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def get_pattern_codes(sampling_mask, kernel_size): """Compute the pattern's code about each voxel Parameters ---------- sampling_mask : (*freq) tensor[bool] kernel_size : [sequence of] int Returns ------- pattern_mask : (*freq) tensor[long] """ ndim = sampling_mask.dim() kernel_size = py.make_list(kernel_size, ndim) sampling_mask = sampling_mask.long() sampling_mask = utils.pad(sampling_mask, [(k - 1) // 2 for k in kernel_size], side='both') sampling_mask = utils.unfold(sampling_mask, kernel_size, stride=1) return pattern_to_code(sampling_mask, ndim)
def get_patches(volume, patch=3, stride=1): """Extract patches from an image/volume. Parameters ---------- volume : (batch, *shape) tensor_like patch : int, default=3 stride : int, default=1 Returns ------- patched_volume : (nb_patches, batch, *patch_shape) """ dim = len(volume.shape) - 1 patch = utils.make_list(patch, dim) patch = utils.make_list(patch, dim) volume = utils.unfold(volume, patch, stride, True) volume = volume.transpose(0, 1) return volume
def forward(self, x, output_padding=None, output_shape=None): """ Parameters ---------- x : (batch, channel, *in_spatial) tensor output_padding : [sequence of] int, default=self.output_padding output_shape : [sequence of] int, default=self.output_shape Returns ------- x : (batch, channel, *out_spatial) tensor """ dim = x.dim() - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) new_shape = self.shape(x, output_padding=output_padding, output_shape=output_shape) y = x.new_zeros(new_shape) if self.fill: z = utils.unfold(y, stride) x = utils.unsqueeze(x, -1, dim) slicer = [slice(o, o+sz*st) for sz, st, o in zip(x.shape[2:], stride, offset)] slicer = [slice(None)]*2 + slicer subz = z[tuple(slicer)] slicer = [slice(mx) for mx in subz.shape[2:]] slicer = [slice(None)]*2 + slicer subz.copy_(x[tuple(slicer)]) else: slicer = [slice(o, None, s) for o, s in zip(offset, stride)] slicer = [slice(None)]*2 + slicer suby = y[tuple(slicer)] slicer = [slice(mx) for mx in suby.shape[2:]] slicer = [slice(None)]*2 + slicer suby.copy_(x[tuple(slicer)]) return y
def forward(self, x, model, **fwdargs): shape = x.shape[2:] dim = len(shape) if isinstance(self.patch_size, int): patch_size = [self.patch_size] * dim else: patch_size = self.patch_size if isinstance(self.stride, int): stride = [self.stride] * dim else: stride = self.stride pshape = [x+(k-x%s) for x,k,s in zip(shape,patch_size,stride)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel_size=self.patch_size, stride=self.stride, collapse=True) x = torch.split(x, 1, dim=2) x = [x_.reshape(tuple(x_.shape[:2])+tuple(x_.shape[3:])) for x_ in x] x = [model(x_, **fwdargs) for x_ in x] x = [x_.unsqueeze(dim=2) for x_ in x] x = torch.cat(x, dim=2) x = utils.fold(x, dim=dim, stride=self.stride, collapsed=True, shape=pshape, reduction=self.reduction) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def kernel_fit(calib, kernel_size, patterns, lam=0.01): """Compute GRAPPA kernels All batch elements should have the same sampling pattern Parameters ---------- calib : ([*batch], coils, *freq) Fully-sampled calibration data kernel_size : sequence[int] GRAPPA kernel size patterns : (N,) tensor[int] Code of patterns for which to learn a kernel. See `pattern_to_code`. lam : float, default=0.01 Tikhonov regularization Returns ------- kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor GRAPPA kernels """ kernel_size = py.make_list(kernel_size) ndim = len(kernel_size) coils, *freq = calib.shape[-ndim - 1:] batch = calib.shape[:-ndim - 1] # find all possible patterns patterns = utils.as_tensor(patterns, device=calib.device) if patterns.dtype is torch.bool: patterns = pattern_to_code(patterns, ndim) patterns = patterns.flatten() # learn one kernel for each pattern calib = utils.unfold(calib, kernel_size, collapse=True) # [*B, C, N, *K] calib = utils.movedim(calib, -ndim - 1, -ndim - 2) # [*B, N, C, *K] def t(x): return x.transpose(-1, -2) def conjt(x): return t(x).conj() def diag(x): return x.diagonal(0, -1, -2) kernels = {} center = [(k - 1) // 2 for k in kernel_size] center = (Ellipsis, *center) for pattern_code in patterns: if code_has_center(pattern_code, kernel_size): continue pattern = code_to_pattern(pattern_code, kernel_size, device=calib.device) pattern_size = pattern.sum() if pattern_size == 0: continue calib_target = calib[center] # [*B, N, C] calib_source = calib[..., pattern] # [*B, N, C, P] calib_size = calib_target.shape[-2] flat_shape = [*batch, calib_size, pattern_size * coils] calib_source = calib_source.reshape(flat_shape) # [*B, N, C*P] # solve H = conjt(calib_source).matmul(calib_source) # [*B, C*P, C*P] diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values) diag(H).add_(lam) g = conjt(calib_source).matmul(calib_target) # [*B, C*P, C] k = linalg.lmdiv(H, g).transpose(-1, -2) # [*B, C, C*P] k = k.reshape([*batch, coils, coils, pattern_size]) # [*B, C, C, P] kernels[pattern_code.item()] = k return kernels
def dilate_likely_voxels(labels, intensity, label=None, nb_iter=1, dist_ratio=1, half_patch=3, conn=1, dim=None): """Dilate labels into voxels with a similar intensity. Notes ----- .. Voxels get switched if their intensity is closer to the foreground intensity than to the background intensity, in terms of Gaussian distance (abs(intensity - class_mean)/class_std) computed in a local patch. .. Adapted from neurite-sandbox (author: B Fischl) Parameters ---------- labels : (..., *spatial) tensor Tensor of labels intensity : (..., *spatial) tensor Tensor of intensities label : int, optional Label to dilate. Default: binarize input labels. nb_iter : int, default=1 Number of iterations dist_ratio : float, default=1 Value that decides how much closer from the foreground intensity than the background intensity a voxel must be to be flipped. Smaller == easier to switch. half_patch : int, default=3 Half-size of the window used to compute intensity statistics. conn : int, default=1 Connectivity order dim : int, default=`labels.dim()` Number of spatial dimensions Returns ------- labels : (..., *spatial) tensor Dilated labels """ in_dtype = labels.dtype foreground = (labels > 0) if label is None else (labels == label) dim = dim or labels.dim() patch = [2*half_patch + 1] * dim unfold = lambda x: utils.unfold(x, patch, stride=1) intensity = unfold(intensity) def mean_var(intensity, fg): """Compute mean and variance""" sum_fg = fg.sum(list(range(-dim, 0))) mean_fg = (intensity * fg).sum(list(range(-dim, 0))) var_fg = (intensity.square() * fg).sum(list(range(-dim, 0))) mean_fg /= sum_fg var_fg /= sum_fg var_fg -= mean_fg.square() return mean_fg, var_fg if isinstance(conn, int): conn = connectivity_kernel(dim, conn, device=foreground.device, dtype=torch.uint8) for n_iter in range(nb_iter): dilated = dilate(foreground, conn=conn, dim=dim) dilated = dilated.bitwise_xor_(foreground) # Extract patches0 center = (Ellipsis, *([half_patch]*dim)) win_dilated = unfold(dilated) msk_dilated = win_dilated[center] win_dilated = win_dilated[msk_dilated, ...] win_intensity = intensity[msk_dilated, ...] win_fg = unfold(foreground)[msk_dilated, ...] win_bg = ~(win_fg | win_dilated) # compute statistics mean_fg, var_fg = mean_var(win_intensity, win_fg) mean_bg, var_bg = mean_var(win_intensity, win_bg) # compute criterion crit = dist_ratio * mean_fg < mean_bg win_intensity = win_intensity[center] mean_fg.sub_(win_intensity).abs_().div_(var_fg.sqrt_()) mean_bg.sub_(win_intensity).abs_().div_(var_bg.sqrt_()) # set value win_fg[center].masked_fill_(crit, 1) unfold(foreground)[msk_dilated, ...] = win_fg if label is None: labels = foreground.to(in_dtype) else: labels = labels.clone() labels[foreground] = label return labels
def extract_patches(inp, size=64, stride=None, output=None, transform=None): """Extracgt patches from a 3D volume. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. size : [sequence of] int, default=64 Patch size. stride : [sequence of] int, default=size Stride between patches. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{i}_{j}_{k}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. transform : [sequence of] str, optional Output filename(s) of the corresponding transforms. Not written by default. Returns ------- output : list[str] or (tensor, tensor) If the input is a path, the output paths are returned. Else, the unfolded data and orientation matrices are returned. Data will have shape (nx, ny, nz, *size, *channels). Affines will have shape (nx, ny, nz, 4, 4). """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.fdata(), f.affine) if output is None: output = '{dir}{sep}{base}.{i}_{j}_{k}{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp shape = dat.shape[:3] size = py.make_list(size, 3) stride = py.make_list(stride, 3) stride = [st or sz for st, sz in zip(stride, size)] dat = utils.movedim(dat, [0, 1, 2], [-3, -2, -1]) dat = utils.unfold(dat, size, stride) dat = utils.movedim(dat, [-6, -5, -4, -3, -2, -1], [0, 1, 2, 3, 4, 5]) aff = aff0.new_empty(dat.shape[:3] + aff0.shape) for i in range(dat.shape[0]): for j in range(dat.shape[1]): for k in range(dat.shape[2]): index = (i, j, k) sub = [slice(st*idx, st*idx + sz) for st, sz, idx in zip(stride, size, index)] aff[i, j, k], _ = spatial.affine_sub(aff0, shape, tuple(sub)) formatted_output = [] if output: output = py.make_list(output, py.prod(dat.shape[:3])) formatted_output = [] for i in range(dat.shape[0]): for j in range(dat.shape[1]): for k in range(dat.shape[2]): out1 = output.pop(0) if is_file: out1 = out1.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, i=i+1, j=j+1, k=k+1) io.volumes.savef(dat[i, j, k], out1, like=fname, affine=aff[i, j, k]) else: out1 = out1.format(sep=os.path.sep, i=i, j=j, k=k) io.volumes.savef(dat[i, j, k], out1, affine=aff[i, j, k]) formatted_output.append(out1) if transform: transform = py.make_list(transform, py.prod(dat.shape[:3])) for i in range(dat.shape[0]): for j in range(dat.shape[1]): for k in range(dat.shape[2]): trf1 = transform.pop(0) if is_file: trf1 = trf1.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, i=i+1, j=j+1, k=k+1) else: trf1 = trf1.format(sep=os.path.sep, i=i+1, j=j+1, k=k+1) io.transforms.savef(torch.eye(4), trf1, source=aff0, target=aff[i, j, k]) if is_file: return formatted_output else: return dat, aff
def mov2fix(fixed, moving, warped, vel=None, cat=False, dim=None, title=None): """Plot registration live""" if plt is None: return warped = warped.detach() if vel is not None: vel = vel.detach() dim = dim or (fixed.dim() - 1) if fixed.dim() < dim + 2: fixed = fixed[None] if moving.dim() < dim + 2: moving = moving[None] if warped.dim() < dim + 2: warped = warped[None] if vel is not None: if vel.dim() < dim + 2: vel = vel[None] nb_channels = fixed.shape[-dim - 1] nb_batch = len(fixed) if dim == 3: fixed = fixed[..., fixed.shape[-1] // 2] moving = moving[..., moving.shape[-1] // 2] warped = warped[..., warped.shape[-1] // 2] if vel is not None: vel = vel[..., vel.shape[-2] // 2, :] if vel is not None: vel = vel.square().sum(-1).sqrt() if cat: moving = math.softmax(moving, dim=1, implicit=True) warped = math.softmax(warped, dim=1, implicit=True) checker = fixed.clone() patch = max([s // 8 for s in fixed.shape]) checker_unfold = utils.unfold(checker, [patch] * 2, [2 * patch] * 2) warped_unfold = utils.unfold(warped, [patch] * 2, [2 * patch] * 2) checker_unfold.copy_(warped_unfold) nb_rows = min(nb_batch, 3) nb_cols = 4 + (vel is not None) for b in range(nb_rows): plt.subplot(nb_rows, nb_cols, b * nb_cols + 1) plt.imshow(moving[b, 0].cpu()) plt.title('moving') plt.axis('off') plt.subplot(nb_rows, nb_cols, b * nb_cols + 2) plt.imshow(warped[b, 0].cpu()) plt.title('moved') plt.axis('off') plt.subplot(nb_rows, nb_cols, b * nb_cols + 3) plt.imshow(checker[b, 0].cpu()) plt.title('checker') plt.axis('off') plt.subplot(nb_rows, nb_cols, b * nb_cols + 4) plt.imshow(fixed[b, 0].cpu()) plt.title('fixed') plt.axis('off') if vel is not None: plt.subplot(nb_rows, nb_cols, b * nb_cols + 5) plt.imshow(vel[b].cpu()) plt.title('velocity') plt.axis('off') plt.colorbar() if title: plt.suptitle(title) plt.gcf().canvas.flush_events() plt.show(block=False)
def forward(self, x, y, **overload): """ Parameters ---------- x : tensor (batch, 1, *spatial) y : tensor (batch, 1, *spatial) overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ # check inputs x = torch.as_tensor(x) y = torch.as_tensor(y) nb_dim = x.dim() - 2 if x.shape[1] != 1 or y.shape[1] != 1: raise ValueError('Mutual info is only implemented for ' 'single channel tensors.') shape = x.shape[2:] # get parameters min_val = overload.get('min_val', self.min_val) max_val = overload.get('max_val', self.max_val) nb_bins = overload.get('nb_bins', self.nb_bins) fwhm = overload.get('fwhm', self.fwhm) order = overload.get('order', self.order) normalize = overload.get('normalize', self.normalize) patch_size = overload.get('patch_size', self.patch_size) patch_stride = overload.get('patch_stride', self.patch_stride) mask = overload.get('mask', self.mask) # reshape if patch_size: # extract patches about each voxel patch_size = make_list(patch_size, nb_dim) patch_size = [ min(pch or dim, dim) for pch, dim in zip(patch_size, shape) ] x = utils.unfold(x[:, 0], patch_size, patch_stride, collapse=True) y = utils.unfold(y[:, 0], patch_size, patch_stride, collapse=True) # collapse spatial dimensions -> we don't need them anymore x = x.reshape((*x.shape[:2], -1)) y = y.reshape((*y.shape[:2], -1)) # exclude masked values mask_x, mask_y = make_list(mask, 2) mask = None if callable(mask_x): mask = mask_x(x) elif mask_x is not None: mask = x <= mask_x if callable(mask_y): mask = (mask & mask_y(y)) if mask is not None else mask_y(y) elif mask_y is not None: mask = (mask & (y <= mask_y)) if mask is not None else (y <= mask_y) if order == 'inf': p_xy = joint_hist_gaussian(x, y, nb_bins, min_val, max_val, fwhm, mask) else: p_xy = joint_hist_spline(x, y, nb_bins, min_val, max_val, order, mask) def pnorm(x, dims=-1): """Normalize a tensor so that it's sum across `dims` is one.""" dims = make_list(dims) x = x.clamp_min_(eps(x.dtype)) x = x / nansum(x, dim=dims, keepdim=True) return x # compute probabilities p_x = pnorm(p_xy.sum(dim=-2)) # -> [B, C, nb_bins] p_y = pnorm(p_xy.sum(dim=-1)) # -> [B, C, nb_bins] p_xy = pnorm(p_xy, [-1, -2]) # compute entropies h_x = -(p_x * p_x.log()).sum(dim=-1) # -> [B, C] h_y = -(p_y * p_y.log()).sum(dim=-1) # -> [B, C] h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2]) # -> [B, C] # negative mutual information mi = h_xy - (h_x + h_y) # normalize if normalize == 'studholme': mi = mi / h_xy.clamp_min_(eps(x.dtype)) mi += 1 elif normalize not in (None, 'none'): normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \ (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \ torch.min if normalize == 'min' else \ torch.max if normalize == 'max' else \ normalize mi = mi / normalize(h_x, h_y).clamp_min_(eps(x.dtype)) mi += 1 # reduce return super().forward(mi)
def mov2fix(self, fixed, moving, warped, vel=None, cat=False, dim=None, title=None): """Plot registration live""" import time tic = self._last_plot toc = time.time() if toc - tic < 1/self.framerate: return self._last_plot = toc import matplotlib.pyplot as plt warped = warped.detach() if vel is not None: vel = vel.detach() dim = dim or (fixed.dim() - 1) if fixed.dim() < dim + 2: fixed = fixed[None] if moving.dim() < dim + 2: moving = moving[None] if warped.dim() < dim + 2: warped = warped[None] if vel is not None: if vel.dim() < dim + 2: vel = vel[None] nb_channels = fixed.shape[-dim - 1] nb_batch = len(fixed) def rescale2d(x): if not x.dtype.is_floating_point: x = x.float() mn, mx = utils.quantile(x, [0.005, 0.995], dim=range(-2, 0), bins=1024).unbind(-1) mx = mx.max(mn + 1e-8) mn, mx = mn[..., None, None], mx[..., None, None] x = x.sub(mn).div_(mx-mn).clamp_(0, 1) return x if dim == 3: fixed = [fixed[..., fixed.shape[-1] // 2], fixed[..., fixed.shape[-2] // 2, :], fixed[..., fixed.shape[-3] // 2, :, :]] fixed = [rescale2d(f) for f in fixed] moving = [moving[..., moving.shape[-1] // 2], moving[..., moving.shape[-2] // 2, :], moving[..., moving.shape[-3] // 2, :, :]] moving = [rescale2d(f) for f in moving] warped = [warped[..., warped.shape[-1] // 2], warped[..., warped.shape[-2] // 2, :], warped[..., warped.shape[-3] // 2, :, :]] warped = [rescale2d(f) for f in warped] if vel is not None: vel = [vel[..., vel.shape[-2] // 2, :], vel[..., vel.shape[-3] // 2, :, :], vel[..., vel.shape[-4] // 2, :, :, :]] vel = [v.square().sum(-1).sqrt() for v in vel] else: fixed = [rescale2d(f) for f in fixed] moving = [rescale2d(f) for f in moving] warped = [rescale2d(f) for f in warped] vel = [vel.square().sum(-1).sqrt()] if vel is not None else [] if cat: moving = [math.softmax(img, dim=1, implicit=True) for img in moving] warped = [math.softmax(img, dim=1, implicit=True) for img in warped] checker = [] for f, w in zip(fixed, warped): patch = max([s // 8 for s in f.shape]) patch = [max(min(patch, s), 1) for s in f.shape] broad_shape = utils.expanded_shape(f.shape, w.shape) f = f.expand(broad_shape).clone() w = w.expand(broad_shape) checker_unfold = utils.unfold(f, patch, [2*p for p in patch]) warped_unfold = utils.unfold(w, patch, [2*p for p in patch]) checker_unfold.copy_(warped_unfold) checker.append(f) kdim = 3 if dim == 3 else 1 bdim = min(nb_batch, 3) nb_rows = kdim * bdim + 1 nb_cols = 4 + bool(vel) if len(self.figure.axes) != nb_rows*nb_cols: self.figure.clf() for b in range(bdim): for k in range(kdim): plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 1) plt.imshow(moving[k][b, 0].cpu()) if b == 0 and k == 0: plt.title('moving') plt.axis('off') plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 2) plt.imshow(warped[k][b, 0].cpu()) if b == 0 and k == 0: plt.title('moved') plt.axis('off') plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 3) plt.imshow(checker[k][b, 0].cpu()) if b == 0 and k == 0: plt.title('checker') plt.axis('off') plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 4) plt.imshow(fixed[k][b, 0].cpu()) if b == 0 and k == 0: plt.title('fixed') plt.axis('off') if vel: plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 5) plt.imshow(vel[k][b].cpu()) if b == 0 and k == 0: plt.title('displacement') plt.axis('off') plt.colorbar() plt.subplot(nb_rows, 1, nb_rows) plt.plot(list(range(1, len(self.all_ll)+1)), self.all_ll) plt.ylabel('NLL') plt.xlabel('iteration') if title: plt.suptitle(title) self.figure.canvas.draw() self.plt_saved = [self.figure.canvas.copy_from_bbox(ax.bbox) for ax in self.figure.axes] self.figure.canvas.flush_events() plt.show(block=False) else: self.figure.canvas.draw() for elem in self.plt_saved: self.figure.canvas.restore_region(elem) for b in range(bdim): for k in range(kdim): j = (b + k*bdim) * nb_cols self.figure.axes[j].images[0].set_data(moving[k][b, 0].cpu()) self.figure.axes[j+1].images[0].set_data(warped[k][b, 0].cpu()) self.figure.axes[j+2].images[0].set_data(checker[k][b, 0].cpu()) self.figure.axes[j+3].images[0].set_data(fixed[k][b, 0].cpu()) if vel is not None: self.figure.axes[j+4].images[0].set_data(vel[k][b].cpu()) lldata = (list(range(1, len(self.all_ll)+1)), self.all_ll) self.figure.axes[-1].lines[0].set_data(lldata) if title: self.figure._suptitle.set_text(title) for ax in self.figure.axes: ax.draw_artist(ax.images[0]) self.figure.canvas.blit(ax.bbox) self.figure.canvas.flush_events()