示例#1
0
def cast(dat, dtype, casting='unsafe', with_scale=False):
    """Cast an array to a given type.

    Parameters
    ----------
    dat : tensor or ndarray
        Input array
    dtype : dtype
        Output data type (should have the proper on-disk byte order)
    casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe',
               'rescale', 'rescale_zero}, default='unsafe'
        Casting method:
        * 'rescale' makes sure that the dynamic range in the array
          matches the range of the output data type
        * 'rescale_zero' does the same, but keeps the mapping `0 -> 0`
          intact.
        * all other options are implemented in numpy. See `np.can_cast`.
    with_scale : bool, default=False
        Return the scaling applied, if any.

    Returns
    -------
    dat : tensor or ndarray

    """
    scale = 1.
    indtype = dtypes.dtype(dat.dtype)
    outdtype = dtypes.dtype(dtype)

    if casting.startswith('rescale') and not outdtype.is_floating_point:
        # rescale
        # TODO: I am using float64 as an intermediate to cast
        #       Maybe I can do things in a nicer / more robust way
        minval = astype(min(dat), dtypes.float64)
        maxval = astype(max(dat), dtypes.float64)
        if not dtypes.equivalent(indtype, dtypes.float64):
            dat = dat.astype(np.float64)
        if not writeable(dat):
            dat = copy(dat)
        if casting == 'rescale':
            scale = (1 - minval / maxval) / (1 - outdtype.min / outdtype.max)
            offset = (outdtype.max - outdtype.min) / (maxval - minval)
            dat *= scale
            dat += offset
        else:
            assert casting == 'rescale_zero'
            if minval < 0 and not outdtype.is_signed:
                warn("Converting negative values to an unsigned datatype")
            scale = min(
                abs(outdtype.max / maxval) if maxval else float('inf'),
                abs(outdtype.min / minval) if minval else float('inf'))
            dat *= scale
        indtype = dtypes.dtype(dat.dtype)
        casting = 'unsafe'

    # unsafe cast
    if indtype != outdtype:
        dat = astype(dat, outdtype, casting=casting)

    return (dat, scale) if with_scale else dat
示例#2
0
def _torch_astype(dat, dtype, casting='unsafe'):
    """Equivalent to `np.astype` but for torch tensors

    Parameters
    ----------
    dat : torch.tensor
    dtype : dtype
    casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}

    Returns
    -------
    dat : torch.as_tensor

    """
    def error(indtype, outdtype):
        raise TypeError("Cannot cast tensor from dtype('{}') to "
                        "dtype('{}}') according to the rule '{}'.".format(
                            indtype, outdtype, casting))

    if not torch.is_tensor(dat):
        raise TypeError('Expected torch tensor but got {}'.format(type(dat)))
    outdtype = dtypes.dtype(dtype)
    if outdtype.torch is None:
        raise TypeError('Data type {} does not exist in pytorch'.format(dtype))
    indtype = dtypes.dtype(dat.dtype)

    casting = casting.lower()
    if ((casting == 'no' and indtype != outdtype) or
        (casting == 'equiv' and not dtypes.equivalent(indtype, outdtype))
            or (casting == 'safe' and not indtype <= outdtype) or
        (casting == 'same_kind' and not dtypes.same_kind(indtype, outdtype))):
        error(indtype, outdtype)
    return dat.to(outdtype.torch)
