Beispiel #1
0
def init_gmm2(x, y, bins=6, dim=None, mask=None):
    """Initialize parameters of a 2D GMM by drawing quantiles.

    Parameters
    ----------
    x : (..., *spatial) tensor
        Moving image (first dimension)
    y : (..., *spatial) tensor
        Fixed image (first dimension)
    bins : int, default=3
        Number of clusters
    dim : int, default=`fixed.dim()-1`
        Number of spatial dimensions
    mask : (..., *spatial) tensor, optional
        Mask or weights

    Returns
    -------
    dict:
        xmean : (..., *1, bins) tensor
            Mean of the moving (1st dimension) image
        ymean : (..., *1, bins) tensor
            Mean of the fixed (1st dimension) image
        xvar : (..., *1, bins) tensor
            Variance of the moving (1st dimension) image
        yvar : (..., *1, bins) tensor
            Variance of the fixed (1st dimension) image
        corr : (..., *1, bins) tensor
            Correlation coefficient
        prior : (..., *1, bins) tensor
            Proportion of each class
    """
    if mask is not None:
        mask = mask[..., 0]
    dim = dim or x.dim() - 1
    quantiles = torch.arange(bins + 1, dtype=x.dtype,
                             device=x.device).div_(bins)
    xmean = utils.quantile(x[..., 0],
                           quantiles,
                           dim=range(-dim, 0),
                           keepdim=True,
                           mask=mask)
    ymean = utils.quantile(y[..., 0],
                           quantiles,
                           dim=range(-dim, 0),
                           keepdim=True,
                           mask=mask)
    xvar = (xmean[..., 1:] - xmean[..., :-1]).div_(2.355).square_()
    yvar = (ymean[..., 1:] - ymean[..., :-1]).div_(2.355).square_()
    xmean = (xmean[..., 1:] + xmean[..., :-1]).div_(2)
    ymean = (ymean[..., 1:] + ymean[..., :-1]).div_(2)
    corr = torch.zeros_like(yvar)
    prior = y.new_full([bins], 1 / bins)
    return dict(xmean=xmean,
                ymean=ymean,
                xvar=xvar,
                yvar=yvar,
                corr=corr,
                prior=prior)
Beispiel #2
0
 def load(files, is_label=False):
     """Load one multi-channel multi-file volume.
     Returns a (channels, *spatial) tensor
     """
     dats = []
     for file in files:
         if is_label:
             dat = io.volumes.load(file.fname,
                                   dtype=torch.int32, device=device)
         else:
             dat = io.volumes.loadf(file.fname, rand=True,
                                    dtype=torch.float32, device=device)
         dat = dat.reshape([*file.shape, file.channels])
         dat = dat[..., file.subchannels]
         dat = utils.movedim(dat, -1, 0)
         dim = dat.dim() - 1
         qt = utils.quantile(dat, (0.01, 0.95), dim=range(-dim, 0), keepdim=True)
         mn, mx = qt.unbind(-1)
         dat = dat.sub_(mn).div_(mx-mn)
         dats.append(dat)
         del dat
     dats = torch.cat(dats, dim=0)
     if is_label and len(dats) > 1:
         warn('Multi-channel label images are not accepted. '
              'Using only the first channel')
         dats = dats[:1]
     return dats
Beispiel #3
0
 def rescale2d(x):
     if not x.dtype.is_floating_point:
         x = x.float()
     mn, mx = utils.quantile(x, [0.005, 0.995],
                             dim=range(-2, 0), bins=1024).unbind(-1)
     mx = mx.max(mn + 1e-8)
     mn, mx = mn[..., None, None], mx[..., None, None]
     x = x.sub(mn).div_(mx-mn).clamp_(0, 1)
     return x
Beispiel #4
0
 def addnoise_(cls, x):
     v = x.unique().sort().values
     if v.shape[0] > 1:
         v = v[1:] - v[:-1]
     v = utils.quantile(v, 0.005).item()
     mask = torch.isfinite(x).bitwise_not_().bitwise_or_(x == 0)
     x.masked_fill_(mask, 0)
     x.addcmul_(mask.to(x.dtype), torch.rand_like(x), value=v)
     return x
