Exemplo n.º 1
0
    def shape(self, x):
        """Output shape of the equivalent forward call.

        Parameters
        ----------
        x : (b, c, **spatial) tensor or sequence[int]
            A tensor or its shape.

        Returns
        -------
        shape : tuple[int]
            (b, c, **spatial_out)
            In each dimension, the output shape is `(size-offset)//stride`.

        """
        if torch.is_tensor(x):
            x = x.shape
        x = list(x)

        dim = len(x) - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)

        x = x[:2] + [(xx-o)//s for xx, o, s in zip(x[2:], offset, stride)]
        return tuple(x)
Exemplo n.º 2
0
    def shape(self, x, output_padding=None, output_shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *in_spatial) tensor or sequence[int]
        output_padding : [sequence of] int, default=self.output_padding
        output_shape : [sequence of] int, default=self.output_shape

        Returns
        -------
        shape : tuple[int]

        """
        if torch.is_tensor(x):
            x = x.shape
        x = list(x)

        if output_padding is not None and output_shape is not None:
            raise ValueError('Only one of `output_padding` or `output_shape` '
                             'should be provided.')
        elif output_padding is None and output_shape is None:
            output_padding = self.output_padding
            output_shape = self.output_shape
        dim = len(x) - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)
        output_padding = py.make_list(output_padding, dim)

        if output_shape:
            output_shape = py.make_list(output_shape, dim)
        else:
            output_shape = [sz*st + o + p for sz, st, o, p in
                            zip(x[2:], stride, offset, output_padding)]
        return (*x[:2], *output_shape)
Exemplo n.º 3
0
def _patch(patch, affine, shape, level):
    """Compute the patch size in voxels"""
    dim = affine.shape[-1] - 1
    patch = py.make_list(patch)
    unit = 'pct'
    if isinstance(patch[-1], str):
        *patch, unit = patch
    patch = py.make_list(patch, dim)
    unit = unit.lower()
    if unit[0] == 'v':  # voxels
        patch = [float(p) / 2**level for p in patch]
    elif unit in ('m', 'mm', 'cm', 'um'):  # assume RAS orientation
        factor = (1e-3 if unit == 'um' else
                  1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1)
        affine_ras = spatial.affine_reorient(affine, layout='RAS')
        vx_ras = spatial.voxel_size(affine_ras).tolist()
        patch = [factor * p / v for p, v in zip(patch, vx_ras)]
        patch = _ras_to_layout(patch, affine)
    elif unit[0] in 'p%':  # percentage of shape
        patch = [0.01 * p * s for p, s in zip(patch, shape)]
    else:
        raise ValueError('Unknown patch unit:', unit)

    # round down to zero small patch sizes
    patch = [0 if p < 1e-3 else p for p in patch]
    return patch
Exemplo n.º 4
0
def _get_default_native(affines, shapes):
    """Get default native space

    Parameters
    ----------
    affines : [sequence of] (4, 4) tensor_like or None
    shapes : [sequence of] (3,) tensor_like

    Returns
    -------
    affines : (N, 4, 4) tensor
    shapes : (N, 3) tensor

    """
    shapes = utils.as_tensor(shapes).reshape([-1, 3])
    shapes = shapes.unbind(dim=0)
    if torch.is_tensor(affines):
        affines = affines.reshape([-1, 4, 4])
        affines = affines.unbind(dim=0)
    shapes = py.make_list(shapes, max(len(shapes), len(affines)))
    affines = py.make_list(affines, max(len(shapes), len(affines)))

    # default affines
    affines = [spatial.affine_default(shape) if affine is None else affine
               for shape, affine in zip(shapes, affines)]

    affines = utils.as_tensor(affines)
    shapes = utils.as_tensor(shapes)
    affines, shapes = utils.to_max_device(affines, shapes)
    return affines, shapes
Exemplo n.º 5
0
def get_kernel(kernel, affine, shape, level):
    """Convert the provided kernel size (RAS mm or pct) to native voxels"""
    dim = affine.shape[-1] - 1
    kernel = py.make_list(kernel)
    unit = 'pct'
    if isinstance(kernel[-1], str):
        *kernel, unit = kernel
    kernel = py.make_list(kernel, dim)
    unit = unit.lower()
    if unit[0] == 'v':  # voxels
        kernel = [p / 2**level for p in kernel]
    elif unit in ('m', 'mm', 'cm', 'um'):  # assume RAS orientation
        factor = (1e-3 if unit == 'um' else
                  1e1 if unit == 'cm' else
                  1e3 if unit == 'm' else
                  1)
        affine_ras = spatial.affine_reorient(affine, layout='RAS')
        vx_ras = spatial.voxel_size(affine_ras).tolist()
        kernel = [factor * p / v for p, v in zip(kernel, vx_ras)]
        kernel = ras_to_layout(kernel, affine)
    elif unit[0] in 'p%':    # percentage of shape
        kernel = [0.01 * p * s for p, s in zip(kernel, shape)]
    else:
        raise ValueError('Unknown patch unit:', unit)

    # ensure patch size is an integer >= 2 (else, no gradients)
    kernel = list(map(lambda x: max(int(pymath.ceil(x)), 2), kernel))
    return kernel
Exemplo n.º 6
0
 def __init__(self, dat, affine=None, dim=None, mask=None,
              bound='dct2', extrapolate=False, **backend):
     # I don't call super().__init__() on purpose
     if torch.is_tensor(affine):
         affine = [affine] * len(dat)
     elif affine is None:
         affine = []
         for dat1 in dat:
             dim1 = dim or dat1.dim
             if callable(dim1):
                 dim1 = dim1()
             if hasattr(dat1, 'affine'):
                 aff1 = dat1.affine
             else:
                 shape1 = dat1.shape[-dim1:]
                 aff1 = spatial.affine_default(shape1, **utils.backend(dat1))
             affine.append(aff1)
     affine = py.make_list(affine, len(dat))
     mask = py.make_list(mask, len(dat))
     self._dat = []
     for dat1, aff1, mask1 in zip(dat, affine, mask):
         if not isinstance(dat1, Image):
             dat1 = Image(dat1, aff1, mask=mask1, dim=dim,
                          bound=bound, extrapolate=extrapolate)
         self._dat.append(dat1)
Exemplo n.º 7
0
 def _affine(self):
     """Affine orientation matrix of a series+level"""
     # TODO: I don't know yet how we should use GeoTiff to encode
     #   affine matrices. In the matrix/zooms, their voxels are ordered
     #   as [x, y, z] even though their dimensions in the returned array
     #   are ordered as [Z, Y, X]. If we want to keep the same convention
     #   as nitorch, I need to permute the matrix/zooms.
     if '_affine' not in self._cache:
         with self.tiffobj() as tiff:
             omexml = tiff.ome_metadata
             geotags = tiff.geotiff_metadata or {}
         zooms, units, axes = ome_zooms(omexml, self.series)
         if zooms:
             # convert to mm + drop non-spatial zooms
             units = [parse_unit(u) for u in units]
             zooms = [z * (f / 1e-3) for z, (f, type) in zip(zooms, units)
                      if type in ('m', 'pixel')]
             if 'ModelPixelScaleTag' in geotags:
                 warn("Both OME and GeoTiff pixel scales are present: "
                      "{} vs {}. Using OME."
                      .format(zooms, geotags['ModelPixelScaleTag']))
         elif 'ModelPixelScaleTag' in geotags:
             zooms = geotags['ModelPixelScaleTag']
             axes = 'XYZ'
         else:
             zooms = 1.
             axes = [ax for ax in self._axes if ax in 'XYZ']
         if 'ModelTransformation' in geotags:
             aff = geotags['ModelTransformation']
             aff = torch.as_tensor(aff, dtype=torch.double).reshape(4, 4)
             self._cache['_affine'] = aff
         elif ('ModelTiepointTag' in geotags):
             # copied from tifffile
             sx, sy, sz = py.make_list(zooms, n=3)
             tiepoints = torch.as_tensor(geotags['ModelTiepointTag'])
             affines = []
             for tiepoint in tiepoints:
                 i, j, k, x, y, z = tiepoint
                 affines.append(torch.as_tensor(
                     [[sx,  0.0, 0.0, x - i * sx],
                      [0.0, -sy, 0.0, y + j * sy],
                      [0.0, 0.0, sz,  z - k * sz],
                      [0.0, 0.0, 0.0, 1.0]], dtype=torch.double))
             affines = torch.stack(affines, dim=0)
             if len(tiepoints) == 1:
                 affines = affines[0]
                 self._cache['_affine'] = affines
         else:
             zooms = py.make_list(zooms, n=len(axes))
             ax2zoom = {ax: zoom for ax, zoom in zip(axes, zooms)}
             axes = [ax for ax in self._axes if ax in 'XYZ']
             shape = [shp for shp, msk in zip(self._shape, self._spatial)
                      if msk]
             zooms = [ax2zoom.get(ax, 1.) for ax in axes]
             layout = [('R' if ax == 'Z' else 'P' if ax == 'Y' else 'S')
                       for ax in axes]
             aff = affine_default(shape, zooms, layout=''.join(layout))
             self._cache['_affine'] = aff
     return self._cache['_affine']
Exemplo n.º 8
0
    def shape(self, x, output_shape=None):
        if torch.is_tensor(x):
            x = x.shape
        x = list(x)

        dim = len(x) - 2
        stride = py.make_list(self.stride, dim)

        if output_shape:
            output_shape = py.make_list(output_shape, dim)
        else:
            output_shape = [sz*st for sz, st in zip(x[2:], stride)]
        return (*x[:2], *output_shape)
Exemplo n.º 9
0
 def fov(self, value):
     if value is None:
         value = (None, None)
     min, max = value
     if torch.is_tensor(min):
         min = min.flatten().tolist()
     min = py.make_list(min, 3)
     if torch.is_tensor(max):
         max = max.flatten().tolist()
     max = py.make_list(max, 3)
     if any(mn is None for mn in min) or any(mx is None for mx in max):
         min0, max0 = self._max_fov()
         min = [mn or mn0 for mn, mn0 in zip(min, min0)]
         max = [mx or mx0 for mx, mx0 in zip(max, max0)]
     self._fov = (tuple(min), tuple(max))
Exemplo n.º 10
0
 def fov_size(self, value):
     if torch.is_tensor(value):
         value = value.flatten().tolist()
     value = py.make_list(value, 3)
     min = [i - v / 2 if v else None for i, v in zip(self._index, value)]
     max = [i + v / 2 if v else None for i, v in zip(self._index, value)]
     self.fov = [min, max]
Exemplo n.º 11
0
    def forward(self, *image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *spatial)
        overload

        Returns
        -------

        """

        image = list(image)
        device = image[0].device

        nb_dim = image[0].dim() - 2
        prob = utils.make_vector(overload.get('prob', self.prob),
                                 dtype=torch.float, device=device)
        dim = overload.get('dim', self.dim)
        dim = py.make_list(dim or range(-nb_dim, 0), nb_dim)

        # sample shift
        flip = torch.rand((nb_dim,), device=device) > (1 - prob)
        dim = [d for d, f in zip(dim, flip) if f]

        if dim:
            for i, img in enumerate(image):
                image[i] = img.flip(dim)
        return image[0] if len(image) == 1 else tuple(image)
Exemplo n.º 12
0
def get_one_hot_map(one_hot_map, nb_classes):
    """Return a well-formed one-hot map"""
    one_hot_map = py.make_list(one_hot_map or [])
    if not one_hot_map:
        one_hot_map = list(range(1, nb_classes))
    if len(one_hot_map) == nb_classes - 1:
        one_hot_map = [*one_hot_map, None]
    if len(one_hot_map) != nb_classes:
        raise ValueError('Number of classes in prior and map '
                         'do not match: {} and {}.'.format(
                             nb_classes, len(one_hot_map)))
    one_hot_map = list(
        map(lambda x: py.make_list(x) if x is not None else x, one_hot_map))
    if sum(elem is None for elem in one_hot_map) > 1:
        raise ValueError('Cannot have more than one implicit class')
    return one_hot_map
Exemplo n.º 13
0
def _composition_jac(jac, rhs, lhs=None, type='grid', identity=None, **kwargs):
    """Jacobian of the composition `(lhs)o(rhs)`

    Parameters
    ----------
    jac : ([batch], *spatial, ndim, ndim) tensor
        Jacobian of input RHS transformation
    rhs : ([batch], *spatial, ndim) tensor
        RHS transformation
    lhs : ([batch], *spatial, ndim) tensor, default=`rhs`
        LHS small displacement
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    composed_jac : ([batch], *spatial, ndim, ndim) tensor
        Jacobian of composition

    """
    if lhs is None:
        lhs = rhs
    dim = rhs.shape[-1]
    backend = utils.backend(rhs)
    typer, typel = py.make_list(type, 2)
    jac_left = grid_jacobian(lhs, type=typel)
    if typer != 'grid':
        if identity is None:
            identity = identity_grid(rhs.shape[-dim - 1:-1], **backend)
        rhs = rhs + identity
    jac_left = _pull_jac(jac_left, rhs)
    jac = torch.matmul(jac_left, jac)
    return jac
Exemplo n.º 14
0
def code_to_pattern(code, kernel_size, **backend):
    """Convert a unique code to a sampling pattern

    Parameters
    ----------
    code : int or ([*batch]) tensor[int]
    kernel_size : sequence of int

    Returns
    -------
    pattern : ([*batch], *kernel_size) tensor[bool]

    """
    backend.setdefault('dtype', torch.bool)
    if torch.is_tensor(code):
        backend.setdefault('device', code.device)
    kernel_size = py.make_list(kernel_size)

    def make_pattern(code):
        pattern = torch.zeros(kernel_size, **backend)
        pattern_flat = pattern.flatten()
        for i in range(len(pattern_flat)):
            pattern_flat[i] = bool((code >> i) & 1)
        return pattern

    if torch.is_tensor(code):
        pattern = code.new_zeros([*code.shape, *kernel_size], **backend)
        for code1 in code.unique():
            pattern[code == code1] = make_pattern(code1)
    else:
        pattern = make_pattern(code)
    return pattern
Exemplo n.º 15
0
 def index(self, value):
     if torch.is_tensor(value):
         value = value.flatten().tolist()
     value = py.make_list(value, 3)
     value = [(mx + mn) / 2 if v is None else v
              for v, mn, mx in zip(value, *self.fov)]
     self._index = tuple(value)
Exemplo n.º 16
0
    def forward(self, q, k, v, **overload):
        """

        Parameters
        ----------
        q : (b, c, *spatial)
            Queries
        k : (b, c, *spatial)
            Keys
        v : (b, c, *spatial)
            Values

        Returns
        -------
        x : (b, c, *spatial)

        """
        kernel_size = overload.pop('kernel_size', self.kernel_size)
        stride = overload.pop('stride', self.kernel_size)
        padding = overload.pop('padding', self.padding)
        padding_mode = overload.pop('padding_mode', self.padding_mode)

        dim = q.dim() - 2
        if padding == 'auto':
            k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode)
            v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode)
        elif padding:
            padding = [0] * 2 + py.make_list(padding, dim)
            k = utils.pad(k, padding, side='both', mode=padding_mode)
            v = utils.pad(v, padding, side='both', mode=padding_mode)

        # compute weights by query/key dot product
        kernel_size = py.make_list(kernel_size, dim)
        k = utils.unfold(k, kernel_size, stride)
        k = k.reshape([*k.shape[:dim + 2], -1])
        k = utils.movedim(k, 1, -1)
        q = utils.movedim(q[..., None], 1, -1)
        k = math.softmax(linalg.dot(k, q), dim=-1)
        k = k[:, None]  # add back channel dimension

        # compute new values by weight/value dot product
        v = utils.unfold(v, kernel_size, stride)
        v = v.reshape([*v.shape[:dim + 2], -1])
        v = linalg.dot(k, v)

        return v
