Beispiel #1
0
import torch
from nitorch import spatial
from nitorch.core import utils, math, py, constants, linalg
from nitorch.core.utils import _hist_to_quantile
from nitorch.core.optionals import try_import

plt = try_import('matplotlib.pyplot', _as=True)
mcolors = try_import('matplotlib.colors', _as=True)
from warnings import warn


def _get_colormap_cat(colormap, nb_classes, dtype=None, device=None):
    if colormap is None:
        if not plt:
            raise ImportError('Matplotlib not available')
        if nb_classes <= 10:
            colormap = plt.get_cmap('tab10')
        elif nb_classes <= 20:
            colormap = plt.get_cmap('tab20')
        else:
            warn('More than 20 classes: multiple classes will share'
                 'the same color.')
            colormap = plt.get_cmap('tab20')
    elif isinstance(colormap, str):
        colormap = plt.get_cmap(colormap)
    if isinstance(colormap, mcolors.Colormap):
        n = nb_classes
        colormap = [colormap(i / (n - 1))[:3] for i in range(n)]
    colormap = torch.as_tensor(colormap, dtype=dtype, device=device)
    return colormap
Beispiel #2
0
"""Module for affine image registration.

"""

import numpy as np
import torch
from nitorch.plot import show_slices
from nitorch.core.datasets import fetch_data
from nitorch.spatial import (affine_basis, voxel_size)
from nitorch.core.linalg import expm
from ._costs import (_costs_hist, _compute_cost)
from ._core import (_data_loader, _get_dat_grid, _get_mean_space, _fit_q)
from .._preproc_utils import _format_input
# Try import matplotlib.pyplot
from nitorch.core.optionals import try_import
plt = try_import('matplotlib.pyplot', _as=True)


def _affine_align(dat, mat, cost_fun='nmi', group='SE', mean_space=False,
                  samp=(3, 1.5), optimiser='powell', fix=0, verbose=False,
                  fov=None, mx_int=1023, raw=False, jitter=False, fwhm=7.0):
    """Affine registration of a collection of images.

    Parameters
    ----------
    dat : [N, ...], tensor_like
        List of image volumes.
    mat : [N, ...], tensor_like
        List of affine matrices.
    cost_fun : str, default='nmi'
        Pairwise methods:
Beispiel #3
0
import pytest
import torch
from nitorch.core.optionals import try_import
from nitorch.core import fft as nifft
pyfft = try_import('torch.fft', _as=True)

_torch_has_old_fft = nifft._torch_has_old_fft
_torch_has_complex = nifft._torch_has_complex
_torch_has_fft_module = nifft._torch_has_fft_module
_torch_has_fftshift = nifft._torch_has_fftshift

norms = ('forward', 'backward', 'ortho')
ndims = (1, 2, 3, 4)


@pytest.mark.parametrize('norm', norms)
def test_fft(norm):
    if not _torch_has_fft_module or not _torch_has_old_fft:
        return True

    nifft._torch_has_complex = False
    nifft._torch_has_fft_module = False
    nifft._torch_has_fftshift = False

    x = torch.randn([16, 32, 2], dtype=torch.doubl)
    f1 = pyfft.fft(torch.complex(x[..., 0], x[..., 1]), norm=norm)
    f2 = nifft.fft(x, norm=norm)
    f2 = torch.complex(f2[..., 0], f2[..., 1])
    assert torch.allclose(f1, f2)

    x = torch.randn([16, 32, 2], dtype=torch.doubl)
Beispiel #4
0
"""
FreeSurfer color table
"""
import os
from nitorch.core.optionals import try_import
wget = try_import('wget')
appdirs = try_import('appdirs')


def wget_check():
    if not wget:
        raise ImportError('wget needed to download dataset')
    return wget


_fs_lut_url = 'https://raw.githubusercontent.com/freesurfer/freesurfer/dev/freeview/resource/FreeSurferColorLUT.txt'
_datadir = appdirs.user_cache_dir('nitorch') if appdirs else '.'


def read_lut_fs(fname=None):
    """Read a freesurfer lookup table

    Parameters
    ----------
    fname : str, optional
        Path to the LUT file.
        By default, the full FreeSurfer LUT is used (FreeSurferColorLUT.txt)

    Returns
    -------
    lut : dict[int -> str]
Beispiel #5
0
import torch
from nitorch.core.optionals import try_import
from .. import generators, preproc
import os
import math
import gzip
import random
wget = try_import('wget')
appdirs = try_import('appdirs')
np = try_import('numpy')


