def init_gmm2(x, y, bins=6, dim=None, mask=None): """Initialize parameters of a 2D GMM by drawing quantiles. Parameters ---------- x : (..., *spatial) tensor Moving image (first dimension) y : (..., *spatial) tensor Fixed image (first dimension) bins : int, default=3 Number of clusters dim : int, default=`fixed.dim()-1` Number of spatial dimensions mask : (..., *spatial) tensor, optional Mask or weights Returns ------- dict: xmean : (..., *1, bins) tensor Mean of the moving (1st dimension) image ymean : (..., *1, bins) tensor Mean of the fixed (1st dimension) image xvar : (..., *1, bins) tensor Variance of the moving (1st dimension) image yvar : (..., *1, bins) tensor Variance of the fixed (1st dimension) image corr : (..., *1, bins) tensor Correlation coefficient prior : (..., *1, bins) tensor Proportion of each class """ if mask is not None: mask = mask[..., 0] dim = dim or x.dim() - 1 quantiles = torch.arange(bins + 1, dtype=x.dtype, device=x.device).div_(bins) xmean = utils.quantile(x[..., 0], quantiles, dim=range(-dim, 0), keepdim=True, mask=mask) ymean = utils.quantile(y[..., 0], quantiles, dim=range(-dim, 0), keepdim=True, mask=mask) xvar = (xmean[..., 1:] - xmean[..., :-1]).div_(2.355).square_() yvar = (ymean[..., 1:] - ymean[..., :-1]).div_(2.355).square_() xmean = (xmean[..., 1:] + xmean[..., :-1]).div_(2) ymean = (ymean[..., 1:] + ymean[..., :-1]).div_(2) corr = torch.zeros_like(yvar) prior = y.new_full([bins], 1 / bins) return dict(xmean=xmean, ymean=ymean, xvar=xvar, yvar=yvar, corr=corr, prior=prior)
def load(files, is_label=False): """Load one multi-channel multi-file volume. Returns a (channels, *spatial) tensor """ dats = [] for file in files: if is_label: dat = io.volumes.load(file.fname, dtype=torch.int32, device=device) else: dat = io.volumes.loadf(file.fname, rand=True, dtype=torch.float32, device=device) dat = dat.reshape([*file.shape, file.channels]) dat = dat[..., file.subchannels] dat = utils.movedim(dat, -1, 0) dim = dat.dim() - 1 qt = utils.quantile(dat, (0.01, 0.95), dim=range(-dim, 0), keepdim=True) mn, mx = qt.unbind(-1) dat = dat.sub_(mn).div_(mx-mn) dats.append(dat) del dat dats = torch.cat(dats, dim=0) if is_label and len(dats) > 1: warn('Multi-channel label images are not accepted. ' 'Using only the first channel') dats = dats[:1] return dats
def rescale2d(x): if not x.dtype.is_floating_point: x = x.float() mn, mx = utils.quantile(x, [0.005, 0.995], dim=range(-2, 0), bins=1024).unbind(-1) mx = mx.max(mn + 1e-8) mn, mx = mn[..., None, None], mx[..., None, None] x = x.sub(mn).div_(mx-mn).clamp_(0, 1) return x
def addnoise_(cls, x): v = x.unique().sort().values if v.shape[0] > 1: v = v[1:] - v[:-1] v = utils.quantile(v, 0.005).item() mask = torch.isfinite(x).bitwise_not_().bitwise_or_(x == 0) x.masked_fill_(mask, 0) x.addcmul_(mask.to(x.dtype), torch.rand_like(x), value=v) return x
def discretize(dat, nbins=256, mask=None): """Discretize an image into a number of bins""" dim = dat.dim() - 2 dims = range(-dim, 0) mn, mx = utils.quantile(dat, (0.0005, 0.9995), dim=dims, keepdim=True, mask=mask).unbind(-1) dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins - 1) dat = make_finite_(dat) dat = dat.long() return dat
def cutoff(dat, cutoff, dim=None): """Clip data when outside of a range defined by percentiles Parameters ---------- dat : tensor or ndarray Input data cutoff : max or (min, max) Percentile cutoffs (in [0, 1]) dim : int, optional Dimension(s) along which to compute percentiles Returns ------- dat : tensor or ndarray Clipped data """ if cutoff is None: return dat cutoff = sorted([100 * val for val in py.make_sequence(cutoff)]) if len(cutoff) > 2: raise ValueError('Maximum to percentiles (min, max) should' ' be provided. Got {}.'.format(len(cutoff))) if torch.is_tensor(dat): dat_for_quantile = dat if not dat.dtype.is_floating_point: dat_for_quantile = dat_for_quantile.float() cutoff = [val / 100 for val in cutoff] pct = utils.quantile(dat_for_quantile, cutoff, bins=1024) if len(pct) == 1: mn, mx = None, pct[0] else: mn, mx = pct[0], pct[1] if not dat.dtype.is_floating_point: mx = mx.ceil() if mn is not None: mn = mn.floor() mx = mx.to(dat.dtype) if mn is not None: mn = mn.to(dat.dtype) dat.clamp_(mn, mx) return dat else: pct = np.nanpercentile(dat, cutoff, axis=dim, keepdims=True) if len(pct) == 1: dat = np.clip(dat, a_min=None, a_max=pct[0]) else: dat = np.clip(dat, a_min=pct[0], a_max=pct[1]) return dat
def _discretize_image(dat, nbins=256): """ Discretize an image into a number of bins Input : (C, *spatial) tensor[float] Returns: (C, *spatial) tensor[long] """ dim = dat.dim() - 1 mn, mx = utils.quantile(dat, (0.0005, 0.9995), dim=range(-dim, 0), keepdim=True).unbind(-1) dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins - 1) dat = dat.long() return dat
def forward(self, image, **overload): qmin = overload.get('qmin', self.qmin) qmax = overload.get('qmax', self.qmax) vmin = overload.get('vmin', self.vmin) vmax = overload.get('vmax', self.vmax) dim = image.dim() - 2 mn, mx = utils.quantile(image, (qmin, qmax), dim=range(-dim, 0), keepdim=True, bins=self.bins).unbind(-1) image = (image - mn) * ((vmax - vmin) / (mx - mn)) + vmin return image
def _soft_quantize_image(dat, nbins=16): """ Discretize an image into a number of bins Input : (1, *spatial) tensor[float] Returns: (C, *spatial) tensor[long] """ dim = dat.dim() - 1 dat = dat[0] mn, mx = utils.quantile(dat, (0.0005, 0.9995), dim=range(-dim, 0), keepdim=True).unbind(-1) dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins) centers = torch.linspace(0, nbins, nbins + 1, **utils.backend(dat)) centers = (centers[1:] + centers[:-1]) / 2 centers = centers.flip(0) centers = centers[(Ellipsis, ) + (None, ) * dim] dat = (centers - dat).square().mul_(-2.355**2).exp_() dat /= dat.sum(0, keepdims=True) return dat
def _rescale_image(dat, quantiles): """Rescale an image between (0, 1) based on two quantiles""" dim = dat.dim() - 1 if not isinstance(quantiles, (list, tuple)): quantiles = [quantiles] if len(quantiles) == 0: mn = 0 mx = 95 elif len(quantiles) == 1: mn = 0 mx = quantiles[0] else: mn, mx = quantiles mx = mx / 100 mn, mx = utils.quantile(dat, (mn, mx), dim=range(-dim, 0), keepdim=True, bins=1024).unbind(-1) dat = dat.sub_(mn).div_(mx - mn) return dat
def rescale(self, image): """Affine rescaling of an image by mapping quantiles to (0, 1) Parameters ---------- image : tensor Input image qmin : (0..1), default=0 Lower quantile qmax : (0..1), default=0.95 Upper quantile Returns ------- image : tensor Rescale image """ if self.qmin == 0 and self.qmax == 1: return image qmin, qmax = utils.quantile(image, [self.qmin, self.qmax]) image -= qmin image /= (qmax - qmin) return image
def intensity_preproc(*images, min=None, max=None, eq=None): """(Joint) rescaling and intensity equalizing. Parameters ---------- *images : (*batch, H, W) tensor Input (batch of) 2d images. All batch shapes should be broadcastable together. min : tensor_like, optional Minimum value. Should be broadcastable to batch. Default: 5th percentile of each batch element. max : tensor_like, optional Maximum value. Should be broadcastable to batch. Default: 95th percentile of each batch element. eq : {'linear', 'quadratic', 'log', None} or float, default=None Apply histogram equalization. If 'quadratic' or 'log', the histogram of the transformed signal is equalized. If float, the signal is taken to that power before being equalized. Returns ------- *images : (*batch, H, W) tensor Preprocessed images. Intensities are scaled within [0, 1]. """ if len(images) == 1: images = [utils.to_max_backend(*images)] else: images = utils.to_max_backend(*images) backend = utils.backend(images[0]) eps = constants.eps(images[0].dtype) # rescale min/max min = py.make_list(min, len(images)) max = py.make_list(max, len(images)) min = [ utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True) if mn is None else torch.as_tensor(mn, **backend)[None, None] for image, mn in zip(images, min) ] min, *othermin = min for mn in othermin: min = torch.min(min, mn) del othermin max = [ utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True) if mx is None else torch.as_tensor(mx, **backend)[None, None] for image, mx in zip(images, max) ] max, *othermax = max for mx in othermax: max = torch.max(max, mx) del othermax images = [torch.max(torch.min(image, max), min) for image in images] images = [ image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min)) for image in images ] if not eq: return tuple(images) if len(images) > 1 else images[0] # reshape and concatenate batch = utils.expanded_shape(*[image.shape[:-2] for image in images]) images = [image.expand([*batch, *image.shape[-2:]]) for image in images] shapes = [image.shape[-2:] for image in images] chunks = [py.prod(s) for s in shapes] images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)] images = torch.cat(images, dim=-1) if eq is True: eq = 'linear' if not isinstance(eq, str): if eq >= 0: images = images.pow(eq) else: images = images.clamp_min_(constants.eps(images.dtype)).pow(eq) elif eq.startswith('q'): images = images.square() elif eq.startswith('log'): images = images.clamp_min_(constants.eps(images.dtype)).log() images = histeq(images, dim=-1) if not (isinstance(eq, str) and eq.startswith('lin')): # rescale min/max images -= math.min(images, dim=-1, keepdim=True) images /= math.max(images, dim=-1, keepdim=True) images = images.split(chunks, dim=-1) images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)] return tuple(images) if len(images) > 1 else images[0]
def preproc_(cls, x): mn, mx = utils.quantile(x, [0.005, 0.995], keepdim=True).unbind(-1) x = x.max(mn).min(mx).sub_(mn).div_(mx - mn) return x