示例#3
0
    def data(self,
             dtype=None,
             device=None,
             casting='unsafe',
             rand=True,
             cutoff=None,
             dim=None,
             numpy=False):

        # --- sanity check before reading ---
        dtype = self.dtype if dtype is None else dtype
        dtype = dtypes.dtype(dtype)
        if not numpy and dtype.torch is None:
            raise TypeError(
                'Data type {} does not exist in PyTorch.'.format(dtype))

        # --- check that view is not empty ---
        if py.prod(self.shape) == 0:
            if numpy:
                return np.zeros(self.shape, dtype=dtype.numpy)
            else:
                return torch.zeros(self.shape,
                                   dtype=dtype.torch,
                                   device=device)

        # --- read native data ---
        slicer, perm, newdim = split_operation(self.permutation, self.slicer,
                                               'r')
        with self.tiffobj() as f:
            dat = self._read_data_raw(slicer, tiffobj=f)
        dat = dat.transpose(perm)[newdim]
        indtype = dtypes.dtype(self.dtype)

        # --- cutoff ---
        dat = volutils.cutoff(dat, cutoff, dim)

        # --- cast ---
        rand = rand and not indtype.is_floating_point
        if rand and not dtype.is_floating_point:
            tmpdtype = dtypes.float64
        else:
            tmpdtype = dtype
        dat, scale = volutils.cast(dat,
                                   tmpdtype.numpy,
                                   casting,
                                   with_scale=True)

        # --- random sample ---
        # uniform noise in the uncertainty interval
        if rand and not (scale == 1 and not dtype.is_floating_point):
            dat = volutils.addnoise(dat, scale)

        # --- final cast ---
        dat = volutils.cast(dat, dtype.numpy, 'unsafe')

        # convert to torch if needed
        if not numpy:
            dat = torch.as_tensor(dat, device=device)
        return dat
示例#4
0
def convert(inp,
            meta=None,
            dtype=None,
            casting='unsafe',
            format=None,
            output=None):
    """Convert a volume.

    Parameters
    ----------
    inp : str
        A path to a volume file.
    meta : sequence of (key, value)
        List of metadata fields to set.
    dtype : str or dtype, optional
        Output data type
    casting : {'unsafe', 'rescale', 'rescale_zero'}, default='unsafe'
        Casting method
    format : {'nii', 'nii.gz', 'mgh', 'mgz'}, optional
        Output format

    """

    meta = dict(meta or {})
    if dtype:
        meta['dtype'] = dtype
    fname = inp
    f = io.volumes.map(fname)
    d = f.data(numpy=True)

    dir, base, ext = py.fileparts(fname)
    if format:
        ext = format
        if ext == 'nifti':
            ext = 'nii'
        if ext[0] != '.':
            ext = '.' + ext
    output = output or '{dir}{sep}{base}{ext}'
    output = output.format(dir=dir or '.', sep=os.sep, base=base, ext=ext)

    odtype = meta.get('dtype', None) or f.dtype
    if ext in ('.mgh', '.mgz'):
        from nibabel.freesurfer.mghformat import _dtdefs
        odtype = dtypes.dtype(odtype)
        mgh_dtypes = [dtypes.dtype(dt[2]) for dt in _dtdefs]
        for mgh_dtype in mgh_dtypes:
            if odtype <= mgh_dtype:
                odtype = mgh_dtype
                break
        odtype = odtype.numpy
        meta['dtype'] = odtype

    io.save(d, output, like=f, casting=casting, **meta)
示例#5
0
def astype(dat, dtype, casting='unsafe'):
    """Casting (without rescaling)

    See `np.ndarray.astype`.

    ..warning:: The output `dtype` must exist in `dat`'s framework.

    Parameters
    ----------
    dat : tensor or ndarray
    dtype : dtype
    casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}

    Returns
    -------
    dat : tensor or ndarray

    Raises
    ------
    TypeError

    """
    if torch.is_tensor(dat):
        return _torch_astype(dat, dtype, casting=casting)
    else:
        dtype = dtypes.dtype(dtype).numpy
        return dat.astype(dtype, casting=casting)
示例#6
0
文件: array.py 项目: wyli/nitorch
 def guess_dtype():
     dtype = None
     if dtype is None:
         dtype = metadata.get('dtype', None)
     if dtype is None and like is not None:
         dtype = like.dtype
     if dtype is None:
         dtype = dat.dtype
     dtype = dtypes.dtype(dtype).numpy
     return dtype