Exemplo n.º 17
0
def spherical_harmonics(shape, order=2, isocenter=None, **backend):
    """Generate a basis of spherical harmonics on a lattice

    Notes
    -----
    .. This should be checked!
    .. Only orders 1 and 2 implemented
    .. I tried to implement some sort of "circular" harmonics in
       dimension 2 but I don't know what I am doing.
    .. The basis is not orthogonal

    Parameters
    ----------
    shape : sequence of int
    order : {1, 2}, default=2
    isocenter : [sequence of] int, default=shape/2
    dtype : torch.dtype, optional
    device : torch.device, optional

    Returns
    -------
    b : (*shape, 2*order + 1) tensor
        Basis

    """
    shape = py.make_list(shape)
    dim = len(shape)
    if dim not in (2, 3):
        raise ValueError('Dimension must be 2 or 3')
    if order not in (1, 2):
        raise ValueError('Order must be 1 or 2')

    if isocenter is None:
        isocenter = [s / 2 for s in shape]
    isocenter = utils.make_vector(isocenter, **backend)

    ramps = identity_grid(shape, **backend)
    for i, ramp in enumerate(ramps.unbind(-1)):
        ramp -= isocenter[i]
        ramp /= shape[i] / 2

    if order == 1:
        return ramps
    # order == 2
    if dim == 3:
        basis = [
            ramps[..., 0] * ramps[..., 1], ramps[..., 0] * ramps[..., 2],
            ramps[..., 1] * ramps[..., 2],
            ramps[..., 0].square() - ramps[..., 1].square(),
            ramps[..., 0].square() - ramps[..., 2].square()
        ]
        return torch.stack(basis, -1)
    else:  # basis == 2
        basis = [
            ramps[..., 0] * ramps[..., 1],
            ramps[..., 0].square() - ramps[..., 1].square()
        ]
        return torch.stack(basis, -1)
