def _loss_ssqd_jtv(dat_x, dat_y, tau, lam, voxel_size=1, side='f', bound='dct2'): """Computes an image denoising loss function, where: * fidelity term: sum-of-squared differences (SSQD) * regularisation term: joint total variation (JTV) * hyper-parameters: tau, lambda Parameters ---------- dat_x : (dmx, dmy, dmz, nchannels) tensor Input image dat_y : (dmx, dmy, dmz, nchannels) tensor Reconstruction image tau : (nchannels) tensor Channel-specific noise precisions lam : (nchannels) tensor Channel-specific regularisation values voxel_size : float or sequence[float], default=1 Unit size used in the denominator of the gradient. side : {'c', 'f', 'b'}, default='f' * 'c': central finite differences * 'f': forward finite differences * 'b': backward finite differences bound : {'dct2', 'dct1', 'dst2', 'dst1', 'dft', 'repeat', 'zero'}, default='dct2' Boundary condition. Returns ---------- nll_yx : tensor Loss function value (negative log-posterior) """ # compute negative log-likelihood (SSQD fidelity term) nll_xy = 0.5 * torch.sum(tau * torch.sum( (dat_x - dat_y)**2, dim=(0, 1, 2))) # compute gradients of reconstruction, shape=(dmx, dmy, dmz, nchannels, dmgr) nll_y = diff(dat_y, order=1, dim=(0, 1, 2), voxel_size=voxel_size, side=side, bound=bound) # modulate channels with regularisation nll_y = lam[None, None, None, :, None] * nll_y # compute negative log-prior (JTV regularisation term) nll_y = torch.sum( nll_y**2 + eps(), dim=-1) # to gradient magnitudes (sum over gradient directions) nll_y = torch.sum(nll_y, dim=-1) # sum over reconstruction channels nll_y = torch.sqrt(nll_y) nll_y = torch.sum(nll_y) # sum over voxels # compute negative log-posterior (loss function) nll_yx = nll_xy + nll_y return nll_yx
def fit_se_log(log_cov, sqdist): """Fit the amplitude and length-scale of a squared-exponential kernel Parameters ---------- log_cov : (*batch, vox, vox) Log of the empirical covariance matrix sqdist : tuple[int] or (vox, vox) tensor If a tensor -> it is the pre-computed squared distance map If a tuple -> it is the shape and we build the distance map Returns ------- sig : (*batch,) tensor Amplitude of the kernel lam : (*batch,) tensor Length-scale of the kernel """ log_cov = torch.as_tensor(log_cov).clone() backend = utils.backend(log_cov) if not torch.is_tensor(sqdist): shape = sqdist sqdist = dist_map(shape, **backend) else: sqdist = sqdist.to(**backend).clone() # linear regression eps = constants.eps(log_cov.dtype) y = log_cov.reshape([-1, py.prod(sqdist.shape)]) msk = torch.isfinite(y) y[~msk] = 0 y0 = y.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True) y -= y0 x = sqdist.flatten() * msk x0 = x.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True) x -= x0 b = (x * y).sum(-1) / x.square().sum(-1).clamp_min_(eps) a = y0 - b * x0 a = a[..., 0] lam = b.reciprocal_().mul_(-0.5).sqrt_() sig = a.div_(2).exp_() return sig, lam
def histeq(x, n=1024, dim=None): """Histogram equalization Notes ----- .. The minimum and maximum values of the input tensor are preserved. .. A piecewise linear transform is applied so that the output quantiles match those of a "template" histogram. .. By default, the template histogram is flat. Parameters ---------- x : tensor Input image n : int or tensor Number of bins or target histogram dim : [sequence of] int, optional Dimensions along which to compute the histogram. Default: all. Returns ------- x : tensor Transformed image """ x = torch.as_tensor(x) # compute target cumulative histogram if torch.is_tensor(n): other_hist = n n = len(other_hist) else: other_hist = x.new_full([n], 1 / n) other_hist += constants.eps(other_hist.dtype) other_hist = other_hist.cumsum(-1) / other_hist.sum(-1, keepdim=True) other_hist[..., -1] = 1 # compute cumulative histogram min = math.min(x, dim=dim) max = math.max(x, dim=dim) batch_shape = min.shape hist = utils.histc(x, n, dim=dim, min=min, max=max) hist += constants.eps(hist.dtype) hist = hist.cumsum(-1) / hist.sum(-1, keepdim=True) hist[..., -1] = 1 # match histograms hist = hist.reshape([-1]) shift = _hist_to_quantile(other_hist[None], hist) shift = shift.reshape([-1, n]) shift /= n # reshape shift = shift.reshape([*batch_shape, n]) # interpolate and apply shift eps = constants.eps(x.dtype) grid = x.clone() grid = grid.mul_(n / (max - min + eps)).add_(n / (1 - max / min)).sub_(1) grid = grid.flatten()[:, None, None] shift = spatial.grid_pull(shift.reshape([-1, 1, n]), grid, bound='zero', extrapolate=True) shift = shift.reshape(x.shape) x = (x - min) * shift + min return x
def intensity_preproc(*images, min=None, max=None, eq=None): """(Joint) rescaling and intensity equalizing. Parameters ---------- *images : (*batch, H, W) tensor Input (batch of) 2d images. All batch shapes should be broadcastable together. min : tensor_like, optional Minimum value. Should be broadcastable to batch. Default: 5th percentile of each batch element. max : tensor_like, optional Maximum value. Should be broadcastable to batch. Default: 95th percentile of each batch element. eq : {'linear', 'quadratic', 'log', None} or float, default=None Apply histogram equalization. If 'quadratic' or 'log', the histogram of the transformed signal is equalized. If float, the signal is taken to that power before being equalized. Returns ------- *images : (*batch, H, W) tensor Preprocessed images. Intensities are scaled within [0, 1]. """ if len(images) == 1: images = [utils.to_max_backend(*images)] else: images = utils.to_max_backend(*images) backend = utils.backend(images[0]) eps = constants.eps(images[0].dtype) # rescale min/max min = py.make_list(min, len(images)) max = py.make_list(max, len(images)) min = [ utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True) if mn is None else torch.as_tensor(mn, **backend)[None, None] for image, mn in zip(images, min) ] min, *othermin = min for mn in othermin: min = torch.min(min, mn) del othermin max = [ utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True) if mx is None else torch.as_tensor(mx, **backend)[None, None] for image, mx in zip(images, max) ] max, *othermax = max for mx in othermax: max = torch.max(max, mx) del othermax images = [torch.max(torch.min(image, max), min) for image in images] images = [ image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min)) for image in images ] if not eq: return tuple(images) if len(images) > 1 else images[0] # reshape and concatenate batch = utils.expanded_shape(*[image.shape[:-2] for image in images]) images = [image.expand([*batch, *image.shape[-2:]]) for image in images] shapes = [image.shape[-2:] for image in images] chunks = [py.prod(s) for s in shapes] images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)] images = torch.cat(images, dim=-1) if eq is True: eq = 'linear' if not isinstance(eq, str): if eq >= 0: images = images.pow(eq) else: images = images.clamp_min_(constants.eps(images.dtype)).pow(eq) elif eq.startswith('q'): images = images.square() elif eq.startswith('log'): images = images.clamp_min_(constants.eps(images.dtype)).log() images = histeq(images, dim=-1) if not (isinstance(eq, str) and eq.startswith('lin')): # rescale min/max images -= math.min(images, dim=-1, keepdim=True) images /= math.max(images, dim=-1, keepdim=True) images = images.split(chunks, dim=-1) images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)] return tuple(images) if len(images) > 1 else images[0]
def is_inside(points, vertices, faces=None): """Test if a point is inside a polygon/surface. The polygon or surface *must* be closed. Parameters ---------- points : (..., dim) tensor Coordinates of points to test vertices : (nv, dim) tensor Vertex coordinates faces : (nf, dim) tensor[int] Faces are encoded by the indices of its vertices. By default, assume that vertices are ordered and define a closed curve Returns ------- check : (...) tensor[bool] """ # This function uses a ray-tracing technique: # # A half-line is started in each point. If it crosses an even # number of faces, it is inside the shape. If it crosses an even # number of faces, it is not. # # In practice, we loop through faces (as we expect there are much # less vertices than voxels) and compute intersection points between # all lines and each face in a batched fashion. We only want to # send these rays in one direction, so we keep aside points whose # intersection have a positive coordinate along the ray. points = torch.as_tensor(points) vertices = torch.as_tensor(vertices) if faces is None: faces = [(i, i + 1) for i in range(len(vertices) - 1)] faces += [(len(vertices) - 1, 0)] faces = utils.as_tensor(faces, dtype=torch.long) points, vertices = utils.to_max_dtype(points, vertices) points, vertices, faces = utils.to_max_device(points, vertices, faces) backend = utils.backend(points) batch = points.shape[:-1] dim = points.shape[-1] eps = constants.eps(points.dtype) cross = points.new_zeros(batch, dtype=torch.long) ray = torch.randn(dim, **backend) for face in faces: face = vertices[face] # compute normal vector origin = face[0] if dim == 3: u = face[1] - face[0] v = face[2] - face[0] norm = torch.stack([ u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2], u[0] * v[1] - u[1] * v[0] ]) else: assert dim == 2 u = face[1] - face[0] norm = torch.stack([-u[1], u[0]]) # check co-linearity between face and ray colinear = linalg.dot(ray, norm).abs() / (ray.norm() * norm.norm()) < eps if colinear: continue # compute intersection between ray and plane # plane: <norm, x - origin> = 0 # line: x = p + t*u # => <norm, p + t*u - origin> = 0 intersection = linalg.dot(norm, points - origin) intersection /= linalg.dot(norm, ray) halfmask = intersection >= 0 # we only want to shoot in one direction intersection = intersection[halfmask] halfpoints = points[halfmask] intersection = intersection[..., None] * (-ray) intersection += halfpoints # check if the intersection is inside the face # first, we project it onto a frame of dimension `dim-1` # defined by (origin, (u, v)) intersection -= origin if dim == 3: interu = linalg.dot(intersection, u) interv = linalg.dot(intersection, v) intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1) else: intersection = linalg.dot(intersection, u) intersection /= u.norm().square_() intersection = (intersection >= 0) & (intersection < 1) cross[halfmask] += intersection # check that the number of crossings is even cross = cross.bitwise_and_(1).bool() return cross
def pnorm(x, dims=-1): """Normalize a tensor so that it's sum across `dims` is one.""" dims = make_list(dims) x = x.clamp_min_(eps(x.dtype)) x = x / nansum(x, dim=dims, keepdim=True) return x
def forward(self, x, y, **overload): """ Parameters ---------- x : tensor (batch, 1, *spatial) y : tensor (batch, 1, *spatial) overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ # check inputs x = torch.as_tensor(x) y = torch.as_tensor(y) nb_dim = x.dim() - 2 if x.shape[1] != 1 or y.shape[1] != 1: raise ValueError('Mutual info is only implemented for ' 'single channel tensors.') shape = x.shape[2:] # get parameters min_val = overload.get('min_val', self.min_val) max_val = overload.get('max_val', self.max_val) nb_bins = overload.get('nb_bins', self.nb_bins) fwhm = overload.get('fwhm', self.fwhm) order = overload.get('order', self.order) normalize = overload.get('normalize', self.normalize) patch_size = overload.get('patch_size', self.patch_size) patch_stride = overload.get('patch_stride', self.patch_stride) mask = overload.get('mask', self.mask) # reshape if patch_size: # extract patches about each voxel patch_size = make_list(patch_size, nb_dim) patch_size = [ min(pch or dim, dim) for pch, dim in zip(patch_size, shape) ] x = utils.unfold(x[:, 0], patch_size, patch_stride, collapse=True) y = utils.unfold(y[:, 0], patch_size, patch_stride, collapse=True) # collapse spatial dimensions -> we don't need them anymore x = x.reshape((*x.shape[:2], -1)) y = y.reshape((*y.shape[:2], -1)) # exclude masked values mask_x, mask_y = make_list(mask, 2) mask = None if callable(mask_x): mask = mask_x(x) elif mask_x is not None: mask = x <= mask_x if callable(mask_y): mask = (mask & mask_y(y)) if mask is not None else mask_y(y) elif mask_y is not None: mask = (mask & (y <= mask_y)) if mask is not None else (y <= mask_y) if order == 'inf': p_xy = joint_hist_gaussian(x, y, nb_bins, min_val, max_val, fwhm, mask) else: p_xy = joint_hist_spline(x, y, nb_bins, min_val, max_val, order, mask) def pnorm(x, dims=-1): """Normalize a tensor so that it's sum across `dims` is one.""" dims = make_list(dims) x = x.clamp_min_(eps(x.dtype)) x = x / nansum(x, dim=dims, keepdim=True) return x # compute probabilities p_x = pnorm(p_xy.sum(dim=-2)) # -> [B, C, nb_bins] p_y = pnorm(p_xy.sum(dim=-1)) # -> [B, C, nb_bins] p_xy = pnorm(p_xy, [-1, -2]) # compute entropies h_x = -(p_x * p_x.log()).sum(dim=-1) # -> [B, C] h_y = -(p_y * p_y.log()).sum(dim=-1) # -> [B, C] h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2]) # -> [B, C] # negative mutual information mi = h_xy - (h_x + h_y) # normalize if normalize == 'studholme': mi = mi / h_xy.clamp_min_(eps(x.dtype)) mi += 1 elif normalize not in (None, 'none'): normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \ (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \ torch.min if normalize == 'min' else \ torch.max if normalize == 'max' else \ normalize mi = mi / normalize(h_x, h_y).clamp_min_(eps(x.dtype)) mi += 1 # reduce return super().forward(mi)
def joint_hist_gaussian(x, y, bins=64, min=None, max=None, fwhm=1, mask=None): """Compute joint histogram with Gaussian window Parameters ---------- x : (batch, channel, voxels) tensor y : (batch, channel, voxels) tensor bins : int or (int, int), default=64 min : float or (float, float), optional max : float or (float, float), optional fwhm : float or (float, float), default=1 mask : (batch, channel, voxels) tensor, optional Returns ------- h : (batch, channel, bins, bins) """ backend = utils.backend(x) x_min, y_min = py.make_list(min, 2) x_max, y_max = py.make_list(max, 2) x_nbins, y_nbins = py.make_list(bins, 2) x_fwhm, y_fwhm = py.make_list(fwhm, 2) def get_bins(x, min, max, nbins): """Compute the histogram bins.""" # TODO: It's suboptimal to have bin centers fall at the # min and max. Better to shift them slightly inside. if mask is not None: # we set masked values to nan so that we can exclude them when # computing min/max val_nan = torch.as_tensor(nan, **backend) x = torch.where(mask, val_nan, x) min_fn = nanmin max_fn = nanmax else: min_fn = lambda *a, **k: torch.min(*a, **k).values max_fn = lambda *a, **k: torch.max(*a, **k).values min = min_fn(x, dim=-1) if min is None else min min = torch.as_tensor(min, **backend) min = unsqueeze(min, dim=2, ndim=4 - min.dim()) # -> shape = [B, C, 1, 1] max = max_fn(x, dim=-1) if max is None else max max = torch.as_tensor(max, **backend) max = unsqueeze(max, dim=2, ndim=4 - max.dim()) # -> shape = [B, C, 1, 1] bins = torch.linspace(0, 1, nbins, **backend) bins = unsqueeze(bins, dim=0, ndim=3) # -> [1, 1, 1, nb_bins] bins = min + bins * (max - min) # -> [B, C, 1, nb_bins] binwidth = (max - min) / (nbins - 1) # -> [B, C, 1, 1] return bins, binwidth # prepare bins x_bins, x_binwidth = get_bins(x.detach(), x_min, x_max, x_nbins) y_bins, y_binwidth = get_bins(y.detach(), y_min, y_max, y_nbins) # we transform our nans into inf so that they get zero-weight # in the histogram if mask is not None: val_inf = torch.as_tensor(inf, **backend) x = torch.where(mask, val_inf, x) y = torch.where(mask, val_inf, y) # compute distances and collapse x = x[..., None] # -> [B, C, N, 1] y = y[..., None] # -> [B, C, N, 1] x_var = ((x_fwhm * x_binwidth)**2) / (8 * math.log(2)) x_var = x_var.clamp(min=eps(x.dtype)) x = -(x - x_bins).square() / (2 * x_var) x = x.exp() y_var = ((y_fwhm * y_binwidth)**2) / (8 * math.log(2)) y_var = y_var.clamp(min=eps(y.dtype)) y = -(y - y_bins).square() / (2 * y_var) y = y.exp() # -> [B, C, N, nb_bins] x = x.transpose(-1, -2) h = torch.matmul(x, y) # -> [B, C, nb_bins, nb_bins] return h
def forward(self, x, y, **overload): """ Parameters ---------- x : tensor (batch, 1, *spatial) y : tensor (batch, 1, *spatial) overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ # check inputs x = torch.as_tensor(x) y = torch.as_tensor(y) dtype = x.dtype device = x.device nb_dim = x.dim() - 2 if x.shape[1] != 1 or y.shape[1] != 1: raise ValueError('Mutual info is only implemented for ' 'single channel tensors.') shape = x.shape[2:] # get parameters x_min, y_min = make_list(overload.get('min_val', self.min_val), 2) x_max, y_max = make_list(overload.get('max_val', self.max_val), 2) x_nbins, y_nbins = make_list(overload.get('nb_bins', self.nb_bins), 2) x_fwhm, y_fwhm = make_list(overload.get('fwhm', self.fwhm), 2) normalize = overload.get('normalize', self.normalize) patch_size = overload.get('patch_size', self.patch_size) patch_stride = overload.get('patch_stride', self.patch_stride) mask = overload.get('mask', self.mask) # reshape if patch_size: # extract patches about each voxel patch_size = make_list(patch_size, nb_dim) patch_size = [pch or dim for pch, dim in zip(patch_size, shape)] patch_stride = make_list(patch_stride, nb_dim) patch_stride = [ sz if st is None else st for sz, st in zip(patch_size, patch_stride) ] x = x[:, 0, ...] y = y[:, 0, ...] for d, (sz, st) in enumerate(zip(patch_size, patch_stride)): x = x.unfold(dimension=d + 1, size=sz, step=st) y = y.unfold(dimension=d + 1, size=sz, step=st) x = x.reshape((x.shape[0], -1, *patch_size)) y = y.reshape((y.shape[0], -1, *patch_size)) # now, the spatial dimension of x and y is `patch_size` and # their channel dimension is the number of patches # collapse spatial dimensions -> we don't need them anymore x = x.reshape((*x.shape[:2], -1)) y = y.reshape((*y.shape[:2], -1)) # exclude masked values mask_x, mask_y = make_list(mask, 2) mask = None if callable(mask_x): mask = mask_x(x) elif mask_x is not None: mask = x <= mask_x if callable(mask_y): mask = (mask & mask_y(y)) if mask is not None else mask_y(y) elif mask_y is not None: mask = (mask & (y <= mask_y)) if mask is not None else (y <= mask_y) def get_bins(x, min, max, nbins): """Compute the histogram bins.""" # TODO: It's suboptimal to have bin centers fall at the # min and max. Better to shift them slightly inside. if mask is not None: # we set masked values to nan so that we can exclude them when # computing min/max val_nan = torch.as_tensor(nan, dtype=x.dtype, device=x.device) x = torch.where(mask, val_nan, x) min_fn = nanmin max_fn = nanmax else: min_fn = torch.min max_fn = torch.max min = min_fn(x, dim=-1).values if min is None else min min = torch.as_tensor(min, dtype=dtype, device=device) min = unsqueeze(min, dim=2, ndim=4 - min.dim()) # -> shape = [B, C, 1, 1] max = max_fn(x, dim=-1).values if max is None else max max = torch.as_tensor(max, dtype=dtype, device=device) max = unsqueeze(max, dim=2, ndim=4 - max.dim()) # -> shape = [B, C, 1, 1] bins = torch.linspace(0, 1, nbins, dtype=dtype, device=device) bins = unsqueeze(bins, dim=0, ndim=3) # -> [1, 1, 1, nb_bins] bins = min + bins * (max - min) # -> [B, C, 1, nb_bins] binwidth = (max - min) / (nbins - 1) # -> [B, C, 1, 1] return bins, binwidth # prepare bins x_bins, x_binwidth = get_bins(x.detach(), x_min, x_max, x_nbins) y_bins, y_binwidth = get_bins(y.detach(), y_min, y_max, y_nbins) # we transform our nans into inf so that they get zero-weight # in the histogram if mask is not None: val_inf = torch.as_tensor(inf, dtype=x.dtype, device=x.device) x = torch.where(mask, val_inf, x) y = torch.where(mask, val_inf, y) # compute distances and collapse x = x[..., None] # -> [B, C, N, 1] y = y[..., None] # -> [B, C, N, 1] x_var = ((x_fwhm * x_binwidth)**2) / (8 * math.log(2)) x_var = x_var.clamp(min=eps(x.dtype)) x = -(x - x_bins).square() / (2 * x_var) x = x.exp() y_var = ((y_fwhm * y_binwidth)**2) / (8 * math.log(2)) y_var = y_var.clamp(min=eps(y.dtype)) y = -(y - y_bins).square() / (2 * y_var) y = y.exp() # -> [B, C, N, nb_bins] def pnorm(x, dims=-1): """Normalize a tensor so that it's sum across `dims` is one.""" dims = make_list(dims) x = x.clamp(min=eps(x.dtype)) x = x / nansum(x, dim=dims, keepdim=True) return x # compute probabilities p_x = pnorm(x.sum(dim=2)) # -> [B, C, nb_bins] p_y = pnorm(y.sum(dim=2)) # -> [B, C, nb_bins] x = x.transpose(-1, -2) # -> [B, C, nb_bins, N] p_xy = torch.matmul(x, y) # -> [B, C, nb_bins, nb_bins] p_xy = pnorm(p_xy, [-1, -2]) # compute entropies h_x = -(p_x * p_x.log()).sum(dim=-1) # -> [B, C] h_y = -(p_y * p_y.log()).sum(dim=-1) # -> [B, C] h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2]) # -> [B, C] # negative mutual information mi = h_xy - (h_x + h_y) # normalize if normalize not in (None, 'none'): normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \ (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \ torch.min if normalize == 'min' else \ torch.max if normalize == 'max' else \ normalize mi = mi / normalize(h_x, h_y) mi += 1 # reduce return super().forward(mi)