Exemple #1
0
def geodesic_dist(x, w, conn=1, nb_iter=1, dim=None):
    """Geodesic distance to a label

    Parameters
    ----------
    x : (..., *spatial) tensor
    w : (..., *spatial) tensor
    conn : int
    nb_iter : int
    dim : int

    Returns
    -------
    y : (..., *spatial) tensor

    """
    in_dtype = x.dtype
    if in_dtype is not torch.bool:
        x = x > 0
    x = x.to(torch.uint8)
    dim = dim or x.dim()

    d = torch.full(x.shape, float('inf'), **utils.backend(w))
    d[x > 0] = 0
    crop = (Ellipsis,  *([slice(1, -1)]*dim))
    dcrop = utils.unfold(d, [3]*dim, stride=1)
    w = utils.unfold(w, [3]*dim, stride=1)
    for n_iter in range(1, nb_iter+1):
        w0 = w[(Ellipsis, *([1]*dim))]
        for coord in itertools.product([0, 1], repeat=dim):
            if sum(coord) == 0 or sum(coord) > conn:
                continue
            mini_dist = sum(c*c for c in coord) ** 0.5
            coords = set()
            for sgn in itertools.product([-1, 1], repeat=dim):
                coord1 = [1 + c*s for c, s in zip(coord, sgn)]
                if tuple(coord1) in coords:
                    continue
                coords.add(tuple(coord1))
                coord1 = (Ellipsis, *coord1)
                new_dist = (w[coord1] - w0).abs() * (dcrop[coord1] + mini_dist)
                new_dist.masked_fill_(torch.isfinite(new_dist).bitwise_not_(), float('inf'))
                msk = new_dist < d[crop]
                d[crop][msk] = new_dist[msk]
                print(d[crop].isfinite().sum())

    msk = torch.isfinite(d).bitwise_not_()
    d[msk] = d[~msk].max()
    return d
Exemple #2
0
    def forward(self, q, k, v, **overload):
        """

        Parameters
        ----------
        q : (b, c, *spatial)
            Queries
        k : (b, c, *spatial)
            Keys
        v : (b, c, *spatial)
            Values

        Returns
        -------
        x : (b, c, *spatial)

        """
        kernel_size = overload.pop('kernel_size', self.kernel_size)
        stride = overload.pop('stride', self.kernel_size)
        padding = overload.pop('padding', self.padding)
        padding_mode = overload.pop('padding_mode', self.padding_mode)

        dim = q.dim() - 2
        if padding == 'auto':
            k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode)
            v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode)
        elif padding:
            padding = [0] * 2 + py.make_list(padding, dim)
            k = utils.pad(k, padding, side='both', mode=padding_mode)
            v = utils.pad(v, padding, side='both', mode=padding_mode)

        # compute weights by query/key dot product
        kernel_size = py.make_list(kernel_size, dim)
        k = utils.unfold(k, kernel_size, stride)
        k = k.reshape([*k.shape[:dim + 2], -1])
        k = utils.movedim(k, 1, -1)
        q = utils.movedim(q[..., None], 1, -1)
        k = math.softmax(linalg.dot(k, q), dim=-1)
        k = k[:, None]  # add back channel dimension

        # compute new values by weight/value dot product
        v = utils.unfold(v, kernel_size, stride)
        v = v.reshape([*v.shape[:dim + 2], -1])
        v = linalg.dot(k, v)

        return v
Exemple #3
0
 def forward(self, x):
     shape = x.shape[2:]
     dim = len(shape)
     pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, self.kernel, collapse=True)
     x = x[:, :, torch.randperm(x.shape[2])]
     x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemple #4
0
 def forward(self, x):
     shape = x.shape[2:]
     dim = len(shape)
     pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, self.kernel, collapse=True)
     for n in range(self.nb_swap):
         i1, i2 = torch.randint(low=0, high=x.shape[2]-1, size=(2,)).tolist()
         x[:,:,i1], x[:,:,i2] = x[:,:,i2], x[:,:,i1]
     x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemple #5