示例#7
0
    def fdata(self, dtype=None, device=None, rand=False, cutoff=None,
              dim=None, numpy=False):
        """Load the scaled array in memory

        This function differs from `data` in several ways:
            * The output data type should be a floating point type.
            * If an affine scaling (slope, intercept) is defined in the
              file, it is applied to the data.
            * the default output data type is `torch.get_default_dtype()`.

        Parameters
        ----------
        dtype : dtype_like, optional
            Output data type. By default, use `torch.get_default_dtype()`.
            Should be a floating point type.
        device : torch.device, default='cpu'
            Output device.
        rand : bool, default=False
            If the on-disk dtype is not floating point, sample noise
            in the uncertainty interval.
        cutoff : float or (float, float), default=(0, 1)
            Percentile cutoff. If only one value is provided, it is
            assumed to relate to the upper percentile.
        dim : int or list[int], optional
            Dimensions along which to compute percentiles.
            By default, they are computed on the flattened array.
        numpy : bool, default=False
            Return a numpy array rather than a torch tensor.

        Returns
        -------
        dat : tensor[dtype]

        """
        # --- sanity check ---
        dtype = torch.get_default_dtype() if dtype is None else dtype
        info = dtypes.dtype(dtype)
        if not info.is_floating_point:
            raise TypeError('Output data type should be a floating point '
                            'type but got {}.'.format(dtype))

        # --- get unscaled data ---
        dat = self.data(dtype=dtype, device=device, rand=rand,
                        cutoff=cutoff, dim=dim, numpy=numpy)

        # --- scale ---
        if self.slope != 1:
            dat *= float(self.slope)
        if self.inter != 0:
            dat += float(self.inter)

        return dat
示例#8
0
    def set_fdata(self, dat):
        """Write (partial) scaled data to disk.

        Parameters
        ----------
        dat : tensor
            Tensor to write on disk. It should have shape `self.shape`
            and a floating point data type.

        Returns
        -------
        self : type(self)

        """
        # --- sanity check ---
        info = dtypes.dtype(dat.dtype)
        if not info.is_floating_point:
            raise TypeError('Input data type should be a floating point '
                            'type but got {}.'.format(dat.dtype))
        if dat.shape != self.shape:
            raise TypeError('Expected input shape {} but got {}.'.format(
                self.shape, dat.shape))

        # --- detach ---
        if torch.is_tensor(dat):
            dat = dat.detach()

        # --- unscale ---
        if self.inter != 0 or self.slope != 1:
            dat = dat.clone() if torch.is_tensor(dat) else dat.copy()
        if self.inter != 0:
            dat -= float(self.inter)
        if self.slope != 1:
            dat /= float(self.slope)

        # --- set unscaled data ---
        self.set_data(dat)

        return self
示例#9
0
文件: array.py 项目: wyli/nitorch
    def savef_new(cls, dat, file_like, like=None, **metadata):

        if isinstance(dat, MappedArray):
            if like is None:
                like = dat
            dat = dat.fdata(numpy=True)

        # sanity check
        dtype = dtypes.dtype(dat.dtype)
        if not dtype.is_floating_point:
            raise TypeError('Input data type should be a floating point '
                            'type but got {}.'.format(dat.dtype))

        # detach
        if torch.is_tensor(dat):
            dat = dat.detach().cpu().numpy()

        # defer
        cls.save_new(dat,
                     file_like,
                     like=like,
                     casting='unsafe',
                     _savef=True,
                     **metadata)
