Esempio n. 1
0
def _loss_ssqd_jtv(dat_x,
                   dat_y,
                   tau,
                   lam,
                   voxel_size=1,
                   side='f',
                   bound='dct2'):
    """Computes an image denoising loss function, where:
    * fidelity term: sum-of-squared differences (SSQD)
    * regularisation term: joint total variation (JTV)
    * hyper-parameters: tau, lambda

    Parameters
    ----------
    dat_x : (dmx, dmy, dmz, nchannels) tensor
        Input image
    dat_y : (dmx, dmy, dmz, nchannels) tensor
        Reconstruction image
    tau : (nchannels) tensor
        Channel-specific noise precisions
    lam : (nchannels) tensor
        Channel-specific regularisation values
    voxel_size : float or sequence[float], default=1
        Unit size used in the denominator of the gradient.
    side : {'c', 'f', 'b'}, default='f'
        * 'c': central finite differences
        * 'f': forward finite differences
        * 'b': backward finite differences
    bound : {'dct2', 'dct1', 'dst2', 'dst1', 'dft', 'repeat', 'zero'}, default='dct2'
        Boundary condition.

    Returns
    ----------
    nll_yx : tensor
        Loss function value (negative log-posterior)

    """
    # compute negative log-likelihood (SSQD fidelity term)
    nll_xy = 0.5 * torch.sum(tau * torch.sum(
        (dat_x - dat_y)**2, dim=(0, 1, 2)))
    # compute gradients of reconstruction, shape=(dmx, dmy, dmz, nchannels, dmgr)
    nll_y = diff(dat_y,
                 order=1,
                 dim=(0, 1, 2),
                 voxel_size=voxel_size,
                 side=side,
                 bound=bound)
    # modulate channels with regularisation
    nll_y = lam[None, None, None, :, None] * nll_y
    # compute negative log-prior (JTV regularisation term)
    nll_y = torch.sum(
        nll_y**2 + eps(),
        dim=-1)  # to gradient magnitudes (sum over gradient directions)
    nll_y = torch.sum(nll_y, dim=-1)  # sum over reconstruction channels
    nll_y = torch.sqrt(nll_y)
    nll_y = torch.sum(nll_y)  # sum over voxels
    # compute negative log-posterior (loss function)
    nll_yx = nll_xy + nll_y

    return nll_yx
Esempio n. 2
0
def fit_se_log(log_cov, sqdist):
    """Fit the amplitude and length-scale of a squared-exponential kernel

    Parameters
    ----------
    log_cov : (*batch, vox, vox)
        Log of the empirical covariance matrix
    sqdist : tuple[int] or (vox, vox) tensor
        If a tensor -> it is the pre-computed squared distance map
        If a tuple -> it is the shape and we build the distance map

    Returns
    -------
    sig : (*batch,) tensor
        Amplitude of the kernel
    lam : (*batch,) tensor
        Length-scale of the kernel

    """
    log_cov = torch.as_tensor(log_cov).clone()
    backend = utils.backend(log_cov)
    if not torch.is_tensor(sqdist):
        shape = sqdist
        sqdist = dist_map(shape, **backend)
    else:
        sqdist = sqdist.to(**backend).clone()

    # linear regression
    eps = constants.eps(log_cov.dtype)
    y = log_cov.reshape([-1, py.prod(sqdist.shape)])
    msk = torch.isfinite(y)
    y[~msk] = 0
    y0 = y.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True)
    y -= y0
    x = sqdist.flatten() * msk
    x0 = x.sum(-1, keepdim=True) / msk.sum(-1, keepdim=True)
    x -= x0
    b = (x * y).sum(-1) / x.square().sum(-1).clamp_min_(eps)
    a = y0 - b * x0
    a = a[..., 0]

    lam = b.reciprocal_().mul_(-0.5).sqrt_()
    sig = a.div_(2).exp_()
    return sig, lam