Exemplo n.º 18
0
    def loadseg(self,
                fnames,
                segtype='label',
                lookup=None,
                dtype=None,
                device=None):
        """Load a volume from disk

        Parameters
        ----------
        fnames : str or sequence[str]
        segtype : [tuple] {'label', 'implicit', 'explicit'}, default='label'
        lookup : list of [list of] int, optional
        dtype : torch.dtype, optional

        Returns
        -------
        dat : (channels | 1, *spatial) tensor

        """
        insegtype, outsegtype = py.make_list(segtype, 2)
        sdtype = dtype if insegtype != 'label' else torch.int
        fnames = py.make_list(fnames)
        channels = []
        for fname in fnames:  # loop across channels
            dat = self.load(fname, dtype=sdtype, device=device)
            channels.append(dat)
        if len(channels) > 1:
            channels = torch.cat(channels, 0)
        else:
            channels = channels[0]
        if insegtype == 'label' and lookup:
            channels = utils.merge_labels(channels, lookup)
        if insegtype == 'label' and outsegtype != 'label':
            channels = utils.one_hot(channels,
                                     dim=0,
                                     implicit=outsegtype == 'implicit',
                                     implicit_index=0,
                                     dtype=dtype)
            if outsegtype == 'implicit':
                channels /= len(channels) + 1
            else:
                channels /= len(channels)
        return channels