示例#10
0
文件: array.py 项目: wyli/nitorch
    def save_new(cls,
                 dat,
                 file_like,
                 like=None,
                 casting='unsafe',
                 _savef=False,
                 **metadata):

        if isinstance(dat, MappedArray):
            if like is None:
                like = dat
            dat = dat.data(numpy=True)
        if torch.is_tensor(dat):
            dat = dat.detach().cpu()
        dat = np.asanyarray(dat)
        if like is not None:
            like = map_array(like)

        # guess data type:
        def guess_dtype():
            dtype = None
            if dtype is None:
                dtype = metadata.get('dtype', None)
            if dtype is None and like is not None:
                dtype = like.dtype
            if dtype is None:
                dtype = dat.dtype
            dtype = dtypes.dtype(dtype).numpy
            return dtype

        dtype = guess_dtype()

        def guess_format():
            # 1) from extension
            ok_klasses = []
            if isinstance(file_like, str):
                base, ext = os.path.splitext(file_like)
                if ext.lower() == '.gz':
                    base, ext = os.path.splitext(base)
                ok_klasses = [
                    klass for klass in all_image_classes
                    if ext in klass.valid_exts
                ]
                if len(ok_klasses) == 1:
                    return ok_klasses[0]
            # 2) from like
                if isinstance(like, BabelArray):
                    return type(like._image)
            # 3) from extension (if conflict)
            if len(ok_klasses) != 0:
                return ok_klasses[0]
            # 4) fallback to nifti-1
            return nib.Nifti1Image

        format = guess_format()

        # build header
        if isinstance(like, BabelArray):
            # defer metadata conversion to nibabel
            header = format.header_class.from_header(
                like._image.dataobj._header)
        else:
            header = format.header_class()
            if like is not None:
                # copy generic metadata
                like_metadata = like.metadata()
                like_metadata.update(metadata)
                metadata = like_metadata
        # set shape now so that we can set zooms/etc
        header.set_data_shape(dat.shape)
        header = metadata_to_header(header, metadata)

        # check endianness
        disk_byteorder = header.endianness
        data_byteorder = dtype.byteorder
        if disk_byteorder == '=':
            disk_byteorder = '<' if sys.byteorder == 'little' else '>'
        if data_byteorder == '=':
            data_byteorder = '<' if sys.byteorder == 'little' else '>'
        if disk_byteorder != data_byteorder:
            dtype = dtype.newbyteorder()

        # set scale
        if hasattr(header, 'set_slope_inter'):
            slope, inter = header.get_slope_inter()
            if slope is None:
                slope = 1
            if inter is None:
                inter = 0
            header.set_slope_inter(slope, inter)

        # unscale
        if _savef:
            assert dtypes.dtype(dat.dtype).is_floating_point
            slope, inter = header.get_slope_inter()
            if inter not in (0, None) or slope not in (1, None):
                dat = dat.copy()
            if inter not in (0, None):
                dat -= inter
            if slope not in (1, None):
                dat /= slope

        # cast + setdtype
        dat = volutils.cast(dat, dtype, casting)
        header.set_data_dtype(dat.dtype)

        # create image object
        image = format(dat, affine=None, header=header)

        # write everything
        file_map = format.filespec_to_file_map(file_like)
        fmap_header = file_map.get('header', file_map.get('image'))
        fmap_image = file_map.get('image')
        fmap_footer = file_map.get('footer', file_map.get('image'))
        fhdr = fmap_header.get_prepare_fileobj('wb')
        if hasattr(header, 'writehdr_to'):
            header.writehdr_to(fhdr)
        elif hasattr(header, 'write_to'):
            header.write_to(fhdr)
        if fmap_image == fmap_header:
            fimg = fhdr
        else:
            fimg = fmap_image.get_prepare_fileobj('wb')
        array_to_file(dat,
                      fimg,
                      dtype,
                      offset=header.get_data_offset(),
                      order=image.ImageArrayProxy.order)
        if fmap_image == fmap_footer:
            fftr = fimg
        else:
            fftr = fmap_footer.get_prepare_fileobj('wb')
        if hasattr(header, 'writeftr_to'):
            header.writeftr_to(fftr)
示例#11
0
    def data(self,
             dtype=None,
             device=None,
             casting='unsafe',
             rand=True,
             missing=None,
             cutoff=None,
             dim=None,
             numpy=False):
        """Load the array in memory

        Parameters
        ----------
        dtype : type or torch.dtype or np.dtype, optional
            Output data type. By default, keep the on-disk data type.
        device : torch.device, default='cpu'
            Output device.
        rand : bool, default=False
            If the on-disk dtype is not floating point, sample noise
            in the uncertainty interval.
        missing : float or sequence[float], optional
            Value(s) that correspond to missing values.
            No noise is added to them, and they are converted to NaNs
            (if possible) or zero (otherwise).
        cutoff : float or (float, float), default=(0, 1)
            Percentile cutoff. If only one value is provided, it is
            assumed to relate to the upper percentile.
        dim : int or list[int], optional
            Dimensions along which to compute percentiles.
            By default, they are computed on the flattened array.
        casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe', 'rescale'}, default='unsafe'
            Controls what kind of data casting may occur:
                * 'no': the data types should not be cast at all.
                * 'equiv': only byte-order changes are allowed.
                * 'safe': only casts which can preserve values are allowed.
                * 'same_kind': only safe casts or casts within a kind,
                  like float64 to float32, are allowed.
                * 'unsafe': any data conversions may be done.
                * 'rescale': the input data is rescaled to match the dynamic
                  range of the output type. The minimum value in the data
                  is mapped to the minimum value of the data type and the
                  maximum value in the data is mapped to the maximum value
                  of the data type.
                * 'rescale_zero': the input data is rescaled to match the
                  dynamic range of the output type, but ensuring that
                  zero maps to zero.
                  > If the data is signed and cast to a signed datatype,
                    zero maps to zero, and the scaling is chosen so that
                    both the maximum and minimum value in the data fit
                    in the output dynamic range.
                  > If the data is signed and cast to an unsigned datatype,
                    negative values "wrap around" (as with an unsafe cast).
                  > If the data is unsigned and cast to a signed datatype,
                    values are kept positive (the negative range is unused).
        numpy : bool, default=False
            Return a numpy array rather than a torch tensor.

        Returns
        -------
        dat : tensor[dtype]

        """
        # --- sanity check before reading ---
        dtype = self.dtype if dtype is None else dtype
        dtype = dtypes.dtype(dtype)
        if not numpy and dtype.torch is None:
            raise TypeError(
                'Data type {} does not exist in PyTorch.'.format(dtype))

        # --- load raw data ---
        dat = self.raw_data()

        # --- move to tensor ---
        indtype = dtypes.dtype(self.dtype)
        if not numpy:
            tmpdtype = dtypes.dtype(indtype.torch_upcast)
            dat = dat.astype(tmpdtype.numpy)
            dat = torch.as_tensor(dat,
                                  dtype=indtype.torch_upcast,
                                  device=device)

        # --- mask of missing values ---
        if missing is not None:
            missing = volutils.missing(dat, missing)
            present = ~missing
        else:
            present = (Ellipsis, )

        # --- cutoff ---
        if cutoff is not None:
            dat[present] = volutils.cutoff(dat[present], cutoff, dim)

        # --- cast + rescale ---
        rand = rand and not indtype.is_floating_point
        tmpdtype = dtypes.float64 if (
            rand and not dtype.is_floating_point) else dtype
        dat, scale = volutils.cast(dat,
                                   tmpdtype,
                                   casting,
                                   indtype=indtype,
                                   returns='dat+scale',
                                   mask=present)

        # --- random sample ---
        # uniform noise in the uncertainty interval
        if rand and not (scale == 1 and not dtype.is_floating_point):
            dat[present] = volutils.addnoise(dat[present], scale)

        # --- final cast ---
        dat = volutils.cast(dat, dtype, 'unsafe')

        # --- replace missing values ---
        if missing is not None:
            if dtype.is_floating_point:
                dat[missing] = float('nan')
            else:
                dat[missing] = 0
        return dat
示例#12
0
文件: main.py 项目: balbasty/nitorch
def chunk(inp, chunk_size=1, dim=-1, output=None, transform=None):
    """Unstack a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    chunk_size : int, default=1
        Size of one chunk.
    dim : int, default=-1
        Dimension along which to unstack.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(dtype=dtype(f.dtype).torch_upcast), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp
    ndim = aff0.shape[-1] - 1
    nchunks = math.ceil(dat.shape[dim] / chunk_size)
    if dim > ndim:
        # we didn't touch the spatial dimensions
        aff = [aff0.clone() for _ in range(len(dat))]
    else:
        aff = []
        slicer = [slice(None) for _ in range(ndim)]
        shape = dat.shape[:ndim]
        for z in range(nchunks):
            slicer[dim] = slice(z * chunk_size, (z + 1) * chunk_size)
            aff1, _ = spatial.affine_sub(aff0, shape, tuple(slicer))
            aff.append(aff1)
    dat = torch.split(dat, chunk_size, dim)
    dat = list(zip(dat, aff))

    formatted_output = []
    if output:
        output = py.make_list(output, len(dat))
        formatted_output = []
        for i, ((dat1, aff1), out1) in enumerate(zip(dat, output)):
            if is_file:
                out1 = out1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
                io.volumes.save(dat1, out1, like=fname, affine=aff1)
            else:
                out1 = out1.format(sep=os.path.sep, i=i + 1)
                io.volumes.save(dat1, out1, affine=aff1)
            formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, len(dat))
        for i, ((_, aff1), trf1) in enumerate(zip(dat, transform)):
            if is_file:
                trf1 = trf1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
            else:
                trf1 = trf1.format(sep=os.path.sep, i=i + 1)
            io.transforms.savef(torch.eye(4), trf1, source=aff0, target=aff1)

    if is_file:
        return formatted_output
    else:
        return dat, aff
示例#13
0
def metadata_to_header(header, metadata, shape=None, dtype=None):
    """Register metadata into a nibabel Header object

    Parameters
    ----------
    header : Header or type
        Original header _or_ Header class.
        If it is a class, default values are used to populate the
        missing fields.
    metadata : dict
        Dictionary of metadata

    Returns
    -------
    header : Header

    """
    is_empty = type(header) is type
    if is_empty:
        header = header()

    # --- generic fields ---

    if metadata.get('voxel_size', None) is not None:
        header = set_voxel_size(header, metadata['voxel_size'], shape)

    if metadata.get('affine', None) is not None:
        header = set_affine(header, metadata['affine'], shape)

    if (metadata.get('slope', None) is not None or
        metadata.get('inter', None) is not None):
        slope = metadata.get('slope', 1.)
        inter = metadata.get('inter', None)
        if isinstance(header, (Spm99AnalyzeHeader, Nifti1Header)):
            header.set_slope_inter(slope, inter)
        else:
            if slope not in (1, None) or inter not in (0, None):
                format_name = type(header).__name__.split('Header')[0]
                warn('Format {} does not accept intensity transforms. '
                     'It will be discarded.'.format(type(header).__name__),
                     RuntimeWarning)

    if (metadata.get('time_step', None) is not None or
        metadata.get('tr', None) is not None):
        time_step = metadata.get('time_step', None) or metadata['tr']
        if isinstance(header, (MGHHeader, Nifti1Header)):
            zooms = header.get_zooms()[:3]
            zooms = (*zooms, time_step)
            header.set_zooms(zooms)
        else:
            warn('Format {} does not accept time steps. '
                 'It will be discarded.'.format(type(header).__name__),
                 RuntimeWarning)

    # TODO: time offset / intent for nifti format
    # TODO: te/ti/fa for MGH format
    #       maybe also nifti from description field?

    if dtype is not None or metadata.get('dtype', None) is not None:
        dtype = dtype or metadata.get('dtype', None)
        dtype = dtypes.dtype(dtype).numpy
        header.set_data_dtype(dtype)

    return header
示例#14
0
def metadata_to_header(header, metadata, shape=None, dtype=None):
    """Register metadata into a nibabel Header object

    Parameters
    ----------
    header : Header or type
        Original header _or_ Header class.
        If it is a class, default values are used to populate the
        missing fields.
    metadata : dict
        Dictionary of metadata

    Returns
    -------
    header : Header

    """
    is_empty = type(header) is type
    if is_empty:
        header = header()

    # --- generic fields ---

    if metadata.get('voxel_size_unit', None) is not None:
        val = metadata['voxel_size_unit']
        if hasattr(header, 'set_xyzt_units'):
            header.set_xyzt_units(val, None)
        else:
            format_name = type(header).__name__.split('Header')[0]
            warn('Format {} does not accept voxel size units. '
                 'It will be discarded.'.format(format_name),
                 RuntimeWarning)

    if metadata.get('voxel_size', None) is not None:
        val = metadata['voxel_size']
        if isinstance(val, str):
            val = ast.literal_eval(val)
        header = set_voxel_size(header, val, shape)

    if metadata.get('affine', None) is not None:
        val = metadata['affine']
        if isinstance(val, str):
            val = ast.literal_eval(val)
        header = set_affine(header, val, shape)

    if (metadata.get('slope', None) is not None or
        metadata.get('inter', None) is not None):
        slope = metadata.get('slope', 1.)
        inter = metadata.get('inter', None)
        if isinstance(slope, str):
            slope = float(ast.literal_eval(slope))
        if isinstance(inter, str):
            slope = float(ast.literal_eval(inter))
        if isinstance(header, (Spm99AnalyzeHeader, Nifti1Header)):
            header.set_slope_inter(slope, inter)
        else:
            if slope not in (1, None) or inter not in (0, None):
                format_name = type(header).__name__.split('Header')[0]
                warn('Format {} does not accept intensity transforms. '
                     'It will be discarded.'.format(format_name),
                     RuntimeWarning)

    if (metadata.get('time_step', None) is not None or
        metadata.get('tr', None) is not None):
        time_step = metadata.get('time_step', None) or metadata['tr']
        unit = None
        if isinstance(time_step, str):
            if time_step.endswith('sec'):
                time_step = time_step[:-3]
                unit = 'sec'
            elif time_step.endswith('ms'):
                time_step = time_step[:-2]
                unit = 'ms'
            elif time_step.endswith('s'):
                time_step = time_step[:-1]
                unit = 'sec'
            time_step = float(ast.literal_eval(time_step))
        unit = unit or metadata.get('te_unit', 'sec')
        if isinstance(header, (MGHHeader, Nifti1Header)):
            if unit == 'sec':
                time_step = time_step * 1e3  # TODO: unit for niftis?
            zooms = header.get_zooms()[:3]
            zooms = (*zooms, time_step)
            try:
                # only possible if 4-th dimension is explicit
                header.set_zooms(zooms)
            except HeaderDataError:
                if isinstance(header, MGHHeader):
                    # set tr manually
                    header['tr'] = time_step

        else:
            warn('Format {} does not accept time steps. '
                 'It will be discarded.'.format(type(header).__name__),
                 RuntimeWarning)

    # TODO: time offset / intent for nifti format

    # TODO: te/ti/fa for MGH format
    #       maybe also nifti from description field?
    if metadata.get('te', None) is not None:
        val = metadata['te']
        unit = None
        if isinstance(val, str):
            if val.endswith('sec'):
                val = val[:-3]
                unit = 'sec'
            elif val.endswith('ms'):
                val = val[:-2]
                unit = 'ms'
            elif val.endswith('s'):
                val = val[:-1]
                unit = 'sec'
            val = float(ast.literal_eval(val))
        unit = unit or metadata.get('te_unit', 'sec')
        if isinstance(header, MGHHeader):
            if unit == 'sec':
                val = val * 1e3
            header['te'] = val

    if metadata.get('ti', None) is not None:
        val = metadata['ti']
        unit = None
        if isinstance(val, str):
            if val.endswith('sec'):
                val = val[:-3]
                unit = 'sec'
            elif val.endswith('ms'):
                val = val[:-2]
                unit = 'ms'
            elif val.endswith('s'):
                val = val[:-1]
                unit = 'sec'
            val = float(ast.literal_eval(val))
        unit = unit or metadata.get('ti_unit', 'sec')
        if isinstance(header, MGHHeader):
            if unit == 'sec':
                val = val * 1e3
            header['ti'] = val

    if metadata.get('fa', None) is not None:
        val = metadata['fa']
        unit = None
        if isinstance(val, str):
            if val.endswith('deg'):
                val = val[:-3]
                unit = 'deg'
            elif val.endswith('rad'):
                val = val[:-3]
                unit = 'rad'
            val = float(ast.literal_eval(val))
        unit = unit or metadata.get('fa_unit', 'deg')
        if isinstance(header, MGHHeader):
            if unit == 'deg':
                val = val * constants.pi / 180.
            header['flip_angle'] = val

    if dtype is not None or metadata.get('dtype', None) is not None:
        dtype = dtype or metadata.get('dtype', None)
        dtype = dtypes.dtype(dtype).numpy
        header.set_data_dtype(dtype)

    return header
示例#15
0
def cast(dat, dtype, casting='unsafe', returns='dat', indtype=None, mask=None):
    """Cast an array to a given type.

    Parameters
    ----------
    dat : tensor or ndarray
        Input array
    dtype : dtype
        Output data type (should have the proper on-disk byte order)
    indtype : dtype, default=dat.dtype
        Original input dtype
    casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe',
               'rescale', 'rescale_zero'}, default='unsafe'
        Casting method:
        * 'rescale' makes sure that the dynamic range in the array
          matches the range of the output data type
        * 'rescale_zero' does the same, but keeps the mapping `0 -> 0`
          intact.
        * all other options are implemented in numpy. See `np.can_cast`.
    returns : [combination of] {'dat', 'scale', 'offset'}, default='dat'
        Return the scaling/offset applied, if any.
    mask : tensor or ndarray, optional
        Mask of voxels to use to compute min/max

    Returns
    -------
    dat : tensor or ndarray, if 'dat' in `returns`
    scale : float, if 'scale' in `returns`
    offset : float, if 'offset' in `returns`

    """
    scale = 1.
    offset = 0.
    indtype = dtypes.dtype(indtype or dat.dtype)
    outdtype = dtypes.dtype(dtype)

    if mask is None:
        mask = (Ellipsis, )

    if casting.startswith('rescale') and not outdtype.is_floating_point:
        # rescale
        # TODO: I am using float64 as an intermediate to cast
        #       Maybe I can do things in a nicer / more robust way
        minval = astype(min(dat[mask]), dtypes.float64)
        maxval = astype(max(dat[mask]), dtypes.float64)
        if 'dat' in returns:
            if not dtypes.equivalent(indtype, dtypes.float64):
                dat = dat.astype(np.float64)
            if not writeable(dat):
                dat = copy(dat)
        if casting == 'rescale':
            scale = (1 - minval / maxval) / (1 - outdtype.min / outdtype.max)
            offset = (outdtype.max - outdtype.min) / (maxval - minval)
            if 'dat' in returns:
                dat *= scale
                dat += offset
        else:
            assert casting == 'rescale_zero'
            if minval < 0 and not outdtype.is_signed:
                warn("Converting negative values to an unsigned datatype")
            scale = min(
                abs(outdtype.max / maxval) if maxval else float('inf'),
                abs(outdtype.min / minval) if minval else float('inf'))
            if 'dat' in returns:
                dat *= scale
        indtype = dtypes.dtype(dat.dtype)
        casting = 'unsafe'

    # unsafe cast
    if 'dat' in returns and indtype != outdtype:
        dat = astype(dat, outdtype, casting=casting)

    output = []
    for component in returns.split('+'):
        if component == 'dat':
            output.append(dat)
        elif component == 'scale':
            output.append(scale)
        elif component == 'offset':
            output.append(offset)
        else:
            output.append(None)
    return tuple(output) if len(output) > 1 else output[0] if output else None