Exemple #1
0
def smooth_iid():

    shape = [128, 128]
    center = tuple([s // 2 for s in shape])
    nrep = 100
    dim = len(shape)

    sig = [0.1, 0.5, 1, 2, 4, 8, 16]  # [0, 0.01, 0.05, 0.1, 0.5, 1, 2, 4, 8]
    fwhm = [2.355 * s for s in sig]
    nc = [(2 * pi * (s**2))**(-dim / 2) if s > 0 else 1 for s in sig]
    yb = [(4 * pi * (s**2))**(-dim / 2) if s > 0 else 1 for s in sig]

    dat = torch.randn([nrep, *shape])
    dat = smooth(dat, fwhm=1.5, basis=1, dim=2)

    var0 = dat.var(0)[64, 64].item()
    varb0 = []
    varb1 = []
    varbd = []
    for f in fwhm:
        sdat = smooth(dat, fwhm=f, basis=0, dim=2)
        varb0.append(sdat.var(0)[center].item() / var0)
        sdat = smooth(dat, fwhm=f, basis=1, dim=2)
        varb1.append(sdat.var(0)[center].item() / var0)
        kernel = gauss_kernel(f, dim)
        sdat = conv(dim, dat, kernel, padding='auto', bound='dct2')
        varbd.append(sdat.var(0)[center].item() / var0)

    for fn in (plt.loglog, plt.semilogy, plt.plot):
        fn(sig, nc, 'k:')
        fn(sig, yb, 'k--')
        fn(sig, varb0, 'r-+')
        fn(sig, varb1, 'b-+')
        fn(sig, varbd, 'g-+')
        plt.xlabel('Smoothing sigma')
        plt.ylabel('Variance')
        plt.show()

    fn = plt.plot
    fn(sig[2:], nc[2:], 'k:')
    fn(sig[2:], yb[2:], 'k--')
    fn(sig[2:], varb0[2:], 'r-+')
    fn(sig[2:], varb1[2:], 'b-+')
    fn(sig[2:], varbd[2:], 'g-+')
    plt.xlabel('Smoothing sigma')
    plt.ylabel('Variance')
    plt.show()

    t = lambda v: v**(-1 / dim)
    varb0 = list(map(t, varb0))
    varb1 = list(map(t, varb1))
    varbd = list(map(t, varbd))
    fn = plt.plot
    fn(sig, sig, 'k--')
    fn(sig[2:], varb0[2:], 'r-+')
    fn(sig[2:], varb1[2:], 'b-+')
    fn(sig[2:], varbd[2:], 'g-+')
    plt.xlabel('Smoothing sigma')
    plt.ylabel('Variance')
    plt.show()
Exemple #2
0
    def forward(self, x, noise=None, return_resolution=False):

        if noise is not None:
            noise = noise.expand(x.shape)

        dim = x.dim() - 2
        backend = utils.backend(x)
        resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1],
                                           **backend)
        resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1],
                                             **backend)

        all_resolutions = []
        out = torch.empty_like(x)
        for b in range(len(x)):
            for c in range(x.shape[1]):
                resolution = self.resolution(resolution_exp[c],
                                             resolution_scale[c]).sample()
                resolution = resolution.clamp_min(1)
                fwhm = [resolution] * dim
                y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
                if noise is not None:
                    y += noise[b, c]
                factor = [1/resolution] * dim
                y = y[None, None]  # need batch and channel for resize
                y = resize(y, factor=factor, anchor='f')
                factor = [resolution] * dim
                all_resolutions.append(factor)
                y = resize(y, factor=factor, shape=x.shape[2:], anchor='f')
                out[b, c] = y[0, 0]

        all_resolutions = utils.as_tensor(all_resolutions, **backend)
        return (out, all_resolutions) if return_resolution else out
Exemple #3
0
    def backward(self, g, x, w=None, min=None, max=None):
        """

        Parameters
        ----------
        g : (..., *bins) tensor
        x : (..., n, 2) tensor
        w : (..., n) tensor, optional
        min : (...) tensor_like, optional
        max : (...) tensor_like, optional

        Returns
        -------
        g : (..., n, 2) tensor

        """
        backend = dict(dtype=x.dtype, device=x.device)
        n = x.shape[-2]
        xbatch = x.shape[:-2]
        if w is not None:
            _, w = torch.broadcast_tensors(x[..., 0], w)
            batch = w.shape[:-1]
            x = x.expand([*batch, *x.shape[-2:]])
            w = w.reshape([-1, n])
        else:
            batch = xbatch
        x = x.reshape([-1, n, 2])

        if min is None:
            min = x.min(-2, keepdim=True).values
        else:
            min = torch.as_tensor(min,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])
        if max is None:
            max = x.max(-2, keepdim=True).values
        else:
            max = torch.as_tensor(max,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])

        x = x.clone()
        bins = torch.as_tensor(self.bins, **backend)
        x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5)
        min = min.reshape([*xbatch, 2])
        max = max.reshape([*xbatch, 2])

        g = g.reshape([-1, *self.bins])

        # smooth backward
        if any(self.fwhm):
            g = smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2)

        # push data into the histogram
        g = _jhistc_backward(g, x, w, self.order, self.bound, self.extrapolate)
        g = g.mul_(bins / (max - min))

        # reshape
        g = g.reshape([*batch, n, 2])
        return g
Exemple #4
0
    def forward(self, x, affine=None):
        """

        Parameters
        ----------
        x : (X, Y, Z) tensor or str
        affine : (4, 4) tensor, optional

        Returns
        -------
        seg : (32, oX, oY, oZ) tensor
            Segmentation
        resliced : (oX, oY, oZ) tensor
            Input resliced to 1 mm RAS
        affine : (4, 4) tensor
            Output orientation matrix

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        if isinstance(x, str):
            x = io.map(x)
        if isinstance(x, io.MappedArray):
            if affine is None:
                affine = x.affine
                x = x.fdata()
                x = x.reshape(x.shape[:3])
        x = SynthPreproc.addnoise(x)
        if affine is not None:
            affine, x = spatial.affine_reorient(affine, x, 'RAS')
            vx = spatial.voxel_size(affine)
            fwhm = 0.25 * vx.reciprocal()
            fwhm[vx > 1] = 0
            x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
            x, affine = spatial.resize(x[None, None],
                                       vx.tolist(),
                                       affine=affine)
            x = x[0, 0]
        oshape = x.shape
        x, crop = SynthPreproc.crop(x)
        x = SynthPreproc.preproc(x)[None, None]
        if self.verbose:
            print('done.', flush=True)
            print('Segmenting... ', end='', flush=True)
        s, x = super().forward(x)[0], x[0, 0]
        if self.verbose:
            print('done.', flush=True)
            print('Postprocessing... ', end='', flush=True)
        s = self.relabel(s.argmax(0))
        x = SynthPreproc.pad(x, oshape, crop)
        s = SynthPreproc.pad(s, oshape, crop)
        if self.verbose:
            print('done.', flush=True)
        return s, x, affine
Exemple #5
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = dict(dtype=x.dtype, device=x.device)

        fwhm_exp = utils.make_vector(self.fwhm_exp, 1 if self.iso else dim, **backend)
        fwhm_scale = utils.make_vector(self.fwhm_scale, 1 if self.iso else dim, **backend)

        out = torch.as_tensor(x)
        for b in range(len(x)):
            fwhm = self.fwhm(fwhm_exp, fwhm_scale).sample().clamp_min_(0).expand([dim]).clone()
            out[b] = smooth(x[b], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
        return out
Exemple #6
0
 def _build_pyramid(self, dat, levels, method, dim, bound,
                    mask=None, preview=None):
     levels = list(levels)
     indexed_levels = list(enumerate(levels))
     indexed_levels.sort(key=lambda x: x[1])
     nb_levels = max(levels)
     if mask is not None:
         mask = mask.to(dat.device)
     dats = [dat] * levels.count(0)
     masks = [mask] * levels.count(0)
     previews = [preview] * levels.count(0)
     if mask is not None:
         mask = mask.to(dat.dtype)
     if preview is not None:
         preview = preview.to(dat.dtype)
     for level in range(1, nb_levels+1):
         shape = dat.shape[-dim:]
         kernel_size = [min(2, s) for s in shape]
         if method[0] == 'g':  # gaussian pyramid
             # We assume the original data has a PSF of 1 input voxel.
             # We smooth by an additional 1-vx FWHM so that the data has a
             # PSF of 2 input voxels == 1 output voxel, then subsample.
             smooth = lambda x: spatial.smooth(x, fwhm=1, stride=2,
                                               dim=dim, bound=bound)
         elif method[0] == 'a':  # average window
             smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size,
                                             stride=2, reduction='mean')
         elif method[0] == 'm':  # median window
             smooth = lambda x: spatial.pool(dim, x, kernel_size=kernel_size,
                                             stride=2, reduction='median')
         elif method[0] == 's':  # strides
             slicer = [slice(None, None, 2)] * dim
             smooth = lambda x: x[(Ellipsis, *slicer)]
         else:
             raise ValueError(method)
         dat = smooth(dat)
         if mask is not None:
             mask = smooth(mask)
         if preview is not None:
             preview = smooth(preview)
         dats += [dat] * levels.count(level)
         masks += [mask] * levels.count(level)
         previews += [preview] * levels.count(level)
     reordered_dats = [None] * len(levels)
     reordered_masks = [None] * len(levels)
     reordered_previews = [None] * len(levels)
     for (i, level), dat, mask, preview \
             in zip(indexed_levels, dats, masks, previews):
         reordered_dats[i] = dat
         reordered_masks[i] = mask
         reordered_previews[i] = preview
     return reordered_dats, reordered_masks, reordered_previews
Exemple #7
0
 def resize(cls, x, affine, target_vx=1):
     target_vx = utils.make_vector(target_vx, x.dim(),
                                   **utils.backend(affine))
     vx = spatial.voxel_size(affine)
     factor = vx / target_vx
     fwhm = 0.25 * factor.reciprocal()
     fwhm[factor > 1] = 0
     x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
     x, affine = spatial.resize(x[None, None],
                                factor.tolist(),
                                affine=affine)
     x = x[0, 0]
     return x, affine
Exemple #8
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)

        # device/dtype
        mean = utils.make_vector(self.mean, channel, **backend)
        amplitude = utils.make_vector(self.amplitude, channel, **backend)
        fwhm = utils.make_vector(self.fwhm, channel, **backend)
        
        # convert SE parameters to noise/kernel parameters
        sigma_se = fwhm / math.sqrt(8*math.log(2))
        amplitude = amplitude * (2*pi)**(nb_dim/4) * sigma_se.sqrt()
        fwhm = fwhm * math.sqrt(2)
        
        # smooth
        out = torch.empty([batch, channel, *shape], **backend)
        for b in range(batch):
            for c in range(channel):
                sample = torch.distributions.Normal(mean[c], amplitude[c]).sample(shape)
                out[b, c] = spatial.smooth(
                    sample, 'gauss', fwhm,
                    basis=self.basis, bound='dct2', dim=nb_dim, padding='same')
        return out
Exemple #9
0
def se_sample_smooth(shape, sigma, lam, mu=None, repeats=1, **backend):
    """Sample random fields with a squared exponential kernel.

    This function uses Gaussian smoothing of white noise.

    Parameters
    ----------
    shape : sequence[int]
        Shape of the image / volume.å
    sigma : float
        SE amplitude.
    lam : float
        SE length-scale.
    mu : () or (*shape) tensor_like
        SE mean
    repeats : int, default=1
        Number of sampled fields.

    Returns
    -------
    field : (repeats, *shape) tensor
        Sampled random fields.

    """
    # Convert SE parameters to Gaussian parameters
    dim = len(shape)
    sigma = sigma * ((2 * constants.pi)**(dim / 4)) * (lam**(dim / 2))
    lam = lam / (2**0.5)
    fwhm = lam * 2 * ((2 * math.log(2))**0.5)

    # Sample white noise
    mul = 4
    pad = int(math.ceil(2 * fwhm))
    shape2 = [s * mul + 2 * pad for s in shape]
    field = torch.randn([repeats, *shape2], **backend)

    # Gaussian smoothing
    field = spatial.smooth(field, fwhm=fwhm * mul, dim=dim, basis=0)
    sub = (slice(None), ) + (slice(pad + mul // 2, -pad, mul), ) * dim
    field = field[sub]
    field *= sigma * (mul**(dim / 2))

    # Add mean
    if mu is not None:
        mu = torch.as_tensor(mu, **backend)
        field += mu

    return field
Exemple #10
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Exemple #11
0
def _cli(args):
    """Command-line interface for `smooth` without exception handling"""
    args = args or sys.argv[1:]

    options = parser(args)
    if options.help:
        print(help)
        return

    fwhm = options.fwhm
    unit = 'mm'
    if isinstance(fwhm[-1], str):
        *fwhm, unit = fwhm
    fwhm = make_list(fwhm, 3)

    options.output = make_list(options.output, len(options.files))
    for fname, ofname in zip(options.files, options.output):
        f = io.map(fname)
        vx = voxel_size(f.affine).tolist()
        dim = len(vx)
        if unit == 'mm':
            fwhm1 = [f / v for f, v in zip(fwhm, vx)]
        else:
            fwhm1 = fwhm[:len(vx)]

        dat = f.fdata()
        dat = movedim_front2back(dat, dim)
        dat = smooth(dat,
                     type=options.method,
                     fwhm=fwhm1,
                     basis=options.basis,
                     bound=options.padding,
                     dim=dim)
        dat = movedim_back2front(dat, dim)

        folder, base, ext = fileparts(fname)
        ofname = ofname.format(dir=folder or '.',
                               base=base,
                               ext=ext,
                               sep=os.path.sep)
        io.savef(dat, ofname, like=f)
Exemple #12
0
 def forward(self, x):
     if x.dtype.is_floating_point:
         fwd = 1 - x[:, :1]
     else:
         fwd = (x == 0).bitwise_not_().float()
     fwd = spatial.smooth(fwd,
                          fwhm=self.kernel,
                          dim=x.dim() - 2,
                          padding='same',
                          bound='replicate')
     fwd = fwd > 1e-3
     if x.dtype.is_floating_point:
         bag = x[:, :1] * fwd
         bg = x[:, :1] * fwd.bitwise_not_()
         return torch.cat([bg, x[1:], bag], dim=1)
     else:
         x = x.clone()
         bag = (x == 0).bitwise_and_(fwd)
         label = self.label or (x.max() + 1)
         x[bag] = label
     return x
Exemple #13
0
    def forward(self, x, min=None, max=None, mask=None):
        """

        Parameters
        ----------
        x : (..., N, 2) tensor
            Input multivariate vector
        min : (..., 2) tensor, optional
        max : (..., 2) tensor, optional
        mask : (..., N) tensor, optional

        Returns
        -------
        h : (..., B, B) tensor
            Joint histogram

        """
        shape = x.shape
        x, min, max = self._prepare(x, min, max)

        # push data into the histogram
        #   hidden feature: tell pullpush to use +/- 0.5 tolerance when
        #   deciding if a coordinate is inbounds.
        extrapolate = self.extrapolate or 2
        if mask is None:
            h = spatial.grid_count(x[:, None], self.n, self.order, self.bound,
                                   extrapolate)[:, 0]
        else:
            mask = mask.to(x.device, x.dtype)
            h = spatial.grid_push(mask, x[:, None], self.n, self.order,
                                  self.bound, extrapolate)[:, 0]
        h = h.to(x.dtype)
        h = h.reshape([*shape[:-2], *h.shape[-2:]])

        if self.fwhm:
            h = spatial.smooth(h, fwhm=self.fwhm, bound=self.bound, dim=2)

        return h, min, max
Exemple #14
0
    def backward2(self, h, x, w=None, min=None, max=None):
        """

        Parameters
        ----------
        h : (..., *bins, [*bins]) tensor
        x : (..., n, 2) tensor
        w : (..., n) tensor, optional
        min : (...) tensor_like, optional
        max : (...) tensor_like, optional

        Returns
        -------
        h : (..., n, 2) tensor

        """
        backend = dict(dtype=x.dtype, device=x.device)
        n = x.shape[-2]
        xbatch = x.shape[:-2]
        if w is not None:
            _, w = torch.broadcast_tensors(x[..., 0], w)
            batch = w.shape[:-1]
            x = x.expand([*batch, *x.shape[-2:]])
            w = w.reshape([-1, n])
        else:
            batch = xbatch
        x = x.reshape([-1, n, 2])

        if h.shape[:-2] == batch:
            is_diag = True
        elif h.shape[:-4] == batch:
            is_diag = False
        else:
            raise ValueError('Don\'t know what to do with that shape')

        if min is None:
            min = x.min(-2, keepdim=True).values
        else:
            min = torch.as_tensor(min,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])
        if max is None:
            max = x.max(-2, keepdim=True).values
        else:
            max = torch.as_tensor(max,
                                  **backend).expand([*xbatch,
                                                     2]).reshape([-1, 1, 2])

        x = x.clone()
        bins = torch.as_tensor(self.bins, **backend)
        x = x.mul_(bins / (max - min)).add_(bins / (1 - max / min)).sub_(0.5)
        min = min.reshape([*xbatch, 2])
        max = max.reshape([*xbatch, 2])

        if is_diag:
            h = h.reshape([-1, *self.bins])
        else:
            h = h.reshape([-1, *self.bins, *self.bins])

        # smooth backward
        if any(self.fwhm):
            ker = kernels.smooth(fwhm=self.fwhm)
            if is_diag:
                ker = [k.square_() for k in ker]
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
            else:
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
                h = h.transpose(-4, -2).transpose(-3, -1)
                h = smooth(h, kernel=ker, bound=self.bound, dim=2)
                h = h.transpose(-4, -2).transpose(-3, -1)

        # push data into the histogram
        h = _jhistc_backward2(h, x, w, self.order, self.bound,
                              self.extrapolate)
        h = h.mul_((bins / (max - min)).square_())

        # reshape
        h = h.reshape([*batch, n, 2])
        return h
Exemple #15
0
def emmi_soft(moving,
              fixed,
              dim=None,
              prior=None,
              fwhm=None,
              max_iter=32,
              weights=None,
              grad=True,
              hess=True,
              return_prior=False):
    # ------------------------------------------------------------------
    #           PREPARATION
    # ------------------------------------------------------------------
    tiny = 1e-16
    dim = dim or (moving.dim() - 1)
    moving, fixed, weights, prior, shape = emmi_prepare(
        moving, fixed, weights, prior, dim)

    *batch, J, K = prior.shape
    Nb = len(batch)
    N = moving.shape[-1]
    Nm = weights.sum()

    # ------------------------------------------------------------------
    #           EM LOOP
    # ------------------------------------------------------------------
    ll = -float('inf')
    z = moving.new_empty([*batch, K, N])
    prior0 = torch.empty_like(prior)
    for n_iter in range(max_iter):
        ll_prev = ll
        # --------------------------------------------------------------
        # E-step
        # --------------------------------------------------------------
        z = sample_prior(prior.log(), fixed, z)
        z += moving.log()
        z, ll = math.softmax_lse(z, -2, lse=True, weights=weights)

        # --------------------------------------------------------------
        # M-step
        # ------
        # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)]
        # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j)
        # --------------------------------------------------------------
        z *= weights
        prior0 = scatter_prior(prior0, fixed, z)
        prior.copy_(prior0).add_(tiny)
        # make it a joint distribution
        prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb)
        if fwhm:
            # smooth "prior" for the prior
            prior = prior.transpose(-1, -2)
            prior = spatial.smooth(prior,
                                   dim=1,
                                   basis=0,
                                   fwhm=fwhm,
                                   bound='replicate')
            prior = prior.transpose(-1, -2)
        # MI-like normalization
        prior /= prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2,
                                                             keepdim=True)
        if ll - ll_prev < 1e-5 * Nm:
            break

    # compute mutual information (times number of observations)
    # > prior contains p(x,y)/(p(x) p(y))
    # > prior0 contains N * p(x,y)
    # >> 1/N Σ_{j,k} prior0[j,k] * log(prior[j,k])
    #    = Σ_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y)
    #    = Σ_{x,y} p(x,y) * log p(x,y)
    #       - Σ_{xy} p(x,y) log p(x)
    #       - Σ_{xy} p(x,y) log p(y)
    #    = Σ_{x,y} p(x,y) * log p(x,y)
    #       - Σ_{x} p(x) log p(x)
    #       - Σ_{y} p(y) log p(y)
    #    = -H[x,y] + H[x] + H[y]
    #    = MI[x, y]
    ll = -(prior0 * prior.log()).sum() / Nm
    out = [ll]

    # ------------------------------------------------------------------
    #           GRADIENTS
    # ------------------------------------------------------------------
    # compute gradients
    # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y]
    # The objective function is \sum_n E[y_n]
    # > ll = Σ_n log p(x[n] == j[n], h)
    #      = Σ_nj \sum_j q(x[n] == j) log \sum_k p(x[n] == j | z[n] == k, h) p(z[n] == k)
    if grad or hess:

        norm = linalg.dot(
            prior.transpose(-1, -2).unsqueeze(-1),
            moving.transpose(-1, -2).unsqueeze(-3))
        norm = norm.add_(tiny).reciprocal_()
        g = sample_prior(prior, fixed * norm)

        if hess:
            norm = norm.square_().mul_(fixed).unsqueeze(-1)
            h = moving.new_zeros([*g.shape[:-2], K * (K + 1) // 2, N])
            for j in range(J):
                h[..., :K, :] += prior[..., j, :K, None].square() * norm
                c = K
                for k in range(K):
                    for kk in range(k + 1, K):
                        h[..., c, :] += (prior[..., j, k, None] *
                                         prior[..., j, kk, None] * norm)
                        c += 1

        if grad:
            g *= weights
            g.neg_()
            g = g.reshape([*g.shape[:-1], *shape])
            out.append(g)
        if hess:
            h *= weights
            h = h.reshape([*h.shape[:-1], *shape])
            out.append(h)

    if return_prior:
        out.append(prior)

    return out[0] if len(out) == 1 else tuple(out)
Exemple #16
0
def emmi_hard(moving,
              fixed,
              dim=None,
              prior=None,
              fwhm=None,
              max_iter=32,
              weights=None,
              grad=True,
              hess=True,
              return_prior=False):
    # ------------------------------------------------------------------
    #           PREPARATION
    # ------------------------------------------------------------------
    tiny = 1e-16
    dim = dim or (moving.dim() - 1)
    moving, fixed, weights, prior, shape = emmi_prepare(
        moving, fixed, weights, prior, dim)

    *batch, J, K = prior.shape
    Nb = len(batch)
    N = moving.shape[-1]
    Nm = weights.sum()

    # ------------------------------------------------------------------
    #           EM LOOP
    # ------------------------------------------------------------------
    ll = -float('inf')
    z = moving.new_empty([*batch, K, N])
    prior0 = torch.empty_like(prior)
    for n_iter in range(max_iter):
        ll_prev = ll
        # --------------------------------------------------------------
        # E-step
        # ------
        # estimate responsibilities of each moving cluster for each
        # fixed voxel using Bayes' rule:
        #   p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k)
        #
        # . j[n] is the discretized fixed image
        # . p(z[n] == k) is the moving template
        # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k)
        # --------------------------------------------------------------
        z = sample_prior(prior, fixed, z)
        z *= moving

        # --------------------------------------------------------------
        # compute log-likelihood (log_sum of the posterior)
        # ll = Σ_n log p(x[n] == j[n])
        #    = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k)  p(z[n] == k)
        #    = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z}
        # --------------------------------------------------------------
        ll = z.sum(-2, keepdim=True)
        ll = add_tiny_(ll, Nb)
        z /= ll
        ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double)

        z *= weights

        # --------------------------------------------------------------
        # M-step
        # ------
        # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)]
        # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) 𝛿(x[n] == j)
        # --------------------------------------------------------------
        prior0 = scatter_prior(prior0, fixed, z)
        prior.copy_(prior0).add_(tiny)
        # make it a joint distribution
        prior /= add_tiny_(prior.sum(dim=[-1, -2], keepdim=True), Nb)

        if fwhm:
            # smooth "prior" for the prior
            prior = prior.transpose(-1, -2)
            prior = spatial.smooth(prior,
                                   dim=1,
                                   basis=0,
                                   fwhm=fwhm,
                                   bound='replicate')
            prior = prior.transpose(-1, -2)

        # prior /= prior.sum(dim=[-1, -2], keepdim=True)
        # MI-like normalization
        prior /= add_tiny_(
            prior.sum(dim=-1, keepdim=True) * prior.sum(dim=-2, keepdim=True),
            Nb)
        if ll - ll_prev < 1e-5 * Nm:
            break

    # compute mutual information (times number of observations)
    # > prior contains p(x,y)/(p(x) p(y))
    # > prior0 contains N * p(x,y)
    # >> 1/N \sum_{j,k} prior0[j,k] * log(prior[j,k])
    #    = \sum_{x,y} p(x,y) * (log p(x,y) - log p(x) - log p(y)
    #    = \sum_{x,y} p(x,y) * log p(x,y)
    #       - \sum_{xy} p(x,y) log p(x)
    #       - \sum_{xy} p(x,y) log p(y)
    #    = \sum_{x,y} p(x,y) * log p(x,y)
    #       - \sum_{x} p(x) log p(x)
    #       - \sum_{y} p(y) log p(y)
    #    = -H[x,y] + H[x] + H[y]
    #    = MI[x, y]
    ll = -(prior0 * add_tiny_(prior, Nb).log()).sum() / Nm
    out = [ll]

    # ------------------------------------------------------------------
    #           GRADIENTS
    # ------------------------------------------------------------------
    # compute gradients
    # Keeping only terms that depend on y, the mutual information is H[y]-H[x,y]
    # The objective function is \sum_n E[y_n]
    # > ll = \sum_n log p(x[n] == j[n], h)
    #      = \sum_n log \sum_k p(x[n] == j[n] | z[n] == k, h) p(z[n] == k)
    if grad or hess:

        g = sample_prior(prior, fixed)
        norm = linalg.dot(g.transpose(-1, -2), moving.transpose(-1, -2))
        norm = add_tiny_(norm, Nb).unsqueeze(-2).reciprocal_()
        g *= norm
        if hess:
            h = sym_outer(g, -2)

        if grad:
            g *= weights
            g /= -Nm
            # g.neg_()
            g = g.reshape([*g.shape[:-1], *shape])
            out.append(g)
        if hess:
            h *= weights
            h /= Nm
            h = h.reshape([*h.shape[:-1], *shape])
            out.append(h)

    if return_prior:
        out.append(prior)

    return out[0] if len(out) == 1 else tuple(out)
Exemple #17
0
    def forward(self, x):
        """

        Parameters
        ----------
        x : (batch, 1 or classes[-1], *shape) tensor
            Labels or probabilities

        Returns
        -------
        x : (batch, channel, *shape) tensor

        """
        batch, _, *shape = x.shape
        device = x.device
        dtype = x.dtype
        if not dtype.is_floating_point:
            dtype = self.dtype
        backend = dict(dtype=dtype, device=device)

        means = torch.as_tensor(self.means, **backend)
        scales = torch.as_tensor(self.scales, **backend)
        nb_classes = means.shape[-1]
        if means.dim() == 2:
            channel = len(means)
        elif scales.dim() == 2:
            channel = len(scales)
        else:
            channel = 1
        means = means.expand([channel, nb_classes]).clone()
        scales = scales.expand([channel, nb_classes]).clone()
        fwhm = utils.make_vector(self.fwhm, nb_classes, **backend)

        implicit = x.shape[1] < nb_classes
        out = torch.zeros([batch, channel, *shape], **backend)
        for k in range(nb_classes):
            sampler = _get_dist('normal')(means[:, k], scales[:, k])
            if x.dtype.is_floating_point:
                y1 = sampler.sample([batch, *shape])
                y1 = utils.movedim(y1, -1, 1)
                for c, f in enumerate(fwhm):
                    if f > 0:
                        y1[:, c] = spatial.smooth(y1[:, c],
                                                  fwhm=f,
                                                  dim=len(shape),
                                                  padding='same',
                                                  bound='dct2')
                if not implicit:
                    x1 = x[:, k, None]
                elif k > 0:
                    x1 = x[:, k - 1, None]
                else:
                    x1 = x.sum(1, keepdim=True).neg_().add_(1)
                out.addcmul(y1, x1)
            else:
                mask = (x.squeeze(1) == k)
                y1 = sampler.sample([mask.sum().long()])
                out = utils.movedim(out, 1, -1)
                out[mask, :] = y1
                out = utils.movedim(out, -1, 1)

        return out
Exemple #18
0
    def backward(self, x, g, min=None, max=None, hess=False, mask=None):
        """

        Parameters
        ----------
        x : (..., N, 2) tensor
            Input multidimensional vector
        g : (..., B, B) tensor
            Gradient with respect to the histogram
        min : (..., 2) tensor, optional
        max : (..., 2) tensor, optional

        Returns
        -------
        g : (..., N, 2) tensor
            Gradient with respect to x

        """
        if self.fwhm:
            g = spatial.smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2)

        shape = x.shape
        x, min, max = self._prepare(x, min, max)
        nvox = x.shape[-2]
        min = min.unsqueeze(-2)
        max = max.unsqueeze(-2)
        g = g.reshape([-1, *g.shape[-2:]])

        extrapolate = self.extrapolate or 2
        if not hess:
            g = spatial.grid_grad(g[:, None], x[:, None], self.order,
                                  self.bound, extrapolate)
            g = g[:, 0].reshape(shape)
        else:
            # 1) Absolute value of adjoint of gradient
            # we want shapes
            #   o : [batch=1, channel=1, spatial=[1, vox], dim=2]
            #   g : [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            #   x : [batch=1, spatial=[1, vox], dim=2]
            #    -> [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            order = _spatial.inter_to_nitorch([self.order], True)
            bound = _spatial.bound_to_nitorch([self.bound], True)
            o = torch.ones_like(x)
            g.requires_grad_()  # triggers push
            o, = _spatial.grid_grad_backward(o[:, None,
                                               None], g[:, None], x[:, None],
                                             bound, order, extrapolate)
            g.requires_grad_(False)
            g *= o[:, 0]
            # 2) Absolute value of gradient
            #   g : [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            #   x : [batch=1, spatial=[1, vox], dim=2]
            #    -> [batch=1, channel=1, spatial=[1, vox], 2]
            g = _spatial.grid_grad(g[:, None], x[:, None], bound, order,
                                   extrapolate)
            g = g.reshape(shape)

        # adjoint of affine function
        nn = torch.as_tensor(self.n, dtype=x.dtype, device=x.device)
        factor = nn / (max - min)
        if hess:
            factor = factor.square_()
        g = g.mul_(factor)
        if mask is not None:
            g *= mask[..., None]

        return g
Exemple #19
0
def _make_image(option, dim=None, device=None):
    """
    Load an image and build a Gaussian pyramid (if requireD)
    Returns: ImagePyramid
    """
    dat, mask, affine = _load_image(option.files,
                                    dim=dim,
                                    device=device,
                                    label=option.label)
    dim = dat.dim() - 1
    if option.mask:
        mask1 = mask
        mask, _, _ = _load_image([option.mask],
                                 dim=dim,
                                 device=device,
                                 label=option.label)
        if mask.shape[-dim:] != dat.shape[-dim:]:
            raise ValueError('Mask should have the same shape as the image. '
                             f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}')
        if mask1 is not None:
            mask = mask * mask1
        del mask1
    if option.world:  # overwrite orientation matrix
        affine = io.transforms.map(option.world).fdata().squeeze()
    for transform in (option.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        affine = spatial.affine_lmdiv(transform, affine)
    if not option.discretize and any(option.rescale):
        dat = _rescale_image(dat, option.rescale)
    if option.pad:
        pad = option.pad
        if isinstance(pad[-1], str):
            *pad, unit = pad
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            pad = torch.as_tensor(pad, **utils.backend(voxel_size))
            pad = pad / voxel_size
            pad = pad.floor().int().tolist()
        else:
            pad = [int(p) for p in pad]
        pad = py.make_list(pad, dim)
        if any(pad):
            affine, _ = spatial.affine_pad(affine,
                                           dat.shape[-dim:],
                                           pad,
                                           side='both')
            dat = utils.pad(dat, pad, side='both', mode=option.bound)
            if mask is not None:
                mask = utils.pad(mask, pad, side='both', mode=option.bound)
    if option.fwhm:
        fwhm = option.fwhm
        if isinstance(fwhm[-1], str):
            *fwhm, unit = fwhm
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size))
            fwhm = fwhm / voxel_size
        dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound)
    image = objects.ImagePyramid(dat,
                                 levels=option.pyramid,
                                 affine=affine,
                                 dim=dim,
                                 bound=option.bound,
                                 mask=mask,
                                 extrapolate=option.extrapolate,
                                 method=option.pyramid_method)
    if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1:
        for level in image:
            level.preview = level.dat
            level.dat = _soft_quantize_image(level.dat, option.soft_quantize)
    elif not option.label and option.discretize:
        for level in image:
            level.preview = level.dat
            level.dat = _discretize_image(level.dat, option.discretize)
    return image
Exemple #20
0
def em_prior(moving,
             fixed,
             prior=None,
             weights=None,
             fwhm=None,
             max_iter=32,
             tolerance=1e-5,
             verbose=0):
    """Estimate H_jk = P[x == j, z == k] by Expectation Maximization

    The objective function is
        Π_n p(x ; H, mu) == Π_n Σ_k p(x | z == k; H) p(z == k; mu)

    Parameters
    ----------
    moving : (B, K, *spatial) tensor
    fixed : (B, J|1, *spatial) tensor
    prior : (B, J, K) tensor, optional
    weights : (B, 1, *spatial) tensor, optional
    fwhm : float, optional
    max_iter : int, default=32

    Returns
    -------
    mi : (B,)
        Mutual information
    N : (B,)
        Total observation weight
    prior : (B, J, K)  tensor
        Regularized Joint histogram P[X,Z] / (P[X]*P[Z])

    """
    # ------------------------------------------------------------------
    #       PREPARATION
    # ------------------------------------------------------------------
    dim = fixed.dim() - 2
    moving, fixed, weights = spatial_prepare(moving, fixed, weights)
    B, K, J, shape = spatial_shapes(moving, fixed, prior)
    N = shape.numel()
    Nm = spatial_sum(weights, dim).squeeze(-1)

    # Flatten
    moving = moving.reshape([*moving.shape[:2], -1])
    fixed = fixed.reshape([*fixed.shape[:2], -1])
    weights = weights.reshape([*weights.shape[:2], -1])

    if fwhm is None:
        fwhm = J / 64

    # ------------------------------------------------------------------
    # initialize normalized histogram
    #   `prior` contains the "normalized" joint histogram p[x,y] / (p[x] p[y])
    #    However, it is used as the conditional histogram p[x|y] during
    #    the E-step. This is because the normalized histogram is equal to
    #    the conditional histogram up to a constant term that does not
    #    depend on x, and therefore disappears after normalization of the
    #    responsibilities.
    # ------------------------------------------------------------------
    if prior is None:
        prior = moving.new_ones([B, J, K])
        prior /= prior.sum(dim=[-1, -2], keepdim=True)
        prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1,
                                                             keepdim=True)
    else:
        prior = prior.clone()
    # prior /= prior.sum(dim=-2, keepdim=True)  # conditional X | Z
    # prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1, keepdim=True)

    # ------------------------------------------------------------------
    #       INFER PRIOR (EM)
    # ------------------------------------------------------------------
    ll_prev = -float('inf')
    z = moving.new_zeros([B, K, N])
    prior0 = torch.empty_like(prior)
    for n_iter in range(max_iter):
        # --------------------------------------------------------------
        # E-step
        # ------
        # estimate responsibilities of each moving cluster for each
        # fixed voxel using Bayes' rule:
        #   p(z[n] == k | x[n] == j[n]) ∝ p(x[n] == j[n] | z[n] == k) p(z[n] == k)
        #
        # . j[n] is the discretized fixed image
        # . p(z[n] == k) is the moving template
        # . p(x[n] == j[n] | z[n] == k) is the conditional prior evaluated at (j[n], k)
        # --------------------------------------------------------------
        sample_prior(prior, fixed, z)
        z *= moving

        # --------------------------------------------------------------
        # compute log-likelihood (log_sum of the posterior)
        # ll = Σ_n log p(x[n] == j[n])
        #    = Σ_n log Σ_k p(x[n] == j[n] | z[n] == k)  p(z[n] == k)
        #    = Σ_n log Σ_k p(z[n] == k | x[n] == j) + constant{\z}
        # --------------------------------------------------------------
        ll = z.sum(-2, keepdim=True) + tiny
        z /= ll
        ll = ll.log_().mul_(weights).sum([-1, -2], dtype=torch.double)
        z *= weights

        # --------------------------------------------------------------
        # M-step
        # ------
        # estimate joint prior by maximizing Q = E_{Z;H,mu}[ln p(X, Z; H)]
        # => H_jk = p(x == j, z == k) ∝ Σ_n p(z[n] == k | x[n] == j) delta(x[n] == j)
        # --------------------------------------------------------------
        scatter_prior(prior0, fixed, z)  # prior[fixed] <- z
        prior.copy_(prior0).add_(tiny)
        # make it a joint distribution
        prior /= prior.sum(dim=[-1, -2], keepdim=True).add_(tiny)

        if fwhm:
            # smooth "prior" for the prior
            prior = prior.transpose(-1, -2)
            prior = spatial.smooth(prior,
                                   dim=1,
                                   basis=0,
                                   fwhm=fwhm,
                                   bound='replicate')
            prior = prior.transpose(-1, -2)

        # prior /= prior.sum(dim=-2, keepdim=True)
        prior /= prior.sum(dim=-2, keepdim=True) * prior.sum(dim=-1,
                                                             keepdim=True)

        if verbose > 0:
            success = ll.sum() > ll_prev
            happy = ':D' if success else ':('
            end = '\n' if verbose > 1 else '\r'
            print(
                f'(em)     | {n_iter:02d} | {(ll/Nm).mean():12.6g} | {happy}',
                end=end)
        if ll.sum() - ll_prev < tolerance * Nm.sum():
            break
        ll_prev = ll.sum()
    # if verbose == 1:
    #     print('')

    # NOTE:
    # We could return: `joint_prior / (prior_x * prior_z)` instead of
    # `conditional_prior = joint_prior / prior_z` as we currently do, which
    # would correspond to optimizing the mutual information instead of
    # the conditional likelihood. But the additional term only depends on
    # the fixed image, so does not have an impact for registration.
    #
    # Both have the same computational cost, and MI might have a slightly
    # nicer range so we could do that eventually.

    mi = (prior0 * prior.log()).sum()

    return mi, Nm, prior