Esempio n. 3
0
def histeq(x, n=1024, dim=None):
    """Histogram equalization

    Notes
    -----
    .. The minimum and maximum values of the input tensor are preserved.
    .. A piecewise linear transform is applied so that the output
       quantiles match those of a "template" histogram.
    .. By default, the template histogram is flat.

    Parameters
    ----------
    x : tensor
        Input image
    n : int or tensor
        Number of bins or target histogram
    dim : [sequence of] int, optional
        Dimensions along which to compute the histogram. Default: all.

    Returns
    -------
    x : tensor
        Transformed image

    """
    x = torch.as_tensor(x)

    # compute target cumulative histogram
    if torch.is_tensor(n):
        other_hist = n
        n = len(other_hist)
    else:
        other_hist = x.new_full([n], 1 / n)
    other_hist += constants.eps(other_hist.dtype)
    other_hist = other_hist.cumsum(-1) / other_hist.sum(-1, keepdim=True)
    other_hist[..., -1] = 1

    # compute cumulative histogram
    min = math.min(x, dim=dim)
    max = math.max(x, dim=dim)
    batch_shape = min.shape
    hist = utils.histc(x, n, dim=dim, min=min, max=max)
    hist += constants.eps(hist.dtype)
    hist = hist.cumsum(-1) / hist.sum(-1, keepdim=True)
    hist[..., -1] = 1

    # match histograms
    hist = hist.reshape([-1])
    shift = _hist_to_quantile(other_hist[None], hist)
    shift = shift.reshape([-1, n])
    shift /= n

    # reshape
    shift = shift.reshape([*batch_shape, n])

    # interpolate and apply shift
    eps = constants.eps(x.dtype)
    grid = x.clone()
    grid = grid.mul_(n / (max - min + eps)).add_(n / (1 - max / min)).sub_(1)
    grid = grid.flatten()[:, None, None]
    shift = spatial.grid_pull(shift.reshape([-1, 1, n]),
                              grid,
                              bound='zero',
                              extrapolate=True)
    shift = shift.reshape(x.shape)
    x = (x - min) * shift + min

    return x
Esempio n. 4
0
def intensity_preproc(*images, min=None, max=None, eq=None):
    """(Joint) rescaling and intensity equalizing.

    Parameters
    ----------
    *images : (*batch, H, W) tensor
        Input (batch of) 2d images.
        All batch shapes should be broadcastable together.
    min : tensor_like, optional
        Minimum value. Should be broadcastable to batch.
        Default: 5th percentile of each batch element.
    max : tensor_like, optional
        Maximum value. Should be broadcastable to batch.
        Default: 95th percentile of each batch element.
    eq : {'linear', 'quadratic', 'log', None} or float, default=None
        Apply histogram equalization.
        If 'quadratic' or 'log', the histogram of the transformed signal
        is equalized.
        If float, the signal is taken to that power before being equalized.

    Returns
    -------
    *images : (*batch, H, W) tensor
        Preprocessed images.
        Intensities are scaled within [0, 1].

    """

    if len(images) == 1:
        images = [utils.to_max_backend(*images)]
    else:
        images = utils.to_max_backend(*images)
    backend = utils.backend(images[0])
    eps = constants.eps(images[0].dtype)

    # rescale min/max
    min = py.make_list(min, len(images))
    max = py.make_list(max, len(images))
    min = [
        utils.quantile(image, 0.05, bins=2048, dim=[-1, -2], keepdim=True)
        if mn is None else torch.as_tensor(mn, **backend)[None, None]
        for image, mn in zip(images, min)
    ]
    min, *othermin = min
    for mn in othermin:
        min = torch.min(min, mn)
    del othermin
    max = [
        utils.quantile(image, 0.95, bins=2048, dim=[-1, -2], keepdim=True)
        if mx is None else torch.as_tensor(mx, **backend)[None, None]
        for image, mx in zip(images, max)
    ]
    max, *othermax = max
    for mx in othermax:
        max = torch.max(max, mx)
    del othermax
    images = [torch.max(torch.min(image, max), min) for image in images]
    images = [
        image.mul_(1 / (max - min + eps)).add_(1 / (1 - max / min))
        for image in images
    ]

    if not eq:
        return tuple(images) if len(images) > 1 else images[0]

    # reshape and concatenate
    batch = utils.expanded_shape(*[image.shape[:-2] for image in images])
    images = [image.expand([*batch, *image.shape[-2:]]) for image in images]
    shapes = [image.shape[-2:] for image in images]
    chunks = [py.prod(s) for s in shapes]
    images = [image.reshape([*batch, c]) for image, c in zip(images, chunks)]
    images = torch.cat(images, dim=-1)

    if eq is True:
        eq = 'linear'
    if not isinstance(eq, str):
        if eq >= 0:
            images = images.pow(eq)
        else:
            images = images.clamp_min_(constants.eps(images.dtype)).pow(eq)
    elif eq.startswith('q'):
        images = images.square()
    elif eq.startswith('log'):
        images = images.clamp_min_(constants.eps(images.dtype)).log()

    images = histeq(images, dim=-1)

    if not (isinstance(eq, str) and eq.startswith('lin')):
        # rescale min/max
        images -= math.min(images, dim=-1, keepdim=True)
        images /= math.max(images, dim=-1, keepdim=True)

    images = images.split(chunks, dim=-1)
    images = [image.reshape(*batch, *s) for image, s in zip(images, shapes)]

    return tuple(images) if len(images) > 1 else images[0]
