Example #1
0
    def responsibilities(image, means, precisions, proportions):
        # aliases
        x = image
        m = means
        A = precisions
        p = proportions
        nb_dim = image.dim() - 2
        del image, means, precisions, proportions

        # voxel-wise term
        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        p = unsqueeze(p, dim=1, ndim=nb_dim)  # [B, ones, K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        z = matvec(A, x)
        z = (z * x).sum(dim=-1)  # [B, ..., K]
        z = -0.5 * z

        # constant term
        twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device)
        nrm = torch.logdet(A) - A.shape[-1] * twopi.log()
        nrm = 0.5 * nrm + p.log()
        z = z + nrm

        # softmax
        z = last2channel(z)
        logz = torch.nn.functional.log_softmax(z, dim=1)
        z = torch.nn.functional.softmax(z, dim=1)

        return z, logz
Example #2
0
def affine_grid(mat, shape):
    """Create a dense transformation grid from an affine matrix.

    Parameters
    ----------
    mat : (..., D[+1], D[+1]) tensor
        Affine matrix (or matrices).
    shape : (D,) sequence[int]
        Shape of the grid, with length D.

    Returns
    -------
    grid : (..., *shape, D) tensor
        Dense transformation grid

    """
    mat = torch.as_tensor(mat)
    shape = list(shape)
    nb_dim = mat.shape[-1] - 1
    if nb_dim != len(shape):
        raise ValueError('Dimension of the affine matrix ({}) and shape ({}) '
                         'are not the same.'.format(nb_dim, len(shape)))
    if mat.shape[-2] not in (nb_dim, nb_dim + 1):
        raise ValueError(
            'First argument should be matrces of shape '
            '(..., {0}, {1}) or (..., {1], {1}) but got {2}.'.format(
                nb_dim, nb_dim + 1, mat.shape))
    batch_shape = mat.shape[:-2]
    grid = identity_grid(shape, mat.dtype, mat.device)
    grid = utils.unsqueeze(grid, dim=0, ndim=len(batch_shape))
    mat = utils.unsqueeze(mat, dim=-3, ndim=nb_dim)
    lin = mat[..., :nb_dim, :nb_dim]
    off = mat[..., :nb_dim, -1]
    grid = linalg.matvec(lin, grid) + off
    return grid
Example #3
0
    def forward(self, x):

        backend = utils.backend(x)

        # compute intensity bounds
        vmin = self.vmin
        if vmin is None:
            vmin = x.reshape([*x.shape[:2], -1]).min(dim=-1).values
        vmax = self.vmax
        if vmax is None:
            vmax = x.reshape([*x.shape[:2], -1]).max(dim=-1).values
        vmin = torch.as_tensor(vmin, **backend).expand(x.shape[:2])
        vmin = unsqueeze(vmin, -1, x.dim() - vmin.dim())
        vmax = torch.as_tensor(vmax, **backend).expand(x.shape[:2])
        vmax = unsqueeze(vmax, -1, x.dim() - vmax.dim())

        # sample factor
        factor_exp = utils.make_vector(self.factor_exp, x.shape[1], **backend)
        factor_scale = utils.make_vector(self.factor_scale, x.shape[1],
                                         **backend)
        factor = self.factor(factor_exp, factor_scale)
        factor = factor.sample([len(x)])
        factor = unsqueeze(factor, -1, x.dim() - 2)

        # apply correction
        x = (x - vmin) / (vmax - vmin)
        x = x.pow(factor)
        x = x * (vmax - vmin) + vmin
        return x
Example #4
0
 def get_bins(x, min, max, nbins):
     """Compute the histogram bins."""
     # TODO: It's suboptimal to have bin centers fall at the
     #   min and max. Better to shift them slightly inside.
     if mask is not None:
         # we set masked values to nan so that we can exclude them when
         # computing min/max
         val_nan = torch.as_tensor(nan, **backend)
         x = torch.where(mask, val_nan, x)
         min_fn = nanmin
         max_fn = nanmax
     else:
         min_fn = lambda *a, **k: torch.min(*a, **k).values
         max_fn = lambda *a, **k: torch.max(*a, **k).values
     min = min_fn(x, dim=-1) if min is None else min
     min = torch.as_tensor(min, **backend)
     min = unsqueeze(min, dim=2, ndim=4 - min.dim())
     # -> shape = [B, C, 1, 1]
     max = max_fn(x, dim=-1) if max is None else max
     max = torch.as_tensor(max, **backend)
     max = unsqueeze(max, dim=2, ndim=4 - max.dim())
     # -> shape = [B, C, 1, 1]
     bins = torch.linspace(0, 1, nbins, **backend)
     bins = unsqueeze(bins, dim=0, ndim=3)  # -> [1, 1, 1, nb_bins]
     bins = min + bins * (max - min)  # -> [B, C, 1, nb_bins]
     binwidth = (max - min) / (nbins - 1)  # -> [B, C, 1, 1]
     return bins, binwidth
Example #5
0
def get_log_confusion(confusion, nb_classes_pred, nb_classes_ref, dim,
                      **backend):
    """Return a well formed (log) confusion matrix"""
    if confusion is None:
        confusion = torch.eye(nb_classes_pred, nb_classes_ref, **backend).exp()
    confusion = utils.unsqueeze(confusion, -1, dim)  # spatial shape
    if confusion.dim() < dim + 3:
        confusion = utils.unsqueeze(confusion, 0, 1)  # batch shape
    confusion = confusion / confusion.sum(dim=[-1, -2], keepdim=True)
    confusion = confusion.clamp(min=1e-7, max=1 - 1e-7).logit()
    return confusion