Beispiel #5
0
def discretize(dat, nbins=256, mask=None):
    """Discretize an image into a number of bins"""
    dim = dat.dim() - 2
    dims = range(-dim, 0)
    mn, mx = utils.quantile(dat, (0.0005, 0.9995),
                            dim=dims,
                            keepdim=True,
                            mask=mask).unbind(-1)
    dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins - 1)
    dat = make_finite_(dat)
    dat = dat.long()
    return dat
Beispiel #6
0
def cutoff(dat, cutoff, dim=None):
    """Clip data when outside of a range defined by percentiles

    Parameters
    ----------
    dat : tensor or ndarray
        Input data
    cutoff : max or (min, max)
        Percentile cutoffs (in [0, 1])
    dim : int, optional
        Dimension(s) along which to compute percentiles

    Returns
    -------
    dat : tensor or ndarray
        Clipped data

    """
    if cutoff is None:
        return dat
    cutoff = sorted([100 * val for val in py.make_sequence(cutoff)])
    if len(cutoff) > 2:
        raise ValueError('Maximum to percentiles (min, max) should'
                         ' be provided. Got {}.'.format(len(cutoff)))
    if torch.is_tensor(dat):
        dat_for_quantile = dat
        if not dat.dtype.is_floating_point:
            dat_for_quantile = dat_for_quantile.float()

        cutoff = [val / 100 for val in cutoff]
        pct = utils.quantile(dat_for_quantile, cutoff, bins=1024)

        if len(pct) == 1:
            mn, mx = None, pct[0]
        else:
            mn, mx = pct[0], pct[1]
        if not dat.dtype.is_floating_point:
            mx = mx.ceil()
            if mn is not None:
                mn = mn.floor()

        mx = mx.to(dat.dtype)
        if mn is not None:
            mn = mn.to(dat.dtype)
        dat.clamp_(mn, mx)
        return dat
    else:
        pct = np.nanpercentile(dat, cutoff, axis=dim, keepdims=True)
        if len(pct) == 1:
            dat = np.clip(dat, a_min=None, a_max=pct[0])
        else:
            dat = np.clip(dat, a_min=pct[0], a_max=pct[1])
        return dat
Beispiel #7
0
def _discretize_image(dat, nbins=256):
    """
    Discretize an image into a number of bins
    Input : (C, *spatial) tensor[float]
    Returns: (C, *spatial) tensor[long]
    """
    dim = dat.dim() - 1
    mn, mx = utils.quantile(dat, (0.0005, 0.9995),
                            dim=range(-dim, 0),
                            keepdim=True).unbind(-1)
    dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins - 1)
    dat = dat.long()
    return dat
Beispiel #8
0
    def forward(self, image, **overload):
        qmin = overload.get('qmin', self.qmin)
        qmax = overload.get('qmax', self.qmax)
        vmin = overload.get('vmin', self.vmin)
        vmax = overload.get('vmax', self.vmax)
        dim = image.dim() - 2

        mn, mx = utils.quantile(image, (qmin, qmax),
                                dim=range(-dim, 0),
                                keepdim=True,
                                bins=self.bins).unbind(-1)
        image = (image - mn) * ((vmax - vmin) / (mx - mn)) + vmin
        return image
Beispiel #9
0
def _soft_quantize_image(dat, nbins=16):
    """
    Discretize an image into a number of bins
    Input : (1, *spatial) tensor[float]
    Returns: (C, *spatial) tensor[long]
    """
    dim = dat.dim() - 1
    dat = dat[0]
    mn, mx = utils.quantile(dat, (0.0005, 0.9995),
                            dim=range(-dim, 0),
                            keepdim=True).unbind(-1)
    dat = dat.sub_(mn).div_(mx - mn).clamp_(0, 1).mul_(nbins)
    centers = torch.linspace(0, nbins, nbins + 1, **utils.backend(dat))
    centers = (centers[1:] + centers[:-1]) / 2
    centers = centers.flip(0)
    centers = centers[(Ellipsis, ) + (None, ) * dim]
    dat = (centers - dat).square().mul_(-2.355**2).exp_()
    dat /= dat.sum(0, keepdims=True)
    return dat
Beispiel #10
0
def _rescale_image(dat, quantiles):
    """Rescale an image between (0, 1) based on two quantiles"""
    dim = dat.dim() - 1
    if not isinstance(quantiles, (list, tuple)):
        quantiles = [quantiles]
    if len(quantiles) == 0:
        mn = 0
        mx = 95
    elif len(quantiles) == 1:
        mn = 0
        mx = quantiles[0]
    else:
        mn, mx = quantiles
    mx = mx / 100
    mn, mx = utils.quantile(dat, (mn, mx),
                            dim=range(-dim, 0),
                            keepdim=True,
                            bins=1024).unbind(-1)
    dat = dat.sub_(mn).div_(mx - mn)
    return dat
Beispiel #11
0
    def rescale(self, image):
        """Affine rescaling of an image by mapping quantiles to (0, 1)

        Parameters
        ----------
        image : tensor
            Input image
        qmin : (0..1), default=0
            Lower quantile
        qmax : (0..1), default=0.95
            Upper quantile

        Returns
        -------
        image : tensor
            Rescale image

        """
        if self.qmin == 0 and self.qmax == 1:
            return image
        qmin, qmax = utils.quantile(image, [self.qmin, self.qmax])
        image -= qmin
        image /= (qmax - qmin)
        return image
Beispiel #12
0
def intensity_preproc(*images, min=None, max=None, eq=None):
    """(Joint) rescaling and intensity equalizing.

    Parameters
    ----------
    *images : (*batch, H, W) tensor
        Input (batch of) 2d images.
        All batch shapes should be broadcastable together.
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: 5th percentile of each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: 95th percentile of each batch element.
    eq : {'linear', 'quadratic', 'log', None} or float, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.
        If float, the signal is taken to that power before being equalized.

    Returns
    -------
    *images : (*batch, H, W) tensor
        Preprocessed images.
        Intensities are scaled within [0, 1].

    """

    if len(images) == 1:
        images = [utils.to_max_backend(*images)]
    else:
        images = utils.to_max_backend(*images)
    backend = utils.backend(images[0])
    eps = constants.eps(images[0].dtype)

    # rescale min/max
    min = py.make_list(min, len(images))
    max = py.make_list(max, len(images))
    min = [
        utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True)
        if mn is None else torch.as_tensor(mn, **backend)[None, None]
        for image, mn in zip(images, min)
    ]
    min, *othermin = min
    for mn in othermin:
        min = torch.min(min, mn)
    del othermin
    max = [
        utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True)
        if mx is None else torch.as_tensor(mx, **backend)[None, None]
        for image, mx in zip(images, max)
    ]
    max, *othermax = max
    for mx in othermax:
        max = torch.max(max, mx)
    del othermax
    images = [torch.max(torch.min(image, max), min) for image in images]
    images = [
        image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min))
        for image in images
    ]

    if not eq:
        return tuple(images) if len(images) > 1 else images[0]

    # reshape and concatenate
    batch = utils.expanded_shape(*[image.shape[:-2] for image in images])
    images = [image.expand([*batch, *image.shape[-2:]]) for image in images]
    shapes = [image.shape[-2:] for image in images]
    chunks = [py.prod(s) for s in shapes]
    images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)]
    images = torch.cat(images, dim=-1)

    if eq is True:
        eq = 'linear'
    if not isinstance(eq, str):
        if eq >= 0:
            images = images.pow(eq)
        else:
            images = images.clamp_min_(constants.eps(images.dtype)).pow(eq)
    elif eq.startswith('q'):
        images = images.square()
    elif eq.startswith('log'):
        images = images.clamp_min_(constants.eps(images.dtype)).log()

    images = histeq(images, dim=-1)

    if not (isinstance(eq, str) and eq.startswith('lin')):
        # rescale min/max
        images -= math.min(images, dim=-1, keepdim=True)
        images /= math.max(images, dim=-1, keepdim=True)

    images = images.split(chunks, dim=-1)
    images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)]

    return tuple(images) if len(images) > 1 else images[0]
Beispiel #13
0
 def preproc_(cls, x):
     mn, mx = utils.quantile(x, [0.005, 0.995], keepdim=True).unbind(-1)
     x = x.max(mn).min(mx).sub_(mn).div_(mx - mn)
     return x