Esempio n. 5
0
def is_inside(points, vertices, faces=None):
    """Test if a point is inside a polygon/surface.

    The polygon or surface *must* be closed.

    Parameters
    ----------
    points : (..., dim) tensor
        Coordinates of points to test
    vertices : (nv, dim) tensor
        Vertex coordinates
    faces : (nf, dim) tensor[int]
        Faces are encoded by the indices of its vertices.
        By default, assume that vertices are ordered and define a closed curve

    Returns
    -------
    check : (...) tensor[bool]

    """
    # This function uses a ray-tracing technique:
    #
    #   A half-line is started in each point. If it crosses an even
    #   number of faces, it is inside the shape. If it crosses an even
    #   number of faces, it is not.
    #
    #   In practice, we loop through faces (as we expect there are much
    #   less vertices than voxels) and compute intersection points between
    #   all lines and each face in a batched fashion. We only want to
    #   send these rays in one direction, so we keep aside points whose
    #   intersection have a positive coordinate along the ray.

    points = torch.as_tensor(points)
    vertices = torch.as_tensor(vertices)
    if faces is None:
        faces = [(i, i + 1) for i in range(len(vertices) - 1)]
        faces += [(len(vertices) - 1, 0)]
        faces = utils.as_tensor(faces, dtype=torch.long)

    points, vertices = utils.to_max_dtype(points, vertices)
    points, vertices, faces = utils.to_max_device(points, vertices, faces)
    backend = utils.backend(points)
    batch = points.shape[:-1]
    dim = points.shape[-1]
    eps = constants.eps(points.dtype)
    cross = points.new_zeros(batch, dtype=torch.long)

    ray = torch.randn(dim, **backend)

    for face in faces:
        face = vertices[face]

        # compute normal vector
        origin = face[0]
        if dim == 3:
            u = face[1] - face[0]
            v = face[2] - face[0]
            norm = torch.stack([
                u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
                u[0] * v[1] - u[1] * v[0]
            ])
        else:
            assert dim == 2
            u = face[1] - face[0]
            norm = torch.stack([-u[1], u[0]])

        # check co-linearity between face and ray
        colinear = linalg.dot(ray,
                              norm).abs() / (ray.norm() * norm.norm()) < eps
        if colinear:
            continue

        # compute intersection between ray and plane
        #   plane: <norm, x - origin> = 0
        #   line: x = p + t*u
        #   => <norm, p + t*u - origin> = 0
        intersection = linalg.dot(norm, points - origin)
        intersection /= linalg.dot(norm, ray)
        halfmask = intersection >= 0  # we only want to shoot in one direction
        intersection = intersection[halfmask]
        halfpoints = points[halfmask]
        intersection = intersection[..., None] * (-ray)
        intersection += halfpoints

        # check if the intersection is inside the face
        #   first, we project it onto a frame of dimension `dim-1`
        #   defined by (origin, (u, v))
        intersection -= origin
        if dim == 3:
            interu = linalg.dot(intersection, u)
            interv = linalg.dot(intersection, v)
            intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1)
        else:
            intersection = linalg.dot(intersection, u)
            intersection /= u.norm().square_()
            intersection = (intersection >= 0) & (intersection < 1)

        cross[halfmask] += intersection

    # check that the number of crossings is even
    cross = cross.bitwise_and_(1).bool()
    return cross