0
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False):
    """Apply a GRAPPA kernel to an accelerated k-space

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    kspace : ([*batch], coils, *freq)
        Accelerated k-space
    patterns : (*freq) tensor[long]
        Code of sampling pattern about each k-space location
    kernel_size : sequence of int
        GRAPPA kernel size
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        Dictionary of GRAPPA kernels (keys are pattern codes)

    Returns
    -------
    kspace : ([*batch], coils, *freq)

    """
    ndim = patterns.dim()
    coils, *freq = kspace.shape[-ndim - 1:]
    batch = kspace.shape[:-ndim - 1]
    kernel_size = py.make_list(kernel_size, ndim)

    kspace_out = kspace
    if not inplace:
        kspace_out = kspace_out.clone()
    kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size],
                       side='both')
    kspace = utils.unfold(kspace, kernel_size, stride=1)

    def t(x):
        return x.transpose(-1, -2)

    for code, kernel in kernels.items():
        kernel = kernels[code]
        pattern = code_to_pattern(code, kernel_size, device=kspace.device)
        pattern_size = pattern.sum()
        mask = patterns == code
        kspace1 = kspace[..., mask, :, :][..., pattern]
        kspace1 = kspace1.transpose(-2, -3) \
                         .reshape([*batch, -1, coils * pattern_size])
        kernel = kernel.reshape([*batch, coils, coils * pattern_size])
        kspace1 = t(kspace1.matmul(t(kernel)))
        kspace_out[..., mask] = kspace1

    return kspace_out
Exemple #6
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim,
                                           **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim,
                                             **backend)

        kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
        shape = x.shape[2:]
        kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
        pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
        x = utils.unfold(x, kernel, collapse=True)
        x = x[:, :, torch.randperm(x.shape[2])]
        x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
Exemple #7
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend)

        shape = x.shape[2:]
        for n in range(self.nb_drop):
            kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
            kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
            pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
            x = utils.unfold(x, kernel, collapse=True)
            i1 = torch.randint(low=0, high=x.shape[2]-1, size=(1,)).item()
            x[:,:,i1] = 0
            x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