Example #6
0
def _build_kernel(dim, **backend):
    kernel = torch.as_tensor([0.75, 1., 0.75], **backend)
    normk = kernel
    for d in range(1, dim):
        normk = normk.unsqueeze(-1)
        normk = normk * kernel
    normk = normk.sum()
    normk = normk ** (1/dim)
    kernel /= normk
    kernels = []
    for d in range(dim):
        kernel1 = kernel
        kernel1 = utils.unsqueeze(kernel1, 0, d)
        kernel1 = utils.unsqueeze(kernel1, -1, dim-1-d)
        kernels.append(kernel1)
    return kernels
Example #7
0
    def forward(self, image, **overload):
        backend = utils.backend(image)
        sigma = overload.get('sigma', self.sigma)
        gfactor = overload.get('gfactor', self.gfactor)

        # sample sigma
        if sigma is None:
            sigma = self.default_sigma(*image.shape[:2], **backend)
        if callable(sigma):
            sigma = sigma(image.shape[:2])
        sigma = torch.as_tensor(sigma, **backend)
        sigma = unsqueeze(sigma, -1, 2 - sigma.dim())

        # sample gfactor
        if gfactor is True:
            gfactor = field.RandomMultiplicativeField()
        if callable(gfactor):
            gfactor = gfactor(image.shape)

        # sample noise
        zero = torch.tensor(0, **backend)
        noise = td.Normal(zero, sigma).sample(image.shape[2:])
        noise = utils.movedim(noise, [-1, -2], [0, 1])

        if torch.is_tensor(gfactor):
            noise *= gfactor

        image = image + noise
        return image
Example #8
0
    def load(self, fname, dtype=None, device=None):
        """Load a volume from disk

        Parameters
        ----------
        fname : str
        dtype : torch.dtype, optional

        Returns
        -------
        dat : (channels, *spatial) tensor

        """
        dtype = dtype or self.dtype
        device = device or self.device
        if not dtype or dtype.is_floating_point:
            dat = io.loadf(fname, dtype=dtype, device=device)
            dat = self.rescale(dat)
        else:
            dat = io.load(fname, dtype=dtype, device=device)
        dat = dat.squeeze()
        dim = self.dim or dat.dim()
        dat = utils.unsqueeze(dat, -1, max(0, dim - dat.dim()))
        dat = dat.reshape([*dat.shape[:dim], -1])
        dat = utils.movedim(dat, -1, 0)
        dat = self.to_shape(dat)
        return dat
Example #9
0
def transform_pointset_dense(points, grid, type='grid', bound='dct2'):
    """Transform a pointset

    Points must already be expressed in "grid voxels" coordinates.

    Parameters
    ----------
    points : (n, dim) tensor
        Set of coordinates, in voxel space
    grid : (*spatial, dim) tensor
        Dense transformation or displacement grid, in voxel space
    type : {'grid', 'disp'}, defualt='grid'
        Transformation or displacement
    bound : str, default='dct2'
        Boundary conditions for out-of-bounds data

    Returns
    -------
    points : (n, dim) tensor
        Transformed coordinates

    """

    dim = grid.shape[-1]
    points = utils.unsqueeze(points, 0, dim)
    grid = utils.movedim(grid, -1, 0)[None]
    delta = spatial.grid_pull(grid, points, bound=bound, extrapolate=True)
    delta = utils.movedim(delta, 1, -1)
    if type == 'disp':
        points = points + delta
    else:
        points = delta
    points = utils.squeeze(points, -2, dim - 1).squeeze(0)
    return points
Example #10
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # device/dtype
        nb_dim = len(shape)
        mean = utils.make_vector(self.mean, channel, **backend)
        amplitude = utils.make_vector(self.amplitude, channel, **backend)
        fwhm = utils.make_vector(self.fwhm, nb_dim, **backend)

        # sample spline coefficients
        nodes = [(s/f).ceil().int().item()
                 for s, f in zip(shape, fwhm)]
        sample = torch.randn([batch, channel, *nodes], **backend)
        sample *= utils.unsqueeze(amplitude, -1, nb_dim)
        sample = spatial.resize(sample, shape=shape, interpolation=self.basis,
                                bound='dct2', prefilter=False)
        sample += utils.unsqueeze(mean, -1, nb_dim)
        return sample
Example #11
0
    def nll(image, resp, means, precisions):
        # aliases
        x = image
        z = resp
        m = means
        A = precisions
        nb_dim = image.dim() - 2
        del image, resp, means, precisions

        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        z = channel2last(z)  # [B, ..., K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        loss = matvec(A, x)
        loss = (loss * x).sum(dim=-1)  # [B, ..., K]
        loss = (loss * z).sum(dim=-1)  # [B, ...]
        loss = loss * 0.5
        return loss
Example #12
0
 def forward(self, image, **overload):
     factor = overload.get('factor', self.factor)
     if factor is None:
         factor = self.default_factor(len(image), **utils.backend(image))
     if callable(factor):
         factor = factor(image.shape[0])
     factor = torch.as_tensor(factor, **utils.backend(image))
     factor = unsqueeze(factor, -1, image.dim() - factor.dim())
     image = self.op(image, factor)
     return image