Exemplo n.º 19
0
    def forward(self, x):
        """

        Parameters
        ----------
        x : (b, c, **spatial) tensor

        Returns
        -------
        x : (b, c, **spatial_out) tensor

        """
        dim = x.dim() - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)
        slicer = [slice(o, ((sz-o)//st)*st, st) for o, st, sz in
                  zip(offset, stride, x.shape[2:])]
        slicer = [slice(None)]*2 + slicer
        return x[tuple(slicer)]
Exemplo n.º 20
0
    def __init__(self, shape, in_channels, out_channels, nb_levels=0,
                 decoder=(32, 32, 32, 32), kernel_size=3,
                 activation=tnn.LeakyReLU(0.2), unpool=None):
        """

        Parameters
        ----------
        shape : sequence[int]
            Output spatial shape
        in_channels : int
            Number of input channels (= meta variables)
        out_channels : int
            Number of output channels
        nb_levels : int, default=0
            Number of levels in the decoder.
            If 0: directly generate the image using a dense layer.
        decoder : sequence[int], default=(32, 32, 32, 32)
            Number of features after each layers.
            If len(decoder) is larger than the number of levels, additional
            stride-1 convolutions are applied.
        kernel_size : [sequence of] int, default=3
        activation : str or callable, default=LeakyReLU(0.2)
        unpool : {'conv', 'up', None}, default=None
                'conv' -> 2x2x2 strided convolution (no bias, no activation)
                'up'   -> linear upsampling
                 None  -> use strided convolutions in the decoder
        """
        super().__init__()
        shape = py.make_list(shape)
        dim = len(shape)
        small_shape = [s // 2**nb_levels for s in shape]
        in_feat, *decoder = decoder
        self.dense = Linear(in_channels, py.prod(small_shape)*in_feat)
        self.reshape = lambda x: x.reshape([-1, in_feat, *small_shape])
        decoder, stack = decoder[:nb_levels], decoder[nb_levels:]
        if decoder:
            self.decoder = Decoder(dim, in_feat, decoder,
                                   kernel_size=kernel_size,
                                   activation=activation,
                                   unpool=unpool)
            in_feat = decoder[-1]
        else:
            self.decoder = lambda x: x
        if stack:
            self.stack = StackedConv(dim, in_feat, stack,
                                     kernel_size=kernel_size,
                                     activation=activation)
            in_feat = stack[-1]
        else:
            self.stack = lambda x: x
        self.final = StackedConv(dim, in_feat, out_channels,
                                 kernel_size=kernel_size,
                                 activation=None)
Exemplo n.º 21
0
    def forward(self, x, output_padding=None, output_shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *in_spatial) tensor
        output_padding : [sequence of] int, default=self.output_padding
        output_shape : [sequence of] int, default=self.output_shape

        Returns
        -------
        x : (batch, channel, *out_spatial) tensor

        """
        dim = x.dim() - 2
        offset = py.make_list(self.offset, dim)
        stride = py.make_list(self.stride, dim)

        new_shape = self.shape(x, output_padding=output_padding,
                               output_shape=output_shape)
        y = x.new_zeros(new_shape)
        if self.fill:
            z = utils.unfold(y, stride)
            x = utils.unsqueeze(x, -1, dim)
            slicer = [slice(o, o+sz*st) for sz, st, o in
                      zip(x.shape[2:], stride, offset)]
            slicer = [slice(None)]*2 + slicer
            subz = z[tuple(slicer)]
            slicer = [slice(mx) for mx in subz.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            subz.copy_(x[tuple(slicer)])
        else:
            slicer = [slice(o, None, s) for o, s in zip(offset, stride)]
            slicer = [slice(None)]*2 + slicer
            suby = y[tuple(slicer)]
            slicer = [slice(mx) for mx in suby.shape[2:]]
            slicer = [slice(None)]*2 + slicer
            suby.copy_(x[tuple(slicer)])

        return y
Exemplo n.º 22
0
    def forward(self, x, output_shape=None):

        if len(self) == 1:
            # strided conv
            conv, = self
            x = conv(x, output_shape=output_shape)
        else:
            # resize + conv
            resize, conv = self
            factor = [1/s for s in py.make_list(self.stride)]
            x = resize(x, output_shape=output_shape, factor=factor)
            x = conv(x)
        return x
Exemplo n.º 23
0
    def __init__(self,
                 shape=None,
                 nb_classes=None,
                 amplitude='lognormal',
                 amplitude_exp=5,
                 amplitude_scale=2,
                 fwhm='lognormal',
                 fwhm_exp=15,
                 fwhm_scale=5,
                 logits=False,
                 implicit=False,
                 device=None,
                 dtype=None):
        """

        Parameters
        ----------
        shape : sequence[int], optional
            Output shape
        nb_classes : int, optional
            Number of classes (excluding background)
        amplitude : {'normal', 'lognormal', 'uniform', 'gamma'}, default='lognormal'
        amplitude_exp : float or (channel,) vector_like, default=5
        amplitude_scale : float or (channel,) vector_like, default=2
        fwhm : {'normal', 'lognormal', 'uniform', 'gamma'}, default='lognormal'
        fwhm_exp : float or (channel,) vector_like, default=15
        fwhm_scale : float or (channel,) vector_like, default=5
        logits : bool, default=False
            Outputs are log-odds instead of probabilities
        implicit : bool, default=False
            Outputs have an implicit K+1-th class.
        device : torch.device, optional
            Output tensor device.
        dtype : torch.dtype, optional
            Output tensor datatype.
        """

        super().__init__()
        self.logits = logits
        self.implicit = implicit
        shape = py.make_list(shape)
        self.field = HyperRandomFieldSpline(shape=shape,
                                            channel=nb_classes + 1,
                                            amplitude=amplitude,
                                            amplitude_exp=amplitude_exp,
                                            amplitude_scale=amplitude_scale,
                                            fwhm=fwhm,
                                            fwhm_exp=fwhm_exp,
                                            fwhm_scale=fwhm_scale,
                                            device=device,
                                            dtype=dtype)
Exemplo n.º 24
0
def multitrainer(trainers, keep_on_gpu=True, parallel=False, cuda_pool=(0, )):
    """Train multiple models in parallel.

    Parameters
    ----------
    trainers : sequence[ModelTrainer]
    keep_on_gpu : bool, default=True
        Keep all models on GPU (risk of out-of-memory)
    parallel : bool or int, default=False
        Train model in parallel.
        If an int, only this number of models will be trained in parallel.
    cuda_pool : sequence[int], default=[0]
        IDs of GPUs that can be used to dispatch the models.

    """
    n = len(trainers)
    initial_epoch = max(min(trainer.initial_epoch for trainer in trainers), 1)
    nb_epoch = max(trainer.nb_epoch for trainer in trainers)
    trainers = tuple(trainers)
    if parallel:
        parallel = len(trainers) if parallel is True else parallel
        chunksize = max(len(trainers) // (2 * parallel), 1)
        pool = Pool(parallel)
        cuda_pool = py.make_list(pool, parallel)

    if not torch.cuda.is_available():
        cuda_pool = []

    if not keep_on_gpu:
        for trainer in trainers:
            trainer.to(device='cpu')

    # --- init ---
    if parallel:
        args = zip(trainers, [keep_on_gpu] * n, [cuda_pool] * n)
        trainers = list(pool.map(_init1, args, chunksize=chunksize))
    else:
        args = zip(trainers, [keep_on_gpu] * n, [cuda_pool] * n)
        trainers = list(map(_init1, args))

    # --- train ---
    for epoch in range(initial_epoch + 1, nb_epoch + 1):
        if parallel:
            args = zip(trainers, [epoch] * n, [keep_on_gpu] * n,
                       [cuda_pool] * n)
            trainers = list(pool.map(_train1, args, chunksize=chunksize))
        else:
            args = zip(trainers, [epoch] * n, [keep_on_gpu] * n,
                       [cuda_pool] * n)
            trainers = list(map(_train1, args))
Exemplo n.º 25
0
 def _init_from_fname(new,
                      fnames,
                      permission='r',
                      keep_open=False,
                      **attributes):
     fnames = py.make_list(fnames)
     fs = []
     for fname in fnames:
         f = io.map(fname, permission=permission, keep_open=keep_open)
         while f.dim < 4:
             f = f.unsqueeze(-1)
         fs += [f]
     fs = io.cat(fs, -1).permute([-1, 0, 1, 2])
     new._init_from_mapped(fs, **attributes)
Exemplo n.º 26
0
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False):
    """Apply a GRAPPA kernel to an accelerated k-space

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    kspace : ([*batch], coils, *freq)
        Accelerated k-space
    patterns : (*freq) tensor[long]
        Code of sampling pattern about each k-space location
    kernel_size : sequence of int
        GRAPPA kernel size
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        Dictionary of GRAPPA kernels (keys are pattern codes)

    Returns
    -------
    kspace : ([*batch], coils, *freq)

    """
    ndim = patterns.dim()
    coils, *freq = kspace.shape[-ndim - 1:]
    batch = kspace.shape[:-ndim - 1]
    kernel_size = py.make_list(kernel_size, ndim)

    kspace_out = kspace
    if not inplace:
        kspace_out = kspace_out.clone()
    kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size],
                       side='both')
    kspace = utils.unfold(kspace, kernel_size, stride=1)

    def t(x):
        return x.transpose(-1, -2)

    for code, kernel in kernels.items():
        kernel = kernels[code]
        pattern = code_to_pattern(code, kernel_size, device=kspace.device)
        pattern_size = pattern.sum()
        mask = patterns == code
        kspace1 = kspace[..., mask, :, :][..., pattern]
        kspace1 = kspace1.transpose(-2, -3) \
                         .reshape([*batch, -1, coils * pattern_size])
        kernel = kernel.reshape([*batch, coils, coils * pattern_size])
        kspace1 = t(kspace1.matmul(t(kernel)))
        kspace_out[..., mask] = kspace1

    return kspace_out
Exemplo n.º 27
0
def _prepare_pyramid_levels(losses, opt, dim=None):
    """
    For each loss, compute the pyramid levels of `fix` and `mov` that
    must be computed.
    """
    if opt.name == 'none':
        return [{'fix': None, 'mov': None}] * len(losses)

    vxs = []
    shapes = []
    for loss in losses:
        # fixed image
        dat, affine = _map_image(loss.fix.files, dim)
        vx = dat.voxel_size[0].tolist()
        shape = dat.shape[:dim]
        if loss.fix.pad:
            pad = py.make_list(loss.fix.pad, len(shape))
            shape = [s + 2 * p for s, p in zip(shape, pad)]
        vxs.append(vx)
        shapes.append(shape)
        # moving image
        dat, affine = _map_image(loss.mov.files, dim)
        vx = dat.voxel_size[0].tolist()
        shape = dat.shape[:dim]
        if loss.mov.pad:
            pad = py.make_list(loss.mov.pad, len(shape))
            shape = [s + 2 * p for s, p in zip(shape, pad)]
        vxs.append(vx)
        shapes.append(shape)

    levels = _pyramid_levels(vxs, shapes, opt)
    levels = [{
        'fix': fix,
        'mov': mov
    } for fix, mov in zip(levels[::2], levels[1::2])]
    return levels
Exemplo n.º 28
0
def _softmax_fwd(input, dim=-1, implicit=False):
    """ SoftMax (safe).

    Parameters
    ----------
    input : torch.tensor
        Tensor with values.
    dim : int, default=-1
        Dimension to take softmax, defaults to last dimensions.
    implicit : bool or (bool, bool), default=False
        The first value relates to the input tensor and the second
        relates to the output tensor.
        - implicit[0] == True assumes that an additional (hidden) channel
          with value zero exists.
        - implicit[1] == True drops the last class from the
          softmaxed tensor.

    Returns
    -------
    Z : torch.tensor
        Soft-maxed tensor with values.

    """

    implicit_in, implicit_out = py.make_list(implicit, 2)

    maxval, _ = torch.max(input, dim=dim, keepdim=True)
    if implicit_in:
        maxval.clamp_min_(0)  # don't forget the class full of zeros

    input = input.clone().sub_(maxval).exp_()
    sumval = torch.sum(input,
                       dim=dim,
                       keepdim=True,
                       out=maxval if not implicit_in else None)
    if implicit_in:
        sumval += maxval.neg().exp()  # don't forget the class full of zeros
    input *= sumval.reciprocal_()

    if implicit_in and not implicit_out:
        background = input.sum(dim, keepdim=True).neg_().add_(1)
        input = torch.cat((input, background), dim=dim)
    elif implicit_out and not implicit_in:
        input = utils.slice_tensor(input, slice(-1), dim)

    return input
Exemplo n.º 29
0
    def __init__(self, dim=1, implicit=False):
        """

        Parameters
        ----------
        dim : int, default=1
            Dimension along which to take the softmax
        implicit : bool or (bool, bool), default=False
            The first value relates to the input tensor and the second
            relates to the output tensor.
            - implicit[0] == True assumes that an additional (hidden) channel
              with value zero exists.
            - implicit[1] == True drops the last class from the
              softmaxed tensor.
        """
        super().__init__()
        self.implicit = py.make_list(implicit, 2)
        self.dim = dim
Exemplo n.º 30
0
    def forward(self, *image, **overload):
        """

        Parameters
        ----------
        *image : (batch, channel, *spatial)
        **overload : dict

        Returns
        -------
        *image : (batch, channel, *patch_shape)

        """

        image, *other_images = image
        image = torch.as_tensor(image)
        device = image.device

        dim = image.dim() - 2
        shape = py.make_list(overload.get('shape', self.shape), dim)
        shape = [min(s0, s1) for s0, s1 in zip(image.shape[2:], shape)]

        # sample shift
        max_shift = [d0 - d1 for d0, d1 in zip(image.shape[2:], shape)]
        shift = [[torch.randint(0, s, [], device=device) if s > 0 else 0
                  for s in max_shift] for _ in range(len(image))]

        output = image.new_empty([*image.shape[:2], *shape])
        other_outputs = [im.new_empty([*im.shape[:2], *shape])
                         for im in other_images]

        for b in range(len(image)):
            # subslice
            index = (b, slice(None))  # batch, channel
            index = index + tuple(slice(s, s+d) for s, d in zip(shift[b], shape))
            output[b] = image[index]
            for i in range(len(other_images)):
                other_outputs[i][b] = other_images[i][index]

        if len(other_images) > 0:
            return (output, *other_outputs)
        else:
            return output