Esempio n. 6
0
 def pnorm(x, dims=-1):
     """Normalize a tensor so that it's sum across `dims` is one."""
     dims = make_list(dims)
     x = x.clamp_min_(eps(x.dtype))
     x = x / nansum(x, dim=dims, keepdim=True)
     return x
Esempio n. 7
0
    def forward(self, x, y, **overload):
        """

        Parameters
        ----------
        x : tensor (batch, 1, *spatial)
        y : tensor (batch, 1, *spatial)
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        loss : scalar or tensor
            The output shape depends on the type of reduction used.
            If 'mean' or 'sum', this function returns a scalar.

        """
        # check inputs
        x = torch.as_tensor(x)
        y = torch.as_tensor(y)
        nb_dim = x.dim() - 2
        if x.shape[1] != 1 or y.shape[1] != 1:
            raise ValueError('Mutual info is only implemented for '
                             'single channel tensors.')
        shape = x.shape[2:]

        # get parameters
        min_val = overload.get('min_val', self.min_val)
        max_val = overload.get('max_val', self.max_val)
        nb_bins = overload.get('nb_bins', self.nb_bins)
        fwhm = overload.get('fwhm', self.fwhm)
        order = overload.get('order', self.order)
        normalize = overload.get('normalize', self.normalize)
        patch_size = overload.get('patch_size', self.patch_size)
        patch_stride = overload.get('patch_stride', self.patch_stride)
        mask = overload.get('mask', self.mask)

        # reshape
        if patch_size:
            # extract patches about each voxel
            patch_size = make_list(patch_size, nb_dim)
            patch_size = [
                min(pch or dim, dim) for pch, dim in zip(patch_size, shape)
            ]
            x = utils.unfold(x[:, 0], patch_size, patch_stride, collapse=True)
            y = utils.unfold(y[:, 0], patch_size, patch_stride, collapse=True)

        # collapse spatial dimensions -> we don't need them anymore
        x = x.reshape((*x.shape[:2], -1))
        y = y.reshape((*y.shape[:2], -1))

        # exclude masked values
        mask_x, mask_y = make_list(mask, 2)
        mask = None
        if callable(mask_x):
            mask = mask_x(x)
        elif mask_x is not None:
            mask = x <= mask_x
        if callable(mask_y):
            mask = (mask & mask_y(y)) if mask is not None else mask_y(y)
        elif mask_y is not None:
            mask = (mask &
                    (y <= mask_y)) if mask is not None else (y <= mask_y)

        if order == 'inf':
            p_xy = joint_hist_gaussian(x, y, nb_bins, min_val, max_val, fwhm,
                                       mask)
        else:
            p_xy = joint_hist_spline(x, y, nb_bins, min_val, max_val, order,
                                     mask)

        def pnorm(x, dims=-1):
            """Normalize a tensor so that it's sum across `dims` is one."""
            dims = make_list(dims)
            x = x.clamp_min_(eps(x.dtype))
            x = x / nansum(x, dim=dims, keepdim=True)
            return x

        # compute probabilities
        p_x = pnorm(p_xy.sum(dim=-2))  # -> [B, C, nb_bins]
        p_y = pnorm(p_xy.sum(dim=-1))  # -> [B, C, nb_bins]
        p_xy = pnorm(p_xy, [-1, -2])

        # compute entropies
        h_x = -(p_x * p_x.log()).sum(dim=-1)  # -> [B, C]
        h_y = -(p_y * p_y.log()).sum(dim=-1)  # -> [B, C]
        h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2])  # -> [B, C]

        # negative mutual information
        mi = h_xy - (h_x + h_y)

        # normalize
        if normalize == 'studholme':
            mi = mi / h_xy.clamp_min_(eps(x.dtype))
            mi += 1
        elif normalize not in (None, 'none'):
            normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \
                        (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \
                        torch.min if normalize == 'min' else \
                        torch.max if normalize == 'max' else \
                        normalize
            mi = mi / normalize(h_x, h_y).clamp_min_(eps(x.dtype))
            mi += 1

        # reduce
        return super().forward(mi)
Esempio n. 8
0
def joint_hist_gaussian(x, y, bins=64, min=None, max=None, fwhm=1, mask=None):
    """Compute joint histogram with Gaussian window

    Parameters
    ----------
    x : (batch, channel, voxels) tensor
    y : (batch, channel, voxels) tensor
    bins : int or (int, int), default=64
    min : float or (float, float), optional
    max : float or (float, float), optional
    fwhm : float or (float, float), default=1
    mask : (batch, channel, voxels) tensor, optional

    Returns
    -------
    h : (batch, channel, bins, bins)

    """
    backend = utils.backend(x)
    x_min, y_min = py.make_list(min, 2)
    x_max, y_max = py.make_list(max, 2)
    x_nbins, y_nbins = py.make_list(bins, 2)
    x_fwhm, y_fwhm = py.make_list(fwhm, 2)

    def get_bins(x, min, max, nbins):
        """Compute the histogram bins."""
        # TODO: It's suboptimal to have bin centers fall at the
        #   min and max. Better to shift them slightly inside.
        if mask is not None:
            # we set masked values to nan so that we can exclude them when
            # computing min/max
            val_nan = torch.as_tensor(nan, **backend)
            x = torch.where(mask, val_nan, x)
            min_fn = nanmin
            max_fn = nanmax
        else:
            min_fn = lambda *a, **k: torch.min(*a, **k).values
            max_fn = lambda *a, **k: torch.max(*a, **k).values
        min = min_fn(x, dim=-1) if min is None else min
        min = torch.as_tensor(min, **backend)
        min = unsqueeze(min, dim=2, ndim=4 - min.dim())
        # -> shape = [B, C, 1, 1]
        max = max_fn(x, dim=-1) if max is None else max
        max = torch.as_tensor(max, **backend)
        max = unsqueeze(max, dim=2, ndim=4 - max.dim())
        # -> shape = [B, C, 1, 1]
        bins = torch.linspace(0, 1, nbins, **backend)
        bins = unsqueeze(bins, dim=0, ndim=3)  # -> [1, 1, 1, nb_bins]
        bins = min + bins * (max - min)  # -> [B, C, 1, nb_bins]
        binwidth = (max - min) / (nbins - 1)  # -> [B, C, 1, 1]
        return bins, binwidth

    # prepare bins
    x_bins, x_binwidth = get_bins(x.detach(), x_min, x_max, x_nbins)
    y_bins, y_binwidth = get_bins(y.detach(), y_min, y_max, y_nbins)

    # we transform our nans into inf so that they get zero-weight
    # in the histogram
    if mask is not None:
        val_inf = torch.as_tensor(inf, **backend)
        x = torch.where(mask, val_inf, x)
        y = torch.where(mask, val_inf, y)

    # compute distances and collapse
    x = x[..., None]  # -> [B, C, N, 1]
    y = y[..., None]  # -> [B, C, N, 1]
    x_var = ((x_fwhm * x_binwidth)**2) / (8 * math.log(2))
    x_var = x_var.clamp(min=eps(x.dtype))
    x = -(x - x_bins).square() / (2 * x_var)
    x = x.exp()
    y_var = ((y_fwhm * y_binwidth)**2) / (8 * math.log(2))
    y_var = y_var.clamp(min=eps(y.dtype))
    y = -(y - y_bins).square() / (2 * y_var)
    y = y.exp()
    # -> [B, C, N, nb_bins]

    x = x.transpose(-1, -2)
    h = torch.matmul(x, y)  # -> [B, C, nb_bins, nb_bins]
    return h
Esempio n. 9
0
    def forward(self, x, y, **overload):
        """

        Parameters
        ----------
        x : tensor (batch, 1, *spatial)
        y : tensor (batch, 1, *spatial)
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        loss : scalar or tensor
            The output shape depends on the type of reduction used.
            If 'mean' or 'sum', this function returns a scalar.

        """
        # check inputs
        x = torch.as_tensor(x)
        y = torch.as_tensor(y)
        dtype = x.dtype
        device = x.device
        nb_dim = x.dim() - 2
        if x.shape[1] != 1 or y.shape[1] != 1:
            raise ValueError('Mutual info is only implemented for '
                             'single channel tensors.')
        shape = x.shape[2:]

        # get parameters
        x_min, y_min = make_list(overload.get('min_val', self.min_val), 2)
        x_max, y_max = make_list(overload.get('max_val', self.max_val), 2)
        x_nbins, y_nbins = make_list(overload.get('nb_bins', self.nb_bins), 2)
        x_fwhm, y_fwhm = make_list(overload.get('fwhm', self.fwhm), 2)
        normalize = overload.get('normalize', self.normalize)
        patch_size = overload.get('patch_size', self.patch_size)
        patch_stride = overload.get('patch_stride', self.patch_stride)
        mask = overload.get('mask', self.mask)

        # reshape
        if patch_size:
            # extract patches about each voxel
            patch_size = make_list(patch_size, nb_dim)
            patch_size = [pch or dim for pch, dim in zip(patch_size, shape)]
            patch_stride = make_list(patch_stride, nb_dim)
            patch_stride = [
                sz if st is None else st
                for sz, st in zip(patch_size, patch_stride)
            ]
            x = x[:, 0, ...]
            y = y[:, 0, ...]
            for d, (sz, st) in enumerate(zip(patch_size, patch_stride)):
                x = x.unfold(dimension=d + 1, size=sz, step=st)
                y = y.unfold(dimension=d + 1, size=sz, step=st)
            x = x.reshape((x.shape[0], -1, *patch_size))
            y = y.reshape((y.shape[0], -1, *patch_size))
            # now, the spatial dimension of x and y is `patch_size` and
            # their channel dimension is the number of patches
        # collapse spatial dimensions -> we don't need them anymore
        x = x.reshape((*x.shape[:2], -1))
        y = y.reshape((*y.shape[:2], -1))

        # exclude masked values
        mask_x, mask_y = make_list(mask, 2)
        mask = None
        if callable(mask_x):
            mask = mask_x(x)
        elif mask_x is not None:
            mask = x <= mask_x
        if callable(mask_y):
            mask = (mask & mask_y(y)) if mask is not None else mask_y(y)
        elif mask_y is not None:
            mask = (mask &
                    (y <= mask_y)) if mask is not None else (y <= mask_y)

        def get_bins(x, min, max, nbins):
            """Compute the histogram bins."""
            # TODO: It's suboptimal to have bin centers fall at the
            #   min and max. Better to shift them slightly inside.
            if mask is not None:
                # we set masked values to nan so that we can exclude them when
                # computing min/max
                val_nan = torch.as_tensor(nan, dtype=x.dtype, device=x.device)
                x = torch.where(mask, val_nan, x)
                min_fn = nanmin
                max_fn = nanmax
            else:
                min_fn = torch.min
                max_fn = torch.max
            min = min_fn(x, dim=-1).values if min is None else min
            min = torch.as_tensor(min, dtype=dtype, device=device)
            min = unsqueeze(min, dim=2, ndim=4 - min.dim())
            # -> shape = [B, C, 1, 1]
            max = max_fn(x, dim=-1).values if max is None else max
            max = torch.as_tensor(max, dtype=dtype, device=device)
            max = unsqueeze(max, dim=2, ndim=4 - max.dim())
            # -> shape = [B, C, 1, 1]
            bins = torch.linspace(0, 1, nbins, dtype=dtype, device=device)
            bins = unsqueeze(bins, dim=0, ndim=3)  # -> [1, 1, 1, nb_bins]
            bins = min + bins * (max - min)  # -> [B, C, 1, nb_bins]
            binwidth = (max - min) / (nbins - 1)  # -> [B, C, 1, 1]
            return bins, binwidth

        # prepare bins
        x_bins, x_binwidth = get_bins(x.detach(), x_min, x_max, x_nbins)
        y_bins, y_binwidth = get_bins(y.detach(), y_min, y_max, y_nbins)

        # we transform our nans into inf so that they get zero-weight
        # in the histogram
        if mask is not None:
            val_inf = torch.as_tensor(inf, dtype=x.dtype, device=x.device)
            x = torch.where(mask, val_inf, x)
            y = torch.where(mask, val_inf, y)

        # compute distances and collapse
        x = x[..., None]  # -> [B, C, N, 1]
        y = y[..., None]  # -> [B, C, N, 1]
        x_var = ((x_fwhm * x_binwidth)**2) / (8 * math.log(2))
        x_var = x_var.clamp(min=eps(x.dtype))
        x = -(x - x_bins).square() / (2 * x_var)
        x = x.exp()
        y_var = ((y_fwhm * y_binwidth)**2) / (8 * math.log(2))
        y_var = y_var.clamp(min=eps(y.dtype))
        y = -(y - y_bins).square() / (2 * y_var)
        y = y.exp()

        # -> [B, C, N, nb_bins]

        def pnorm(x, dims=-1):
            """Normalize a tensor so that it's sum across `dims` is one."""
            dims = make_list(dims)
            x = x.clamp(min=eps(x.dtype))
            x = x / nansum(x, dim=dims, keepdim=True)
            return x

        # compute probabilities
        p_x = pnorm(x.sum(dim=2))  # -> [B, C, nb_bins]
        p_y = pnorm(y.sum(dim=2))  # -> [B, C, nb_bins]
        x = x.transpose(-1, -2)  # -> [B, C, nb_bins, N]
        p_xy = torch.matmul(x, y)  # -> [B, C, nb_bins, nb_bins]
        p_xy = pnorm(p_xy, [-1, -2])

        # compute entropies
        h_x = -(p_x * p_x.log()).sum(dim=-1)  # -> [B, C]
        h_y = -(p_y * p_y.log()).sum(dim=-1)  # -> [B, C]
        h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2])  # -> [B, C]

        # negative mutual information
        mi = h_xy - (h_x + h_y)

        # normalize
        if normalize not in (None, 'none'):
            normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \
                        (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \
                        torch.min if normalize == 'min' else \
                        torch.max if normalize == 'max' else \
                        normalize
            mi = mi / normalize(h_x, h_y)
            mi += 1

        # reduce
        return super().forward(mi)