Example #13
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch shape.

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """
        shape = overload.get('shape', self.grid.velocity.field.shape)
        dtype = overload.get('dtype', self.grid.velocity.field.dtype)
        device = overload.get('device', self.grid.velocity.field.device)
        backend = dict(dtype=dtype, device=device)

        if self.grid.velocity.field.amplitude == 0:
            grid = identity_grid(shape, **backend)
        else:
            grid = self.grid(batch, shape=shape, **backend)
        dtype = grid.dtype
        device = grid.device
        backend = dict(dtype=dtype, device=device)

        shape = grid.shape[1:-1]
        dim = len(shape)
        aff = self.affine(batch, dim=dim, **backend)

        # shift center of rotation
        aff_shift = torch.cat((
            torch.eye(dim, **backend),
            torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)),
            dim=1)
        aff_shift = as_euclidean(aff_shift)

        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = utils.unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = linalg.matvec(lin, grid) + off

        return grid
Example #14
0
    def exp(self, velocity, affine=None, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        affine : (batch, nb_prm)
            Affine parameters
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacment).

        """
        info = {'dtype': velocity.dtype, 'device': velocity.device}

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if affine is not None:
            # exponentiate
            affine_prm = affine
            affine = []
            for prm in affine_prm:
                affine.append(self.affexp(prm))
            affine = torch.stack(affine, dim=0)

            # shift center of rotation
            affine_shift = torch.cat(
                (torch.eye(self.dim, **info),
                 -torch.as_tensor(shape, **info)[:, None] / 2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

            # compose
            affine = unsqueeze(affine, dim=-3, ndim=self.dim)
            lin = affine[..., :self.dim, :self.dim]
            off = affine[..., :self.dim, -1]
            grid = matvec(lin, grid) + off

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **info)

        return grid
Example #15
0
    def kl(resp, log_resp, proportions):
        # aliases
        z = resp
        logz = log_resp
        p = proportions
        nb_dim = resp.dim() - 2
        del resp, log_resp, proportions

        p = unsqueeze(p, dim=-1, ndim=nb_dim)  # [B, K, ones]
        loss = z * (logz - p.log())  # [B, K, ...]
        loss = loss.sum(dim=1)  # [B, ...]
        return loss
Example #16
0
def _process_weights(weighted, dim, nb_classes, **backend):
    weighted_channelwise = False
    if weighted is not False:
        weighted = torch.as_tensor(weighted, **backend)
        if weighted.dim() == 1:
            weighted = utils.unsqueeze(weighted, -1, dim)
        if weighted.numel() == nb_classes:
            weighted_channelwise = True
            weighted = weighted.flatten()
    else:
        weighted = None
    return weighted, weighted_channelwise
Example #17
0
def irls_tukey_reweight(moving,
                        fixed,
                        lam=1,
                        c=4.685,
                        joint=False,
                        dim=None,
                        mask=None):
    """Update iteratively reweighted least-squares weights for Tukey's biweight

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Equivalent to Gaussian noise precision
        (used to standardize the residuals)
    c  : float, default=4.685
        Tukey's threshold.
        Approximately equal to a number of standard deviations above
        which the loss is capped.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions

    Returns
    -------
    weights : (..., K|1, *spatial) tensor
        IRLS weights

    """
    if lam is None:
        lam = 1
    c = c * c
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    weights = (moving - fixed).square_().mul_(lam)
    if mask is not None:
        weights = weights.mul_(mask)
    if joint:
        weights = weights.sum(dim=-dim - 1, keepdims=True)
    zeromsk = weights > c
    weights = weights.div_(-c).add_(1).square()
    weights[zeromsk].zero_()
    return weights
Example #18
0
def irls_laplace_reweight(moving,
                          fixed,
                          lam=1,
                          joint=False,
                          eps=1e-5,
                          dim=None,
                          mask=None):
    """Update iteratively reweighted least-squares weights for l1

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Inverse-squared scale parameter of the Laplace distribution.
        (equivalent to Gaussian noise precision)
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions

    Returns
    -------
    weights : (..., K|1, *spatial) tensor
        IRLS weights

    """
    if lam is None:
        lam = 1
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    weights = (moving - fixed).square_().mul_(lam)
    if mask is not None:
        weights = weights.mul_(mask)
    if joint:
        weights = weights.sum(dim=-dim - 1, keepdims=True)
    weights = weights.sqrt_().clamp_min_(eps).reciprocal_()
    if mask is not None:
        weights = weights.masked_fill_(mask == 0, 0)
    return weights
Example #19
0
def roi_closing(label, radius=10, dim=None):
    """Performs a multi-label morphological closing.

    Parameters
    ----------
    label : (..., *spatial) tensor[int]
        Volume of labels.
    radius : float, default=1
        Radius of the structuring element (in voxels)
    dim : int, default=label.dim()
        Number of spatial dimensions

    Returns
    -------
    closed_label : tensor[int]

    """
    from scipy.ndimage import distance_transform_edt, binary_closing

    dim = dim or label.dim()
    closest_label = torch.zeros_like(label)
    closest_dist = label.new_full(label.shape, float('inf'), dtype=torch.float)
    dist = torch.empty_like(closest_dist)

    for l in label.unique():
        if l == 0:
            continue
        if label.dim() == dim:
            dist = torch.as_tensor(distance_transform_edt(label != l))
        elif label.dim() == dim + 1:
            for z in range(len(dist)):
                dist[z] = torch.as_tensor(
                    distance_transform_edt(label[z] != l))
        else:
            raise NotImplementedError
        closest_label[dist < closest_dist] = l
        closest_dist = torch.min(closest_dist, dist)

    struct = spatial.identity_grid([2 * radius + 1] * dim).sub_(radius)
    struct = struct.square().sum(-1).sqrt() <= radius
    struct = utils.unsqueeze(struct, 0, label.dim() - dim)
    mask = binary_closing(label > 0, struct)
    mask = torch.as_tensor(mask).bitwise_not_()
    closest_label[mask] = 0

    return closest_label
Example #20
0
    def forward(self, x, output_padding=None, output_shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *in_spatial) tensor
        output_padding : [sequence of] int, default=self.output_padding
        output_shape : [sequence of] int, default=self.output_shape

        Returns
        -------
        x : (batch, channel, *out_spatial) tensor

        """
        dim = x.dim() - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)

        new_shape = self.shape(x, output_padding=output_padding,
                               output_shape=output_shape)
        y = x.new_zeros(new_shape)
        if self.fill:
            z = utils.unfold(y, stride)
            x = utils.unsqueeze(x, -1, dim)
            slicer = [slice(o, o+sz*st) for sz, st, o in
                      zip(x.shape[2:], stride, offset)]
            slicer = [slice(None)]*2 + slicer
            subz = z[tuple(slicer)]
            slicer = [slice(mx) for mx in subz.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            subz.copy_(x[tuple(slicer)])
        else:
            slicer = [slice(o, None, s) for o, s in zip(offset, stride)]
            slicer = [slice(None)]*2 + slicer
            suby = y[tuple(slicer)]
            slicer = [slice(mx) for mx in suby.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            suby.copy_(x[tuple(slicer)])

        return y
Example #21
0
def spconv(input,
           kernel,
           step=1,
           start=0,
           stop=None,
           inplace=False,
           bound='dct2',
           dim=None):
    """Convolution with a sparse kernel.

    Notes
    -----
    .. This convolution does not support strides, padding, dilation.
    .. The output spatial shape is the same as the input spatial shape.
    .. The output batch shape is the same as the input batch shape.
    .. Data outside the field-of-view is extrapolated according to `bound`
    .. It is implemented as a linear combination of views into the input
       tensor and should therefore be relatively memory-efficient.

    Parameters
    ----------
    input : (..., [channel_in], *spatial) tensor
        Input tensor, to convolve.
    kernel : ([channel_in, [channel_out]], *kernel_size) sparse tensor
        Convolution kernel.
    start : [sequence of] int, default=0
    stop : [sequence of] int, default=None
    step : [sequence of] int, default=1
        Equivalent to spconv(x)[start:stop:step]
    bound : [sequence of] str, default='dct2'
        Boundary condition (per spatial dimension).
    dim : int, default=kernel.dim()
        Number of spatial dimensions.

    Returns
    -------
    output : (..., [channel_out or channel_in], *spatial) tensor

        * If the kernel shape is (channel_in, channel_out, *kernel_size),
          the output shape is (..., channel_out, *spatial) and cross-channel
          convolution happens:
            out[co] = \sum_{ci} conv(inp[ci], ker[ci, co])
        * If the kernel_shape is (channel_in, *kernel_size), independent
          single-channel convolutions are applied to each channels::
            out[c] = conv(inp[c], ker[c])
        * If the kernel shape is (*kernel_size), the same convolution
          is applied to all input channels:
            out[c] = conv(inp[c], ker)

    """
    # get kernel dimensions
    dim = dim or kernel.dim()
    if kernel.dim() == dim + 2:
        channel_in, channel_out, *kernel_size = kernel.shape
    elif kernel.dim() == dim + 1:
        channel_in, *kernel_size = kernel.shape
        channel_out = None
    elif kernel.dim() == dim:
        kernel_size = kernel.shape
        channel_in = channel_out = None
    else:
        raise ValueError('Incompatible kernel shape: too many dimensions')
    start = core.py.ensure_list(start or 0, dim)
    stop = core.py.ensure_list(stop, dim)
    step = core.py.ensure_list(step, dim)

    # check input dimensions
    added_dims = max(0, dim + 1 - input.dim())
    input = unsqueeze(input, 0, added_dims)
    if channel_in is not None:
        if input.shape[-dim - 1] not in (1, channel_in):
            raise ValueError('Incompatible kernel shape: input channels')
        spatial_shape = input.shape[-dim:]
        batch_shape = input.shape[:-dim - 1]
        output_shape = tuple(
            [*batch_shape, channel_out or channel_in, *spatial_shape])
    else:
        # add a fake channel dimension
        spatial_shape = input.shape[-dim:]
        batch_shape = input.shape[:-dim]
        input = input.reshape([*batch_shape, 1, *spatial_shape])
        output_shape = input.shape
    output_spatial_shape = spatial_shape
    start = [
        0 if not str else str + sz if str < 0 else str
        for str, sz in zip(start, spatial_shape)
    ]
    stop = [
        sz if stp is None else stp + sz if stp < 0 else stp
        for stp, sz in zip(stop, spatial_shape)
    ]
    stop = [stp - 1 for stp in stop
            ]  # we use an inclusive stop in the rest of the code
    step = [st or 1 for st in step]
    if step:
        output_spatial_shape = [
            int(pymath.floor((stp - str) / float(st) + 1))
            for stp, st, str in zip(stop, step, start)
        ]
        output_shape = [*output_shape[:-dim], *output_spatial_shape]

    slicer = [
        slice(str, stp + 1, st) for str, stp, st in zip(start, stop, step)
    ]
    slicer = tuple([Ellipsis, *slicer])
    identity = input[slicer]
    assert identity.shape[-dim:] == tuple(output_shape[-dim:]), "oops"
    if inplace:
        output = identity
        identity = identity.clone()
        output.zero_()
    else:
        output = input.new_zeros(output_shape)

    # move channel + spatial dimensions to the front
    for d in range(dim + 1):  # +1 for channel dim
        input = core.utils.fast_movedim(input, -1, 0)
        output = core.utils.fast_movedim(output, -1, 0)
        identity = core.utils.fast_movedim(identity, -1, 0)

    # prepare other stuff
    bound = core.py.ensure_list(bound, dim)
    bound = [getattr(_bounds, b, None) for b in bound]
    # shift = torch.as_tensor([int(pymath.floor(k/2)) for k in kernel_size],
    #                         dtype=torch.long, device=kernel.device)
    shift = [int(pymath.floor(k / 2)) for k in kernel_size]
    sides = list(itertools.product([True, False], repeat=dim))

    # Numeric magic to (hopefully) avoid floating point inaccuracy
    subw0 = True
    if subw0:
        kernel, w0 = _split_kernel(kernel, dim)
    else:
        identity = None

    split_idx = _get_idx_split(kernel.dim(), dim)

    # loop across weights in the sparse kernel
    indices = kernel._indices().t().tolist()
    values = kernel._values()
    for idx, weight in zip(indices, values):

        # map input and output channels
        ci, co, idx = split_idx(idx)
        idx = [i - s for i, s in zip(idx, shift)]

        inp = input[ci]
        out = output[co]
        if identity is not None:
            idt = identity[co]
        else:
            idt = None

        # generate slicers
        (input_center_slice, input_side_slice,
         output_center_slice, output_side_slice, transfo_side) = \
            _make_slicers(idx, start, stop, step,
                          output_spatial_shape, spatial_shape, bound)

        # Iterate all combinations of in/out of bounds
        for side in sides:
            input_slicer = tuple(
                input_center_slice[d] if inside else input_side_slice[d]
                for d, inside in enumerate(side))
            output_slicer = tuple(
                output_center_slice[d] if inside else output_side_slice[d]
                for d, inside in enumerate(side))
            transfo = tuple(None if inside else transfo_side[d]
                            for d, inside in enumerate(side))

            if any(sl is None for sl in input_slicer):
                continue
            if any(sl is None for sl in output_slicer):
                continue

            _accumulate(out,
                        inp,
                        output_slicer,
                        input_slicer,
                        transfo,
                        weight,
                        idt=idt,
                        diag=(ci == co))

    # add weighted identity
    if subw0:
        w0 = core.utils.unsqueeze(w0, -1, output.dim() - 1)
        output.addcmul_(identity, w0)

    # move spatial dimensions to the back
    for d in range(dim + 1):
        output = core.utils.fast_movedim(output, 0, -1)

    # remove fake channels
    if channel_in is None:
        output = output.squeeze(len(batch_shape))
    # remove added dimensions
    for _ in range(added_dims):
        output = output.squeeze(-dim - 1)
    return output
Example #22
0
    def forward(self, score, truth, mask=None):
        """

        Parameters
        ----------
        score : (nb_batch, nb_class, *spatial) tensor
            Pre-transformed score vector.
        truth : (nb_batch, nb_class[-1]|1, *spatial) tensor
            Observed classes (or their expectation).
                * If `obs` has a floating point data type (`half`,
                  `float`, `double`) it is assumed to hold one-hot or
                  soft labels, and its channel dimension should be
                  `nb_class` or `nb_class - 1`.
                * If `obs` has an integer or boolean data type, it is
                  assumed to hold hard labels and its channel dimension
                  should be 1.
        mask : (nb_batch, 1, *spatial) tensor, optional
            Loss mask

        Returns
        -------
        loss : scalar or tensor
            The output shape depends on the type of reduction used.
            If 'mean' or 'sum', this function returns a scalar.

        """

        weighted = self.weighted

        score = torch.as_tensor(score)
        truth = torch.as_tensor(truth, device=score.device)
        nb_classes = score.shape[1]  # (includes background)

        if truth.dtype.is_floating_point:
            # soft labels
            truth = truth.to(score.dtype)
            truth_implicit = truth.shape[1] == nb_classes - 1
            truth = get_prob_explicit(truth, implicit=truth_implicit)
            if truth.shape[1] != nb_classes:
                raise ValueError('Number of classes not consistent. '
                                 'Expected {} or {} but got {}.'.format(
                                     nb_classes, nb_classes - 1,
                                     truth.shape[1]))

            loss = score * truth

            if weighted is True:
                weighted = _auto_weighted_soft(truth)
                if mask is not None:
                    weighted = weighted * mask
                loss *= weighted
            elif weighted not in (None, False):
                dim = truth.dim() - 2
                weighted = utils.make_vector(weighted, nb_classes,
                                             **utils.backend(loss))
                weighted = utils.unsqueeze(weighted, -1, dim)
                if mask is not None:
                    weighted = weighted * mask
                loss *= weighted
            elif mask is not None:
                loss *= mask

        else:
            # hard labels
            channelwise = True
            if weighted is True:
                channelwise = False
                weighted = _auto_weighted_hard(truth, nb_classes,
                                               **utils.backend(score))
            elif weighted not in (None, False):
                weighted = utils.make_vector(weighted, **utils.backend(score))
            else:
                weighted = None

            truth = truth.squeeze(1).long()
            # If weights are a list of length C (or none), use nll_loss
            if channelwise and isinstance(self.reduction,
                                          str) and mask is None:
                return F.nll_loss(score,
                                  truth,
                                  weighted,
                                  reduction=self.reduction or 'none').neg_()
            # Otherwise, use our own implementation
            else:
                if weighted is not None:
                    score = score * weighted
                loss = score.gather(dim=1, index=truth)
                if mask is not None:
                    mask.squeeze(1)
                    loss *= mask

        if mask is not None and self.reduction == 'mean':
            return loss.sum() / mask.sum()
        return super().forward(loss)
Example #23
0
 def compute_grad(dat):
     med = dat.reshape([dat.shape[0], -1]).median(dim=-1).values
     med = utils.unsqueeze(med, -1, 3)
     dat /= 0.5*med
     dat = spatial.diff(dat, dim=[1, 2, 3]).square().sum(-1)
     return dat
Example #24
0
def shim(fmap,
         max_order=2,
         mask=None,
         isocenter=None,
         dim=None,
         returns='corrected'):
    """Subtract a linear combination of spherical harmonics that minimize gradients

    Parameters
    ----------
    fmap : (..., *spatial) tensor
        Field map
    max_order : int, default=2
        Maximum order of the spherical harmonics
    mask : tensor, optional
        Mask of voxels to include (typically brain mask)
    isocenter : [sequence of] float, default=shape/2
        Coordinate of isocenter, in voxels
    dim : int, default=fmap.dim()
        Number of spatial dimensions
    returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected'
        Components to return

    Returns
    -------
    corrected : (..., *spatial) tensor, if 'corrected' in `returns`
        Corrected field map (with spherical harmonics subtracted)
    correction : (..., *spatial) tensor, if 'correction' in `returns`
        Linear combination of spherical harmonics.
    parameters : (..., k) tensor, if 'parameters' in `returns`
        Parameters of the linear combination

    """
    fmap = torch.as_tensor(fmap)
    dim = dim or fmap.dim()
    shape = fmap.shape[-dim:]
    batch = fmap.shape[:-dim]
    backend = utils.backend(fmap)
    dims = list(range(-dim, 0))

    if mask is not None:
        mask = ~mask  # make it a mask of background voxels

    # compute gradients
    gmap = diff(fmap, dim=dims, side='f', bound='dct2')
    if mask is not None:
        gmap[..., mask, :] = 0
    gmap = gmap.reshape([*batch, -1])

    # compute basis of spherical harmonics
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = diff(b, dim=dims, side='f', bound='dct2')
        if mask is not None:
            b[..., mask, :] = 0
        b = b.reshape([b.shape[0], *batch, -1])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    # solve system
    prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0]
    # > (*batch, k)

    # rebuild basis (without taking gradients)
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = b.reshape([b.shape[0], *batch, *shape])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim))
    comb = comb[..., 0]
    fmap = fmap - comb

    returns = returns.split('+')
    out = []
    for ret in returns:
        if ret == 'corrected':
            out.append(fmap)
        elif ret == 'correction':
            out.append(comb)
        elif ret[0] == 'p':
            out.append(prm)
    return out[0] if len(out) == 1 else tuple(out)
Example #25
0
def zcorrect_exp_const(x,
                       decay=None,
                       sigma=None,
                       lam=10,
                       mask=None,
                       max_iter=128,
                       tol=1e-6,
                       verbose=False,
                       snr=5):
    """Correct the z signal decay in a SPIM image.

    The signal is modelled as: f(z) = s * exp(-b * z) + eps
    where z=0 is (arbitrarily) the middle slice, s is the intercept
    and b is the decay coefficient.

    Parameters
    ----------
    x : (..., nz) tensor
        SPIM image with the z dimension last and the z=0 plane first
    decay : float, optional
        Initial guess for decay parameter. Default: educated guess.
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float or (float, float), default=10
        Regularisation.
    max_iter : int, default=128
    tol : float, default=1e-6
    verbose : int or bool, default=False

    Returns
    -------
    y : tensor
        Corrected image
    decay : float
        Decay parameters

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    backend = utils.backend(x)
    shape = x.shape
    dim = x.dim() - 1
    nz = shape[-1]
    b = decay

    x = utils.movedim(x, -1, 0).clone()
    if mask is None:
        mask = torch.isfinite(x) & (x > 0)
    else:
        mask = mask & (torch.isfinite(x) & (x > 0))
    x[~mask] = 0

    # decay educated guess: closed form from two values
    if b is None:
        z1 = 2 * nz // 5
        z2 = 3 * nz // 5
        x1 = x[z1]
        x1 = x1[x1 > 0].median()
        x2 = x[z2]
        x2 = x2[x2 > 0].median()
        z1 = float(z1)
        z2 = float(z2)
        b = (x2.log() - x1.log()) / (z1 - z2)
    y = x[(nz - 1) // 2]
    y = y[y > 0].median().log()

    b = b.item() if torch.is_tensor(b) else b
    y = y.item()
    print(f'init: y = {y}, b = {b}')

    # noise educated guess: assume SNR=5 at z=1/2
    sigma = sigma or (y / snr)
    lam_y, lam_b = py.make_list(lam, 2)
    lam_y = lam_y**2 * sigma**2
    lam_b = lam_b**2 * sigma**2
    reg = lambda t: spatial.regulariser(
        t, membrane=1, dim=dim, factor=(lam_y, lam_b))
    solve = lambda h, g: spatial.solve_field_fmg(
        h, g, membrane=1, dim=dim, factor=(lam_y, lam_b))

    # init
    z = torch.arange(nz, **backend) - (nz - 1) / 2
    z = utils.unsqueeze(z, -1, dim)
    theta = z.new_empty([2, *x.shape[1:]], **backend)
    logy = theta[0].fill_(y)
    b = theta[1].fill_(b)
    y = logy.exp()
    ll0 = (mask * y *
           (-b * z).exp_() - x).square_().sum() + (theta * reg(theta)).sum()
    ll1 = ll0

    g = torch.zeros_like(theta)
    h = theta.new_zeros([3, *theta.shape[1:]])
    for it in range(max_iter):

        # exponentiate
        y = torch.exp(logy, out=y)
        fit = (b * z).neg_().exp_().mul_(y).mul_(mask)
        res = fit - x

        # compute objective
        reg_theta = reg(theta)
        ll = res.square().sum() + (theta * reg_theta).sum()
        gain = (ll1 - ll) / ll0
        if verbose:
            end = '\n' if verbose > 1 else '\r'
            print(f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}', end=end)
        if it > 0 and gain < tol:
            break
        ll1 = ll

        g[0] = (fit * res).sum(0)
        g[1] = -(fit * res * z).sum(0)
        h[0] = (fit * (fit + res.abs())).sum(0)
        h[1] = (fit * (fit + res.abs()) * (z * z)).sum(0)
        h[2] = -(z * fit * fit).sum(0)

        g += reg_theta
        theta -= solve(h, g)

    y = torch.exp(logy, out=y)
    x = x * (b * z).exp_()
    x = utils.movedim(x, 0, -1)
    x = x.reshape(shape)
    return y, b, x
Example #26
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        amplitude = overload.get('amplitude', self.amplitude)
        fwhm = overload.get('fwhm', self.fwhm)
        channel = overload.get('channel', self.channel)
        basis = overload.get('basis', self.basis)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)

        # sample if parameters are callable
        mean = mean() if callable(mean) else mean
        amplitude = amplitude() if callable(amplitude) else amplitude
        fwhm = fwhm() if callable(fwhm) else fwhm

        # device/dtype
        mean = torch.as_tensor(mean, dtype=dtype, device=device)
        amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device)
        fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device)

        # reshape
        nb_dim = len(shape)
        full_shape = [batch, channel, *shape]
        mean = mean.expand(full_shape)
        amplitude = amplitude.expand(full_shape)
        fwhm = fwhm.expand([batch, channel, nb_dim])

        conv = torch.nn.functional.conv1d if nb_dim == 1 else \
               torch.nn.functional.conv2d if nb_dim == 2 else \
               torch.nn.functional.conv3d if nb_dim == 3 else None

        # convert SE parameters to noise/kernel parameters
        sigma_se = fwhm / math.sqrt(8 * math.log(2))
        sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim)
        amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt()
        fwhm = fwhm * math.sqrt(2)

        # smooth
        samples_b = []
        for b in range(batch):
            samples_c = []
            for c in range(channel):
                kernel = smooth('gauss',
                                fwhm[b, c],
                                basis=basis,
                                device=device,
                                dtype=dtype)

                # compute input shape
                pad_shape = [
                    shape[d] + kernel[d].shape[d + 2] - 1
                    for d in range(nb_dim)
                ]
                mean1 = ensure_shape(mean[b, c],
                                     pad_shape,
                                     mode='reflect2',
                                     side='both')
                amplitude1 = ensure_shape(amplitude[b, c],
                                          pad_shape,
                                          mode='reflect2',
                                          side='both')

                # generate sample
                sample = torch.distributions.Normal(mean1, amplitude1).sample()
                sample = sample[None, None, ...]

                # convolve
                for ker in kernel:
                    sample = conv(sample, ker)

                samples_c.append(sample)

            samples_b.append(torch.cat(samples_c, dim=1))

        sample = torch.cat(samples_b, dim=0)

        return sample
Example #27
0
def mse(moving, fixed, lam=1, dim=None, grad=True, hess=True, mask=None):
    """Mean-squared error loss for optimisation-based registration.

    (A factor 1/2 is included, and the loss is averaged across voxels,
    but not across channels or batches)

    Parameters
    ----------
    moving : ([B], K, *spatial) tensor
        Moving image
    fixed : ([B], K, *spatial) tensor
        Fixed image
    lam : float or ([B], K|1, [*spatial]) tensor_like
        Gaussian noise precision (or IRLS weights)
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return Hessian

    Returns
    -------
    ll : () tensor
        Negative log-likelihood
    g : (..., K, *spatial) tensor, optional
        Gradient with respect to the moving imaged
    h : (..., K, *spatial) tensor, optional
        (Diagonal) Hessian with respect to the moving image

    """
    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    if mask is not None:
        mask = mask.to(fixed.device)
    dim = dim or (fixed.dim() - 1)
    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)  # pad spatial dimensions
    nvox = py.prod(fixed.shape[-dim:])

    if moving.requires_grad:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square().mul_(lam).sum() / (2 * nvox)
    else:
        ll = moving - fixed
        if mask is not None:
            ll = ll.mul_(mask)
        ll = ll.square_().mul_(lam).sum() / (2 * nvox)

    out = [ll]
    if grad:
        g = moving - fixed
        if mask is not None:
            g = g.mul_(mask)
        g = g.mul_(lam).div_(nvox)
        out.append(g)
    if hess:
        h = lam / nvox
        if mask is not None:
            h = mask * h
        out.append(h)

    return tuple(out) if len(out) > 1 else out[0]