def wget_check():
    if not wget:
        raise ImportError('wget needed to download dataset')
    return wget


class MNIST:
    url_base = 'http://yann.lecun.com/exdb/mnist/'
    base_train_images = 'train-images-idx3-ubyte.gz'
    base_train_labels = 'train-labels-idx1-ubyte.gz'
    base_test_images = 't10k-images-idx3-ubyte.gz'
    base_test_labels = 't10k-labels-idx1-ubyte.gz'
    datadir = appdirs.user_cache_dir('nitorch') if appdirs else '.'

    def _load1(self, base, magic=None):
        fname = os.path.join(self.datadir, base)
        if not os.path.exists(fname):
            os.makedirs(self.datadir, exist_ok=True)
            wget_check().download(self.url_base + base, fname)
Beispiel #6
0
import math
import time
import torch
from nitorch import io
from nitorch import spatial
from nitorch.core import utils, py
from nitorch.core.optionals import try_import
from .volumes import show_orthogonal_slices
# from .menu import Menu, MenuItem

# optional imports
plt = try_import('matplotlib.pyplot', _as=True)
gridspec = try_import('matplotlib.gridspec', _as=True)

__all__ = ['ImageViewer']


def ordered_set(*values):
    return tuple({v: None for v in values}.keys())


class ImageArtist:
    def __init__(self, image, parent=None, **kwargs):

        self.parent = parent
        self.show_cursor = kwargs.pop(
            'show_cursor', getattr(self.parent, 'show_cursor', True))
        self.equalize = kwargs.pop('equalize',
                                   getattr(self.parent, 'equalize', False))
        self.mode = kwargs.pop('mode', getattr(self.parent, 'mode',
                                               'intensity'))
Beispiel #7
0
from nitorch.core import math, utils
from nitorch.core.optionals import try_import

plt = try_import('matplotlib.pyplot')


def mov2fix(fixed, moving, warped, vel=None, cat=False, dim=None, title=None):
    """Plot registration live"""

    if plt is None:
        return

    warped = warped.detach()
    if vel is not None:
        vel = vel.detach()

    dim = dim or (fixed.dim() - 1)
    if fixed.dim() < dim + 2:
        fixed = fixed[None]
    if moving.dim() < dim + 2:
        moving = moving[None]
    if warped.dim() < dim + 2:
        warped = warped[None]
    if vel is not None:
        if vel.dim() < dim + 2:
            vel = vel[None]
    nb_channels = fixed.shape[-dim - 1]
    nb_batch = len(fixed)

    if dim == 3:
        fixed = fixed[..., fixed.shape[-1] // 2]
Beispiel #8
0
"""Utilities to build hyper-networks"""
from typing import Sequence, Optional
import copy
import torch
import torch.nn as tnn
from nitorch.core import py
from ..base import Module
from .conv import ActivationLike, NormalizationLike
from .linear import LinearBlock
from nitorch.core.optionals import try_import
functorch = try_import('functorch')


class HyperNet(Module):
    """
    Generic hypernetwork.
    An hyper-network is a network whose weights are generated dynamically
    by a meta-network from a set of input features.
    Its forward pass is: HyperNet(x, feat) = SubNet(MetaNet(feat))(x)
    """

    # TODO: we maybe want to make it easier for people to build
    #   specializations of this class where not all weights are
    #   instantiated by the hyper network, but are trainable instead.
    #   We could define a filter function that selects which submodules
    #   of the main network have hyper-weights. Or have a `parameters`
    #   argument (like in optimizers) that let the user specify
    #   which parameters are dynamic.

    def __init__(self,
                 in_features: int,
Beispiel #9
0
"""Plotting utilities for multi-dimensional tensors.

TODO
* Real-time plotting slows down with number of calls to show_slices!

"""

import torch
from nitorch import spatial
from nitorch.core import utils, py, math, linalg
from nitorch.core.optionals import try_import
# Try import matplotlib
plt = try_import('matplotlib.pyplot', _as=True)
gridspec = try_import('matplotlib.gridspec', _as=True)
make_axes_locatable = try_import(
    'mpl_toolkits.axes_grid1', keys='make_axes_locatable', _as=False)
from . import colormaps as cmap


def get_slice(image, dim=-1, index=None):
    """Extract a 2d slice from a 3d volume

    Parameters
    ----------
    image : (..., *shape3) tensor
        A (batch of) 3d volume
    dim : int, default=-1
        Index of the spatial dimension to slice
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract