예제 #1
0
    def data(self,
             dtype=None,
             device=None,
             casting='unsafe',
             rand=True,
             cutoff=None,
             dim=None,
             numpy=False):

        # --- sanity check before reading ---
        dtype = self.dtype if dtype is None else dtype
        dtype = dtypes.dtype(dtype)
        if not numpy and dtype.torch is None:
            raise TypeError(
                'Data type {} does not exist in PyTorch.'.format(dtype))

        # --- check that view is not empty ---
        if py.prod(self.shape) == 0:
            if numpy:
                return np.zeros(self.shape, dtype=dtype.numpy)
            else:
                return torch.zeros(self.shape,
                                   dtype=dtype.torch,
                                   device=device)

        # --- read native data ---
        slicer, perm, newdim = split_operation(self.permutation, self.slicer,
                                               'r')
        with self.tiffobj() as f:
            dat = self._read_data_raw(slicer, tiffobj=f)
        dat = dat.transpose(perm)[newdim]
        indtype = dtypes.dtype(self.dtype)

        # --- cutoff ---
        dat = volutils.cutoff(dat, cutoff, dim)

        # --- cast ---
        rand = rand and not indtype.is_floating_point
        if rand and not dtype.is_floating_point:
            tmpdtype = dtypes.float64
        else:
            tmpdtype = dtype
        dat, scale = volutils.cast(dat,
                                   tmpdtype.numpy,
                                   casting,
                                   with_scale=True)

        # --- random sample ---
        # uniform noise in the uncertainty interval
        if rand and not (scale == 1 and not dtype.is_floating_point):
            dat = volutils.addnoise(dat, scale)

        # --- final cast ---
        dat = volutils.cast(dat, dtype.numpy, 'unsafe')

        # convert to torch if needed
        if not numpy:
            dat = torch.as_tensor(dat, device=device)
        return dat
예제 #2
0
파일: lta.py 프로젝트: balbasty/nitorch
 def data(self, dtype=None, device=None, numpy=False):
     if self._struct.affine is None:
         return None
     dtype = dtype or torch.get_default_dtype()
     affine = cast(self._struct.affine, dtype)
     if numpy:
         return np.asarray(affine)
     else:
         return torch.as_tensor(affine, dtype=dtype, device=device)
예제 #3
0
    def set_data(self, dat, casting='unsafe'):

        if '+' not in self.mode:
            raise RuntimeError('Cannot write into read-only volume. '
                               'Re-map in mode "r+" to allow in-place '
                               'writing.')

        # --- convert to numpy ---
        if torch.is_tensor(dat):
            dat = dat.detach().cpu()
        dat = np.asanyarray(dat)

        # --- sanity check ---
        if dat.shape != self.shape:
            raise ValueError(
                'Expected an array of shape {} but got {}.'.format(
                    self.shape, dat.shape))

        # --- special case if empty view ---
        if dat.size == 0:
            # nothing to write
            return self

        # --- cast ---
        dat = volutils.cast(dat, self.dtype, casting)

        # --- unpermute ---
        drop, perm, slicer = split_operation(self.permutation, self.slicer,
                                             'w')
        dat = dat[drop].transpose(perm)

        # --- dispatch ---
        if self.is_compressed('image'):
            if not all(is_fullslice(slicer, self._shape)):
                # read-and-write
                slice = dat
                with self.fileobj('image', 'r') as f:
                    dat = self._read_data_raw(fileobj=f, mmap=False)
                dat[slicer] = slice
            with self.fileobj('image', 'w', seek=0) as f:
                if self.same_file('image', 'header'):
                    self._set_header_raw(fileobj=f)
                self._write_data_raw_full(dat, fileobj=f)
                if self.same_file('image', 'footer'):
                    self._set_footer_raw(fileobj=f)

        elif all(is_fullslice(slicer, self._shape)):
            with self.fileobj('image', 'r+') as f:
                self._write_data_raw_full(dat, fileobj=f)

        else:
            with self.fileobj('image', 'r+') as f:
                self._write_data_raw_partial(dat, slicer, fileobj=f)
        return self
예제 #4
0
파일: lta.py 프로젝트: balbasty/nitorch
    def destination_space(self,
                          source='voxel',
                          dest='ras',
                          dtype=None,
                          device=None,
                          numpy=False):
        """Return the space (affine + shape) of the destination image

        Parameters
        ----------
        source : {'voxel', 'physical', 'ras'}, default='voxel'
            Source space of the affine
        dest : {'voxel', 'physical', 'ras'}, default='ras'
            Destination space of the affine
        dtype : torch.dtype, optional
        device : torch.device, optional
        numpy : bool, default=False

        Returns
        -------
        affine : (4, 4) tensor
            A voxel to world affine matrix
        shape : (3,) tuple[int]
            The spatial shape of the image

        """
        if self._struct.dst is not None:
            affine = fs_to_affine(self._struct.dst.volume,
                                  self._struct.dst.voxelsize,
                                  self._struct.dst.xras,
                                  self._struct.dst.yras,
                                  self._struct.dst.zras,
                                  self._struct.dst.cras,
                                  source=source,
                                  dest=dest)
            shape = tuple(self._struct.dst.volume)
            affine = cast(affine, dtype)
            if numpy:
                affine = np.asarray(affine)
            else:
                affine = torch.as_tensor(affine, dtype=dtype, device=device)
            return affine, shape
        return None, None
예제 #5
0
파일: lta.py 프로젝트: balbasty/nitorch
    def fdata(self, dtype=None, device=None, numpy=False):
        dtype = dtype or torch.get_default_dtype()
        backend = dict(dtype=dtype, device=device)
        affine = self.data(**backend)
        if affine is None:
            return None

        # we may need to convert from a weird space to RAS space
        afftype = self.type()[0]
        if afftype != 'ras':
            src, _ = self.source_space('ras', afftype, **backend)
            dst, _ = self.destination_space(afftype, 'ras', **backend)
            if src is not None and dst is not None:
                affine = affine_matmul(dst, affine_matmul(affine, src))

        affine = cast(affine, dtype)
        if numpy:
            return np.asarray(affine)
        else:
            return torch.as_tensor(affine, dtype=dtype, device=device)
예제 #6
0
    def save_new(cls,
                 dat,
                 file_like,
                 like=None,
                 casting='unsafe',
                 _savef=False,
                 **metadata):

        if isinstance(dat, MappedArray):
            if like is None:
                like = dat
            dat = dat.data(numpy=True)
        if torch.is_tensor(dat):
            dat = dat.detach().cpu()
        dat = np.asanyarray(dat)
        if like is not None:
            like = map_array(like)

        # guess data type:
        def guess_dtype():
            dtype = None
            if dtype is None:
                dtype = metadata.get('dtype', None)
            if dtype is None and like is not None:
                dtype = like.dtype
            if dtype is None:
                dtype = dat.dtype
            dtype = dtypes.dtype(dtype).numpy
            return dtype

        dtype = guess_dtype()

        def guess_format():
            # 1) from extension
            ok_klasses = []
            if isinstance(file_like, str):
                base, ext = os.path.splitext(file_like)
                if ext.lower() == '.gz':
                    base, ext = os.path.splitext(base)
                ok_klasses = [
                    klass for klass in all_image_classes
                    if ext in klass.valid_exts
                ]
                if len(ok_klasses) == 1:
                    return ok_klasses[0]
            # 2) from like
                if isinstance(like, BabelArray):
                    return type(like._image)
            # 3) from extension (if conflict)
            if len(ok_klasses) != 0:
                return ok_klasses[0]
            # 4) fallback to nifti-1
            return nib.Nifti1Image

        format = guess_format()

        # build header
        if isinstance(like, BabelArray):
            # defer metadata conversion to nibabel
            header = format.header_class.from_header(
                like._image.dataobj._header)
        else:
            header = format.header_class()
            if like is not None:
                # copy generic metadata
                like_metadata = like.metadata()
                like_metadata.update(metadata)
                metadata = like_metadata
        # set shape now so that we can set zooms/etc
        header.set_data_shape(dat.shape)
        header = metadata_to_header(header, metadata, shape=dat.shape)

        # check endianness
        disk_byteorder = header.endianness
        data_byteorder = dtype.byteorder
        if disk_byteorder == '=':
            disk_byteorder = '<' if sys.byteorder == 'little' else '>'
        if data_byteorder == '=':
            data_byteorder = '<' if sys.byteorder == 'little' else '>'
        if disk_byteorder != data_byteorder:
            dtype = dtype.newbyteorder()

        # set scale
        if hasattr(header, 'set_slope_inter'):
            slope, inter = header.get_slope_inter()
            if slope is None:
                slope = 1
            if inter is None:
                inter = 0
            header.set_slope_inter(slope, inter)

        # unscale
        if _savef:
            assert dtypes.dtype(dat.dtype).is_floating_point
            slope, inter = header.get_slope_inter()
            if inter not in (0, None) or slope not in (1, None):
                dat = dat.copy()
            if inter not in (0, None):
                dat -= inter
            if slope not in (1, None):
                dat /= slope

        # cast + setdtype
        dat = volutils.cast(dat, dtype, casting)
        header.set_data_dtype(dat.dtype)

        # create image object
        image = format(dat, affine=None, header=header)

        # write everything
        file_map = format.filespec_to_file_map(file_like)
        fmap_header = file_map.get('header', file_map.get('image'))
        fmap_image = file_map.get('image')
        fmap_footer = file_map.get('footer', file_map.get('image'))
        fhdr = fmap_header.get_prepare_fileobj('wb')
        if hasattr(header, 'writehdr_to'):
            header.writehdr_to(fhdr)
        elif hasattr(header, 'write_to'):
            header.write_to(fhdr)
        if fmap_image == fmap_header:
            fimg = fhdr
        else:
            fimg = fmap_image.get_prepare_fileobj('wb')
        array_to_file(dat,
                      fimg,
                      dtype,
                      offset=header.get_data_offset(),
                      order=image.ImageArrayProxy.order)
        if fmap_image == fmap_footer:
            fftr = fimg
        else:
            fftr = fmap_footer.get_prepare_fileobj('wb')
        if hasattr(header, 'writeftr_to'):
            header.writeftr_to(fftr)