Example #28
0
    def forward(self, image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *shape) tensor
            Input image
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        warped : (batch, channel, *shape) tensor
            Deformed image
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """

        image = torch.as_tensor(image)
        dim = image.dim() - 2
        batch, channel, *shape = image.shape
        info = {'dtype': image.dtype, 'device': image.device}

        # get arguments
        opt_grid = {
            'dim': dim,
            'shape': shape,
            'amplitude': overload.get('vel_amplitude', self.grid.amplitude),
            'fwhm': overload.get('vel_fwhm', self.grid.fwhm),
            'bound': overload.get('vel_bound', self.grid.bound),
            'interpolation': overload.get('interpolation',
                                          self.grid.interpolation),
            'dtype': overload.get('dtype', self.grid.dtype),
            'device': overload.get('device', self.grid.device),
        }
        opt_affine = {
            'dim': dim,
            'translation': overload.get('translation',
                                        self.affine.translation),
            'rotation': overload.get('rotation', self.affine.rotation),
            'zoom': overload.get('zoom', self.affine.zoom),
            'shear': overload.get('shear', self.affine.shear),
            'dtype': overload.get('dtype', self.affine.dtype),
            'device': overload.get('device', self.affine.device),
        }
        opt_pull = {
            'bound':
            overload.get('image_bound', self.pull.bound),
            'interpolation':
            overload.get('interpolation', self.pull.interpolation),
        }

        grid = self.grid(batch, **opt_grid)
        aff = self.affine(batch, **opt_affine)

        # shift center of rotation
        aff_shift = torch.cat(
            (torch.eye(dim, **info),
             -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2),
            dim=1)
        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = matvec(lin, grid) + off

        # pull
        warped = self.pull(image, grid, **opt_pull)

        return warped, grid
Example #29
0
def lcc(moving,
        fixed,
        dim=None,
        patch=20,
        stride=1,
        lam=1,
        mode='g',
        grad=True,
        hess=True,
        mask=None):
    """Local correlation coefficient (squared)

    This function implements a squared version of Cachier and
    Pennec's local correlation coefficient, so that anti-correlations
    are not penalized.

    Parameters
    ----------
    moving : (..., K, *spatial) tensor
        Moving image with K channels.
    fixed : (..., K, *spatial) tensor
        Fixed image with K channels.
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions.
    patch : int, default=5
        Patch size
    lam : float or ([B], K|1, [*spatial]) tensor_like, default=1
        Precision of the NCC distribution
    grad : bool, default=True
        Compute and return gradient
    hess : bool, default=True
        Compute and return approximate Hessian

    Returns
    -------
    ll : () tensor

    References
    ----------
    ..[1] "3D Non-Rigid Registration by Gradient Descent on a Gaussian-
           Windowed Similarity Measure using Convolutions"
          Pascal Cachier, Xavier Pennec
          MMBIA (2000)

    """
    if moving.requires_grad:
        sqrt_ = torch.sqrt
        div_ = torch.div
    else:
        sqrt_ = torch.sqrt_
        div_ = lambda x, y: x.div_(y)

    fixed, moving, lam = utils.to_max_backend(fixed, moving, lam)
    dim = dim or (fixed.dim() - 1)
    shape = fixed.shape[-dim:]
    if mask is not None:
        mask = mask.to(**utils.backend(fixed))
    else:
        mask = fixed.new_ones(fixed.shape[-dim:])

    if lam.dim() <= 2:
        if lam.dim() == 0:
            lam = lam.flatten()
        lam = utils.unsqueeze(lam, -1, dim)

    patch = list(map(float, py.ensure_list(patch)))
    stride = py.ensure_list(stride)
    stride = [s or 0 for s in stride]
    fwd = lambda x: local_mean(
        x, patch, stride, dim=dim, mode=mode, mask=mask, cache=local_cache)
    bwd = lambda x: local_mean(x,
                               patch,
                               stride,
                               dim=dim,
                               mode=mode,
                               mask=mask,
                               backward=True,
                               shape=shape,
                               cache=local_cache)
    sumall = lambda x: x.sum(list(range(-dim, 0)), keepdim=True)

    # compute ncc within each patch
    mom0, moving_mean, fixed_mean, moving_std, fixed_std, corr = \
        _suffstat(fwd, moving, fixed)
    mom0 = mom0.div_(sumall(mom0).clamp_min_(1e-5)).mul_(lam)
    moving_std = sqrt_(moving_std.addcmul_(moving_mean, moving_mean, value=-1))
    fixed_std = sqrt_(fixed_std.addcmul_(fixed_mean, fixed_mean, value=-1))
    moving_std.clamp_min_(1e-5)
    fixed_std.clamp_min_(1e-5)
    corr = div_(
        div_(corr.addcmul_(moving_mean, fixed_mean, value=-1), moving_std),
        fixed_std)
    corr2 = corr.square().neg_().add_(1).clamp_min_(1e-8)

    out = []
    if grad or hess:
        h = (corr / moving_std).square_().mul_(mom0).div_(corr2)
        h = bwd(h)

        if grad:
            # g = G' * (corr.*(corr.*xmean./xstd - ymean./ystd)./xstd)
            #   - x .* (G' * (corr./ xstd).^2)
            #   + y .* (G' * (corr ./ (xstd.*ystd)))
            # g = -2 * g
            fixed_mean = fixed_mean.div_(fixed_std)
            moving_mean = moving_mean.div_(moving_std)
            g = fixed_mean.addcmul_(corr, moving_mean, value=-1)
            fixed_mean = moving_mean = None
            g = g.mul_(corr).div_(moving_std).mul_(mom0).div_(corr2)
            g = bwd(g)
            g = g.addcmul_(h, moving)
            g = g.addcmul_(bwd(
                corr.div_(moving_std).div_(fixed_std).mul_(mom0).div_(corr2)),
                           fixed,
                           value=-1)
            g = g.mul_(2)
            out.append(g)

        if hess:
            # h = 2 * (G' * (corr./ xstd).^2)
            h = h.mul_(2)
            out.append(h)

    # return stuff
    corr = corr2.log_().mul_(mom0)
    corr = corr.sum()
    out = [corr, *out]
    return tuple(out) if len(out) > 1 else out[0]
Example #30
0
 def preprocess(a):
     a = torch.as_tensor(a)
     a = unsqueeze(a, dim=-1, ndim=opt['channel'] + 1 - a.dim())
     return a