Exemple #8
0
def get_pattern_codes(sampling_mask, kernel_size):
    """Compute the pattern's code about each voxel

    Parameters
    ----------
    sampling_mask : (*freq) tensor[bool]
    kernel_size : [sequence of] int

    Returns
    -------
    pattern_mask : (*freq) tensor[long]

    """
    ndim = sampling_mask.dim()
    kernel_size = py.make_list(kernel_size, ndim)
    sampling_mask = sampling_mask.long()
    sampling_mask = utils.pad(sampling_mask,
                              [(k - 1) // 2 for k in kernel_size],
                              side='both')
    sampling_mask = utils.unfold(sampling_mask, kernel_size, stride=1)
    return pattern_to_code(sampling_mask, ndim)
Exemple #9
0
def get_patches(volume, patch=3, stride=1):
    """Extract patches from an image/volume.

    Parameters
    ----------
    volume : (batch, *shape) tensor_like
    patch : int, default=3
    stride : int, default=1

    Returns
    -------
    patched_volume : (nb_patches, batch, *patch_shape)

    """
    dim = len(volume.shape) - 1
    patch = utils.make_list(patch, dim)
    patch = utils.make_list(patch, dim)

    volume = utils.unfold(volume, patch, stride, True)
    volume = volume.transpose(0, 1)
    return volume
Exemple #10
0
    def forward(self, x, output_padding=None, output_shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *in_spatial) tensor
        output_padding : [sequence of] int, default=self.output_padding
        output_shape : [sequence of] int, default=self.output_shape

        Returns
        -------
        x : (batch, channel, *out_spatial) tensor

        """
        dim = x.dim() - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)

        new_shape = self.shape(x, output_padding=output_padding,
                               output_shape=output_shape)
        y = x.new_zeros(new_shape)
        if self.fill:
            z = utils.unfold(y, stride)
            x = utils.unsqueeze(x, -1, dim)
            slicer = [slice(o, o+sz*st) for sz, st, o in
                      zip(x.shape[2:], stride, offset)]
            slicer = [slice(None)]*2 + slicer
            subz = z[tuple(slicer)]
            slicer = [slice(mx) for mx in subz.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            subz.copy_(x[tuple(slicer)])
        else:
            slicer = [slice(o, None, s) for o, s in zip(offset, stride)]
            slicer = [slice(None)]*2 + slicer
            suby = y[tuple(slicer)]
            slicer = [slice(mx) for mx in suby.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            suby.copy_(x[tuple(slicer)])

        return y
Exemple #11
0
 def forward(self, x, model, **fwdargs):
     shape = x.shape[2:]
     dim = len(shape)
     if isinstance(self.patch_size, int):
         patch_size = [self.patch_size] * dim
     else:
         patch_size = self.patch_size
     if isinstance(self.stride, int):
         stride = [self.stride] * dim
     else:
         stride = self.stride
     pshape = [x+(k-x%s) for x,k,s in zip(shape,patch_size,stride)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, kernel_size=self.patch_size, stride=self.stride, collapse=True)
     x = torch.split(x, 1, dim=2)
     x = [x_.reshape(tuple(x_.shape[:2])+tuple(x_.shape[3:])) for x_ in x]
     x = [model(x_, **fwdargs) for x_ in x]
     x = [x_.unsqueeze(dim=2) for x_ in x]
     x = torch.cat(x, dim=2)
     x = utils.fold(x, dim=dim, stride=self.stride, collapsed=True, shape=pshape, reduction=self.reduction)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemple #12
0
def kernel_fit(calib, kernel_size, patterns, lam=0.01):
    """Compute GRAPPA kernels

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    calib : ([*batch], coils, *freq)
        Fully-sampled calibration data
    kernel_size : sequence[int]
        GRAPPA kernel size
    patterns : (N,) tensor[int]
        Code of patterns for which to learn a kernel.
        See `pattern_to_code`.
    lam : float, default=0.01
        Tikhonov regularization

    Returns
    -------
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        GRAPPA kernels

    """
    kernel_size = py.make_list(kernel_size)
    ndim = len(kernel_size)
    coils, *freq = calib.shape[-ndim - 1:]
    batch = calib.shape[:-ndim - 1]

    # find all possible patterns
    patterns = utils.as_tensor(patterns, device=calib.device)
    if patterns.dtype is torch.bool:
        patterns = pattern_to_code(patterns, ndim)
    patterns = patterns.flatten()

    # learn one kernel for each pattern
    calib = utils.unfold(calib, kernel_size, collapse=True)  # [*B, C, N, *K]
    calib = utils.movedim(calib, -ndim - 1, -ndim - 2)  # [*B, N, C, *K]

    def t(x):
        return x.transpose(-1, -2)

    def conjt(x):
        return t(x).conj()

    def diag(x):
        return x.diagonal(0, -1, -2)

    kernels = {}
    center = [(k - 1) // 2 for k in kernel_size]
    center = (Ellipsis, *center)
    for pattern_code in patterns:
        if code_has_center(pattern_code, kernel_size):
            continue
        pattern = code_to_pattern(pattern_code,
                                  kernel_size,
                                  device=calib.device)
        pattern_size = pattern.sum()
        if pattern_size == 0:
            continue

        calib_target = calib[center]  # [*B, N, C]
        calib_source = calib[..., pattern]  # [*B, N, C, P]
        calib_size = calib_target.shape[-2]
        flat_shape = [*batch, calib_size, pattern_size * coils]
        calib_source = calib_source.reshape(flat_shape)  # [*B, N, C*P]
        # solve
        H = conjt(calib_source).matmul(calib_source)  # [*B, C*P, C*P]
        diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values)
        diag(H).add_(lam)
        g = conjt(calib_source).matmul(calib_target)  # [*B, C*P, C]
        k = linalg.lmdiv(H, g).transpose(-1, -2)  # [*B, C, C*P]
        k = k.reshape([*batch, coils, coils, pattern_size])  # [*B, C, C, P]
        kernels[pattern_code.item()] = k

    return kernels
Exemple #13
0
def dilate_likely_voxels(labels, intensity, label=None, nb_iter=1,
                         dist_ratio=1, half_patch=3, conn=1, dim=None):
    """Dilate labels into voxels with a similar intensity.

    Notes
    -----
    .. Voxels get switched if their intensity is closer to the foreground
       intensity than to the background intensity, in terms of Gaussian
       distance (abs(intensity - class_mean)/class_std) computed in a local
       patch.
    .. Adapted from neurite-sandbox (author: B Fischl)

    Parameters
    ----------
    labels : (..., *spatial) tensor
        Tensor of labels
    intensity : (..., *spatial) tensor
        Tensor of intensities
    label : int, optional
        Label to dilate. Default: binarize input labels.
    nb_iter : int, default=1
        Number of iterations
    dist_ratio : float, default=1
        Value that decides how much closer from the foreground intensity
        than the background intensity a voxel must be to be flipped.
        Smaller == easier to switch.
    half_patch : int, default=3
        Half-size of the window used to compute intensity statistics.
    conn : int, default=1
        Connectivity order
    dim : int, default=`labels.dim()`
        Number of spatial dimensions

    Returns
    -------
    labels : (..., *spatial) tensor
        Dilated labels

    """
    in_dtype = labels.dtype
    foreground = (labels > 0) if label is None else (labels == label)
    dim = dim or labels.dim()

    patch = [2*half_patch + 1] * dim
    unfold = lambda x: utils.unfold(x, patch, stride=1)
    intensity = unfold(intensity)

    def mean_var(intensity, fg):
        """Compute mean and variance"""
        sum_fg = fg.sum(list(range(-dim, 0)))
        mean_fg = (intensity * fg).sum(list(range(-dim, 0)))
        var_fg = (intensity.square() * fg).sum(list(range(-dim, 0)))
        mean_fg /= sum_fg
        var_fg /= sum_fg
        var_fg -= mean_fg.square()
        return mean_fg, var_fg

    if isinstance(conn, int):
        conn = connectivity_kernel(dim, conn, device=foreground.device,
                                   dtype=torch.uint8)

    for n_iter in range(nb_iter):
        dilated = dilate(foreground, conn=conn, dim=dim)
        dilated = dilated.bitwise_xor_(foreground)

        # Extract patches0
        center = (Ellipsis, *([half_patch]*dim))
        win_dilated = unfold(dilated)
        msk_dilated = win_dilated[center]
        win_dilated = win_dilated[msk_dilated, ...]
        win_intensity = intensity[msk_dilated, ...]
        win_fg = unfold(foreground)[msk_dilated, ...]
        win_bg = ~(win_fg | win_dilated)

        # compute statistics
        mean_fg, var_fg = mean_var(win_intensity, win_fg)
        mean_bg, var_bg = mean_var(win_intensity, win_bg)

        # compute criterion
        crit = dist_ratio * mean_fg < mean_bg
        win_intensity = win_intensity[center]
        mean_fg.sub_(win_intensity).abs_().div_(var_fg.sqrt_())
        mean_bg.sub_(win_intensity).abs_().div_(var_bg.sqrt_())

        # set value
        win_fg[center].masked_fill_(crit, 1)
        unfold(foreground)[msk_dilated, ...] = win_fg

    if label is None:
        labels = foreground.to(in_dtype)
    else:
        labels = labels.clone()
        labels[foreground] = label
    return labels
Exemple #14
0
def extract_patches(inp, size=64, stride=None, output=None, transform=None):
    """Extracgt patches from a 3D volume.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, default=64
        Patch size.
    stride : [sequence of] int, default=size
        Stride between patches.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}_{j}_{k}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str] or (tensor, tensor)
        If the input is a path, the output paths are returned.
        Else, the unfolded data and orientation matrices are returned.
            Data will have shape (nx, ny, nz, *size, *channels).
            Affines will have shape (nx, ny, nz, 4, 4).

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}_{j}_{k}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp

    shape = dat.shape[:3]
    size = py.make_list(size, 3)
    stride = py.make_list(stride, 3)
    stride = [st or sz for st, sz in zip(stride, size)]

    dat = utils.movedim(dat, [0, 1, 2], [-3, -2, -1])
    dat = utils.unfold(dat, size, stride)
    dat = utils.movedim(dat, [-6, -5, -4, -3, -2, -1], [0, 1, 2, 3, 4, 5])

    aff = aff0.new_empty(dat.shape[:3] + aff0.shape)
    for i in range(dat.shape[0]):
        for j in range(dat.shape[1]):
            for k in range(dat.shape[2]):
                index = (i, j, k)
                sub = [slice(st*idx, st*idx + sz)
                       for st, sz, idx in zip(stride, size, index)]
                aff[i, j, k], _ = spatial.affine_sub(aff0, shape, tuple(sub))

    formatted_output = []
    if output:
        output = py.make_list(output, py.prod(dat.shape[:3]))
        formatted_output = []
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    out1 = output.pop(0)
                    if is_file:
                        out1 = out1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                        io.volumes.savef(dat[i, j, k], out1, like=fname,
                                         affine=aff[i, j, k])
                    else:
                        out1 = out1.format(sep=os.path.sep, i=i, j=j, k=k)
                        io.volumes.savef(dat[i, j, k], out1, affine=aff[i, j, k])
                    formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, py.prod(dat.shape[:3]))
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    trf1 = transform.pop(0)
                    if is_file:
                        trf1 = trf1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    else:
                        trf1 = trf1.format(sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    io.transforms.savef(torch.eye(4), trf1,
                                        source=aff0, target=aff[i, j, k])

    if is_file:
        return formatted_output
    else:
        return dat, aff
Exemple #15
0
def mov2fix(fixed, moving, warped, vel=None, cat=False, dim=None, title=None):
    """Plot registration live"""

    if plt is None:
        return

    warped = warped.detach()
    if vel is not None:
        vel = vel.detach()

    dim = dim or (fixed.dim() - 1)
    if fixed.dim() < dim + 2:
        fixed = fixed[None]
    if moving.dim() < dim + 2:
        moving = moving[None]
    if warped.dim() < dim + 2:
        warped = warped[None]
    if vel is not None:
        if vel.dim() < dim + 2:
            vel = vel[None]
    nb_channels = fixed.shape[-dim - 1]
    nb_batch = len(fixed)

    if dim == 3:
        fixed = fixed[..., fixed.shape[-1] // 2]
        moving = moving[..., moving.shape[-1] // 2]
        warped = warped[..., warped.shape[-1] // 2]
        if vel is not None:
            vel = vel[..., vel.shape[-2] // 2, :]
    if vel is not None:
        vel = vel.square().sum(-1).sqrt()

    if cat:
        moving = math.softmax(moving, dim=1, implicit=True)
        warped = math.softmax(warped, dim=1, implicit=True)

    checker = fixed.clone()
    patch = max([s // 8 for s in fixed.shape])
    checker_unfold = utils.unfold(checker, [patch] * 2, [2 * patch] * 2)
    warped_unfold = utils.unfold(warped, [patch] * 2, [2 * patch] * 2)
    checker_unfold.copy_(warped_unfold)

    nb_rows = min(nb_batch, 3)
    nb_cols = 4 + (vel is not None)
    for b in range(nb_rows):
        plt.subplot(nb_rows, nb_cols, b * nb_cols + 1)
        plt.imshow(moving[b, 0].cpu())
        plt.title('moving')
        plt.axis('off')
        plt.subplot(nb_rows, nb_cols, b * nb_cols + 2)
        plt.imshow(warped[b, 0].cpu())
        plt.title('moved')
        plt.axis('off')
        plt.subplot(nb_rows, nb_cols, b * nb_cols + 3)
        plt.imshow(checker[b, 0].cpu())
        plt.title('checker')
        plt.axis('off')
        plt.subplot(nb_rows, nb_cols, b * nb_cols + 4)
        plt.imshow(fixed[b, 0].cpu())
        plt.title('fixed')
        plt.axis('off')
        if vel is not None:
            plt.subplot(nb_rows, nb_cols, b * nb_cols + 5)
            plt.imshow(vel[b].cpu())
            plt.title('velocity')
            plt.axis('off')
            plt.colorbar()
    if title:
        plt.suptitle(title)
    plt.gcf().canvas.flush_events()
    plt.show(block=False)
Exemple #16
0
    def forward(self, x, y, **overload):
        """

        Parameters
        ----------
        x : tensor (batch, 1, *spatial)
        y : tensor (batch, 1, *spatial)
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        loss : scalar or tensor
            The output shape depends on the type of reduction used.
            If 'mean' or 'sum', this function returns a scalar.

        """
        # check inputs
        x = torch.as_tensor(x)
        y = torch.as_tensor(y)
        nb_dim = x.dim() - 2
        if x.shape[1] != 1 or y.shape[1] != 1:
            raise ValueError('Mutual info is only implemented for '
                             'single channel tensors.')
        shape = x.shape[2:]

        # get parameters
        min_val = overload.get('min_val', self.min_val)
        max_val = overload.get('max_val', self.max_val)
        nb_bins = overload.get('nb_bins', self.nb_bins)
        fwhm = overload.get('fwhm', self.fwhm)
        order = overload.get('order', self.order)
        normalize = overload.get('normalize', self.normalize)
        patch_size = overload.get('patch_size', self.patch_size)
        patch_stride = overload.get('patch_stride', self.patch_stride)
        mask = overload.get('mask', self.mask)

        # reshape
        if patch_size:
            # extract patches about each voxel
            patch_size = make_list(patch_size, nb_dim)
            patch_size = [
                min(pch or dim, dim) for pch, dim in zip(patch_size, shape)
            ]
            x = utils.unfold(x[:, 0], patch_size, patch_stride, collapse=True)
            y = utils.unfold(y[:, 0], patch_size, patch_stride, collapse=True)

        # collapse spatial dimensions -> we don't need them anymore
        x = x.reshape((*x.shape[:2], -1))
        y = y.reshape((*y.shape[:2], -1))

        # exclude masked values
        mask_x, mask_y = make_list(mask, 2)
        mask = None
        if callable(mask_x):
            mask = mask_x(x)
        elif mask_x is not None:
            mask = x <= mask_x
        if callable(mask_y):
            mask = (mask & mask_y(y)) if mask is not None else mask_y(y)
        elif mask_y is not None:
            mask = (mask &
                    (y <= mask_y)) if mask is not None else (y <= mask_y)

        if order == 'inf':
            p_xy = joint_hist_gaussian(x, y, nb_bins, min_val, max_val, fwhm,
                                       mask)
        else:
            p_xy = joint_hist_spline(x, y, nb_bins, min_val, max_val, order,
                                     mask)

        def pnorm(x, dims=-1):
            """Normalize a tensor so that it's sum across `dims` is one."""
            dims = make_list(dims)
            x = x.clamp_min_(eps(x.dtype))
            x = x / nansum(x, dim=dims, keepdim=True)
            return x

        # compute probabilities
        p_x = pnorm(p_xy.sum(dim=-2))  # -> [B, C, nb_bins]
        p_y = pnorm(p_xy.sum(dim=-1))  # -> [B, C, nb_bins]
        p_xy = pnorm(p_xy, [-1, -2])

        # compute entropies
        h_x = -(p_x * p_x.log()).sum(dim=-1)  # -> [B, C]
        h_y = -(p_y * p_y.log()).sum(dim=-1)  # -> [B, C]
        h_xy = -(p_xy * p_xy.log()).sum(dim=[-1, -2])  # -> [B, C]

        # negative mutual information
        mi = h_xy - (h_x + h_y)

        # normalize
        if normalize == 'studholme':
            mi = mi / h_xy.clamp_min_(eps(x.dtype))
            mi += 1
        elif normalize not in (None, 'none'):
            normalize = (lambda a, b: (a+b)/2) if normalize == 'arithmetic' else \
                        (lambda a, b: (a*b).sqrt()) if normalize == 'geometric' else \
                        torch.min if normalize == 'min' else \
                        torch.max if normalize == 'max' else \
                        normalize
            mi = mi / normalize(h_x, h_y).clamp_min_(eps(x.dtype))
            mi += 1

        # reduce
        return super().forward(mi)
Exemple #17
0
    def mov2fix(self, fixed, moving, warped, vel=None, cat=False, dim=None, title=None):
        """Plot registration live"""

        import time
        tic = self._last_plot
        toc = time.time()
        if toc - tic < 1/self.framerate:
            return
        self._last_plot = toc

        import matplotlib.pyplot as plt


        warped = warped.detach()
        if vel is not None:
            vel = vel.detach()

        dim = dim or (fixed.dim() - 1)
        if fixed.dim() < dim + 2:
            fixed = fixed[None]
        if moving.dim() < dim + 2:
            moving = moving[None]
        if warped.dim() < dim + 2:
            warped = warped[None]
        if vel is not None:
            if vel.dim() < dim + 2:
                vel = vel[None]
        nb_channels = fixed.shape[-dim - 1]
        nb_batch = len(fixed)

        def rescale2d(x):
            if not x.dtype.is_floating_point:
                x = x.float()
            mn, mx = utils.quantile(x, [0.005, 0.995],
                                    dim=range(-2, 0), bins=1024).unbind(-1)
            mx = mx.max(mn + 1e-8)
            mn, mx = mn[..., None, None], mx[..., None, None]
            x = x.sub(mn).div_(mx-mn).clamp_(0, 1)
            return x

        if dim == 3:
            fixed = [fixed[..., fixed.shape[-1] // 2],
                     fixed[..., fixed.shape[-2] // 2, :],
                     fixed[..., fixed.shape[-3] // 2, :, :]]
            fixed = [rescale2d(f) for f in fixed]
            moving = [moving[..., moving.shape[-1] // 2],
                      moving[..., moving.shape[-2] // 2, :],
                      moving[..., moving.shape[-3] // 2, :, :]]
            moving = [rescale2d(f) for f in moving]
            warped = [warped[..., warped.shape[-1] // 2],
                      warped[..., warped.shape[-2] // 2, :],
                      warped[..., warped.shape[-3] // 2, :, :]]
            warped = [rescale2d(f) for f in warped]
            if vel is not None:
                vel = [vel[..., vel.shape[-2] // 2, :],
                       vel[..., vel.shape[-3] // 2, :, :],
                       vel[..., vel.shape[-4] // 2, :, :, :]]
                vel = [v.square().sum(-1).sqrt() for v in vel]
        else:
            fixed = [rescale2d(f) for f in fixed]
            moving = [rescale2d(f) for f in moving]
            warped = [rescale2d(f) for f in warped]
            vel = [vel.square().sum(-1).sqrt()] if vel is not None else []

        if cat:
            moving = [math.softmax(img, dim=1, implicit=True) for img in moving]
            warped = [math.softmax(img, dim=1, implicit=True) for img in warped]

        checker = []
        for f, w in zip(fixed, warped):
            patch = max([s // 8 for s in f.shape])
            patch = [max(min(patch, s), 1) for s in f.shape]
            broad_shape = utils.expanded_shape(f.shape, w.shape)
            f = f.expand(broad_shape).clone()
            w = w.expand(broad_shape)
            checker_unfold = utils.unfold(f, patch, [2*p for p in patch])
            warped_unfold = utils.unfold(w, patch, [2*p for p in patch])
            checker_unfold.copy_(warped_unfold)
            checker.append(f)

        kdim = 3 if dim == 3 else 1
        bdim = min(nb_batch, 3)
        nb_rows = kdim * bdim + 1
        nb_cols = 4 + bool(vel)

        if len(self.figure.axes) != nb_rows*nb_cols:
            self.figure.clf()

            for b in range(bdim):
                for k in range(kdim):
                    plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 1)
                    plt.imshow(moving[k][b, 0].cpu())
                    if b == 0 and k == 0:
                        plt.title('moving')
                    plt.axis('off')
                    plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 2)
                    plt.imshow(warped[k][b, 0].cpu())
                    if b == 0 and k == 0:
                        plt.title('moved')
                    plt.axis('off')
                    plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 3)
                    plt.imshow(checker[k][b, 0].cpu())
                    if b == 0 and k == 0:
                        plt.title('checker')
                    plt.axis('off')
                    plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 4)
                    plt.imshow(fixed[k][b, 0].cpu())
                    if b == 0 and k == 0:
                        plt.title('fixed')
                    plt.axis('off')
                    if vel:
                        plt.subplot(nb_rows, nb_cols, (b + k*bdim) * nb_cols + 5)
                        plt.imshow(vel[k][b].cpu())
                        if b == 0 and k == 0:
                            plt.title('displacement')
                        plt.axis('off')
                        plt.colorbar()
            plt.subplot(nb_rows, 1, nb_rows)
            plt.plot(list(range(1, len(self.all_ll)+1)), self.all_ll)
            plt.ylabel('NLL')
            plt.xlabel('iteration')
            if title:
                plt.suptitle(title)

            self.figure.canvas.draw()
            self.plt_saved = [self.figure.canvas.copy_from_bbox(ax.bbox)
                              for ax in self.figure.axes]
            self.figure.canvas.flush_events()
            plt.show(block=False)

        else:
            self.figure.canvas.draw()
            for elem in self.plt_saved:
                self.figure.canvas.restore_region(elem)

            for b in range(bdim):
                for k in range(kdim):
                    j = (b + k*bdim) * nb_cols
                    self.figure.axes[j].images[0].set_data(moving[k][b, 0].cpu())
                    self.figure.axes[j+1].images[0].set_data(warped[k][b, 0].cpu())
                    self.figure.axes[j+2].images[0].set_data(checker[k][b, 0].cpu())
                    self.figure.axes[j+3].images[0].set_data(fixed[k][b, 0].cpu())
                    if vel is not None:
                        self.figure.axes[j+4].images[0].set_data(vel[k][b].cpu())
            lldata = (list(range(1, len(self.all_ll)+1)), self.all_ll)
            self.figure.axes[-1].lines[0].set_data(lldata)
            if title:
                self.figure._suptitle.set_text(title)

            for ax in self.figure.axes:
                ax.draw_artist(ax.images[0])
                self.figure.canvas.blit(ax.bbox)
            self.figure.canvas.flush_events()