def cast(dat, dtype, casting='unsafe', with_scale=False): """Cast an array to a given type. Parameters ---------- dat : tensor or ndarray Input array dtype : dtype Output data type (should have the proper on-disk byte order) casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe', 'rescale', 'rescale_zero}, default='unsafe' Casting method: * 'rescale' makes sure that the dynamic range in the array matches the range of the output data type * 'rescale_zero' does the same, but keeps the mapping `0 -> 0` intact. * all other options are implemented in numpy. See `np.can_cast`. with_scale : bool, default=False Return the scaling applied, if any. Returns ------- dat : tensor or ndarray """ scale = 1. indtype = dtypes.dtype(dat.dtype) outdtype = dtypes.dtype(dtype) if casting.startswith('rescale') and not outdtype.is_floating_point: # rescale # TODO: I am using float64 as an intermediate to cast # Maybe I can do things in a nicer / more robust way minval = astype(min(dat), dtypes.float64) maxval = astype(max(dat), dtypes.float64) if not dtypes.equivalent(indtype, dtypes.float64): dat = dat.astype(np.float64) if not writeable(dat): dat = copy(dat) if casting == 'rescale': scale = (1 - minval / maxval) / (1 - outdtype.min / outdtype.max) offset = (outdtype.max - outdtype.min) / (maxval - minval) dat *= scale dat += offset else: assert casting == 'rescale_zero' if minval < 0 and not outdtype.is_signed: warn("Converting negative values to an unsigned datatype") scale = min( abs(outdtype.max / maxval) if maxval else float('inf'), abs(outdtype.min / minval) if minval else float('inf')) dat *= scale indtype = dtypes.dtype(dat.dtype) casting = 'unsafe' # unsafe cast if indtype != outdtype: dat = astype(dat, outdtype, casting=casting) return (dat, scale) if with_scale else dat
def _torch_astype(dat, dtype, casting='unsafe'): """Equivalent to `np.astype` but for torch tensors Parameters ---------- dat : torch.tensor dtype : dtype casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'} Returns ------- dat : torch.as_tensor """ def error(indtype, outdtype): raise TypeError("Cannot cast tensor from dtype('{}') to " "dtype('{}}') according to the rule '{}'.".format( indtype, outdtype, casting)) if not torch.is_tensor(dat): raise TypeError('Expected torch tensor but got {}'.format(type(dat))) outdtype = dtypes.dtype(dtype) if outdtype.torch is None: raise TypeError('Data type {} does not exist in pytorch'.format(dtype)) indtype = dtypes.dtype(dat.dtype) casting = casting.lower() if ((casting == 'no' and indtype != outdtype) or (casting == 'equiv' and not dtypes.equivalent(indtype, outdtype)) or (casting == 'safe' and not indtype <= outdtype) or (casting == 'same_kind' and not dtypes.same_kind(indtype, outdtype))): error(indtype, outdtype) return dat.to(outdtype.torch)
def data(self, dtype=None, device=None, casting='unsafe', rand=True, cutoff=None, dim=None, numpy=False): # --- sanity check before reading --- dtype = self.dtype if dtype is None else dtype dtype = dtypes.dtype(dtype) if not numpy and dtype.torch is None: raise TypeError( 'Data type {} does not exist in PyTorch.'.format(dtype)) # --- check that view is not empty --- if py.prod(self.shape) == 0: if numpy: return np.zeros(self.shape, dtype=dtype.numpy) else: return torch.zeros(self.shape, dtype=dtype.torch, device=device) # --- read native data --- slicer, perm, newdim = split_operation(self.permutation, self.slicer, 'r') with self.tiffobj() as f: dat = self._read_data_raw(slicer, tiffobj=f) dat = dat.transpose(perm)[newdim] indtype = dtypes.dtype(self.dtype) # --- cutoff --- dat = volutils.cutoff(dat, cutoff, dim) # --- cast --- rand = rand and not indtype.is_floating_point if rand and not dtype.is_floating_point: tmpdtype = dtypes.float64 else: tmpdtype = dtype dat, scale = volutils.cast(dat, tmpdtype.numpy, casting, with_scale=True) # --- random sample --- # uniform noise in the uncertainty interval if rand and not (scale == 1 and not dtype.is_floating_point): dat = volutils.addnoise(dat, scale) # --- final cast --- dat = volutils.cast(dat, dtype.numpy, 'unsafe') # convert to torch if needed if not numpy: dat = torch.as_tensor(dat, device=device) return dat
def convert(inp, meta=None, dtype=None, casting='unsafe', format=None, output=None): """Convert a volume. Parameters ---------- inp : str A path to a volume file. meta : sequence of (key, value) List of metadata fields to set. dtype : str or dtype, optional Output data type casting : {'unsafe', 'rescale', 'rescale_zero'}, default='unsafe' Casting method format : {'nii', 'nii.gz', 'mgh', 'mgz'}, optional Output format """ meta = dict(meta or {}) if dtype: meta['dtype'] = dtype fname = inp f = io.volumes.map(fname) d = f.data(numpy=True) dir, base, ext = py.fileparts(fname) if format: ext = format if ext == 'nifti': ext = 'nii' if ext[0] != '.': ext = '.' + ext output = output or '{dir}{sep}{base}{ext}' output = output.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) odtype = meta.get('dtype', None) or f.dtype if ext in ('.mgh', '.mgz'): from nibabel.freesurfer.mghformat import _dtdefs odtype = dtypes.dtype(odtype) mgh_dtypes = [dtypes.dtype(dt[2]) for dt in _dtdefs] for mgh_dtype in mgh_dtypes: if odtype <= mgh_dtype: odtype = mgh_dtype break odtype = odtype.numpy meta['dtype'] = odtype io.save(d, output, like=f, casting=casting, **meta)
def astype(dat, dtype, casting='unsafe'): """Casting (without rescaling) See `np.ndarray.astype`. ..warning:: The output `dtype` must exist in `dat`'s framework. Parameters ---------- dat : tensor or ndarray dtype : dtype casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'} Returns ------- dat : tensor or ndarray Raises ------ TypeError """ if torch.is_tensor(dat): return _torch_astype(dat, dtype, casting=casting) else: dtype = dtypes.dtype(dtype).numpy return dat.astype(dtype, casting=casting)
def guess_dtype(): dtype = None if dtype is None: dtype = metadata.get('dtype', None) if dtype is None and like is not None: dtype = like.dtype if dtype is None: dtype = dat.dtype dtype = dtypes.dtype(dtype).numpy return dtype
def fdata(self, dtype=None, device=None, rand=False, cutoff=None, dim=None, numpy=False): """Load the scaled array in memory This function differs from `data` in several ways: * The output data type should be a floating point type. * If an affine scaling (slope, intercept) is defined in the file, it is applied to the data. * the default output data type is `torch.get_default_dtype()`. Parameters ---------- dtype : dtype_like, optional Output data type. By default, use `torch.get_default_dtype()`. Should be a floating point type. device : torch.device, default='cpu' Output device. rand : bool, default=False If the on-disk dtype is not floating point, sample noise in the uncertainty interval. cutoff : float or (float, float), default=(0, 1) Percentile cutoff. If only one value is provided, it is assumed to relate to the upper percentile. dim : int or list[int], optional Dimensions along which to compute percentiles. By default, they are computed on the flattened array. numpy : bool, default=False Return a numpy array rather than a torch tensor. Returns ------- dat : tensor[dtype] """ # --- sanity check --- dtype = torch.get_default_dtype() if dtype is None else dtype info = dtypes.dtype(dtype) if not info.is_floating_point: raise TypeError('Output data type should be a floating point ' 'type but got {}.'.format(dtype)) # --- get unscaled data --- dat = self.data(dtype=dtype, device=device, rand=rand, cutoff=cutoff, dim=dim, numpy=numpy) # --- scale --- if self.slope != 1: dat *= float(self.slope) if self.inter != 0: dat += float(self.inter) return dat
def set_fdata(self, dat): """Write (partial) scaled data to disk. Parameters ---------- dat : tensor Tensor to write on disk. It should have shape `self.shape` and a floating point data type. Returns ------- self : type(self) """ # --- sanity check --- info = dtypes.dtype(dat.dtype) if not info.is_floating_point: raise TypeError('Input data type should be a floating point ' 'type but got {}.'.format(dat.dtype)) if dat.shape != self.shape: raise TypeError('Expected input shape {} but got {}.'.format( self.shape, dat.shape)) # --- detach --- if torch.is_tensor(dat): dat = dat.detach() # --- unscale --- if self.inter != 0 or self.slope != 1: dat = dat.clone() if torch.is_tensor(dat) else dat.copy() if self.inter != 0: dat -= float(self.inter) if self.slope != 1: dat /= float(self.slope) # --- set unscaled data --- self.set_data(dat) return self
def savef_new(cls, dat, file_like, like=None, **metadata): if isinstance(dat, MappedArray): if like is None: like = dat dat = dat.fdata(numpy=True) # sanity check dtype = dtypes.dtype(dat.dtype) if not dtype.is_floating_point: raise TypeError('Input data type should be a floating point ' 'type but got {}.'.format(dat.dtype)) # detach if torch.is_tensor(dat): dat = dat.detach().cpu().numpy() # defer cls.save_new(dat, file_like, like=like, casting='unsafe', _savef=True, **metadata)
def save_new(cls, dat, file_like, like=None, casting='unsafe', _savef=False, **metadata): if isinstance(dat, MappedArray): if like is None: like = dat dat = dat.data(numpy=True) if torch.is_tensor(dat): dat = dat.detach().cpu() dat = np.asanyarray(dat) if like is not None: like = map_array(like) # guess data type: def guess_dtype(): dtype = None if dtype is None: dtype = metadata.get('dtype', None) if dtype is None and like is not None: dtype = like.dtype if dtype is None: dtype = dat.dtype dtype = dtypes.dtype(dtype).numpy return dtype dtype = guess_dtype() def guess_format(): # 1) from extension ok_klasses = [] if isinstance(file_like, str): base, ext = os.path.splitext(file_like) if ext.lower() == '.gz': base, ext = os.path.splitext(base) ok_klasses = [ klass for klass in all_image_classes if ext in klass.valid_exts ] if len(ok_klasses) == 1: return ok_klasses[0] # 2) from like if isinstance(like, BabelArray): return type(like._image) # 3) from extension (if conflict) if len(ok_klasses) != 0: return ok_klasses[0] # 4) fallback to nifti-1 return nib.Nifti1Image format = guess_format() # build header if isinstance(like, BabelArray): # defer metadata conversion to nibabel header = format.header_class.from_header( like._image.dataobj._header) else: header = format.header_class() if like is not None: # copy generic metadata like_metadata = like.metadata() like_metadata.update(metadata) metadata = like_metadata # set shape now so that we can set zooms/etc header.set_data_shape(dat.shape) header = metadata_to_header(header, metadata) # check endianness disk_byteorder = header.endianness data_byteorder = dtype.byteorder if disk_byteorder == '=': disk_byteorder = '<' if sys.byteorder == 'little' else '>' if data_byteorder == '=': data_byteorder = '<' if sys.byteorder == 'little' else '>' if disk_byteorder != data_byteorder: dtype = dtype.newbyteorder() # set scale if hasattr(header, 'set_slope_inter'): slope, inter = header.get_slope_inter() if slope is None: slope = 1 if inter is None: inter = 0 header.set_slope_inter(slope, inter) # unscale if _savef: assert dtypes.dtype(dat.dtype).is_floating_point slope, inter = header.get_slope_inter() if inter not in (0, None) or slope not in (1, None): dat = dat.copy() if inter not in (0, None): dat -= inter if slope not in (1, None): dat /= slope # cast + setdtype dat = volutils.cast(dat, dtype, casting) header.set_data_dtype(dat.dtype) # create image object image = format(dat, affine=None, header=header) # write everything file_map = format.filespec_to_file_map(file_like) fmap_header = file_map.get('header', file_map.get('image')) fmap_image = file_map.get('image') fmap_footer = file_map.get('footer', file_map.get('image')) fhdr = fmap_header.get_prepare_fileobj('wb') if hasattr(header, 'writehdr_to'): header.writehdr_to(fhdr) elif hasattr(header, 'write_to'): header.write_to(fhdr) if fmap_image == fmap_header: fimg = fhdr else: fimg = fmap_image.get_prepare_fileobj('wb') array_to_file(dat, fimg, dtype, offset=header.get_data_offset(), order=image.ImageArrayProxy.order) if fmap_image == fmap_footer: fftr = fimg else: fftr = fmap_footer.get_prepare_fileobj('wb') if hasattr(header, 'writeftr_to'): header.writeftr_to(fftr)
def data(self, dtype=None, device=None, casting='unsafe', rand=True, missing=None, cutoff=None, dim=None, numpy=False): """Load the array in memory Parameters ---------- dtype : type or torch.dtype or np.dtype, optional Output data type. By default, keep the on-disk data type. device : torch.device, default='cpu' Output device. rand : bool, default=False If the on-disk dtype is not floating point, sample noise in the uncertainty interval. missing : float or sequence[float], optional Value(s) that correspond to missing values. No noise is added to them, and they are converted to NaNs (if possible) or zero (otherwise). cutoff : float or (float, float), default=(0, 1) Percentile cutoff. If only one value is provided, it is assumed to relate to the upper percentile. dim : int or list[int], optional Dimensions along which to compute percentiles. By default, they are computed on the flattened array. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe', 'rescale'}, default='unsafe' Controls what kind of data casting may occur: * 'no': the data types should not be cast at all. * 'equiv': only byte-order changes are allowed. * 'safe': only casts which can preserve values are allowed. * 'same_kind': only safe casts or casts within a kind, like float64 to float32, are allowed. * 'unsafe': any data conversions may be done. * 'rescale': the input data is rescaled to match the dynamic range of the output type. The minimum value in the data is mapped to the minimum value of the data type and the maximum value in the data is mapped to the maximum value of the data type. * 'rescale_zero': the input data is rescaled to match the dynamic range of the output type, but ensuring that zero maps to zero. > If the data is signed and cast to a signed datatype, zero maps to zero, and the scaling is chosen so that both the maximum and minimum value in the data fit in the output dynamic range. > If the data is signed and cast to an unsigned datatype, negative values "wrap around" (as with an unsafe cast). > If the data is unsigned and cast to a signed datatype, values are kept positive (the negative range is unused). numpy : bool, default=False Return a numpy array rather than a torch tensor. Returns ------- dat : tensor[dtype] """ # --- sanity check before reading --- dtype = self.dtype if dtype is None else dtype dtype = dtypes.dtype(dtype) if not numpy and dtype.torch is None: raise TypeError( 'Data type {} does not exist in PyTorch.'.format(dtype)) # --- load raw data --- dat = self.raw_data() # --- move to tensor --- indtype = dtypes.dtype(self.dtype) if not numpy: tmpdtype = dtypes.dtype(indtype.torch_upcast) dat = dat.astype(tmpdtype.numpy) dat = torch.as_tensor(dat, dtype=indtype.torch_upcast, device=device) # --- mask of missing values --- if missing is not None: missing = volutils.missing(dat, missing) present = ~missing else: present = (Ellipsis, ) # --- cutoff --- if cutoff is not None: dat[present] = volutils.cutoff(dat[present], cutoff, dim) # --- cast + rescale --- rand = rand and not indtype.is_floating_point tmpdtype = dtypes.float64 if ( rand and not dtype.is_floating_point) else dtype dat, scale = volutils.cast(dat, tmpdtype, casting, indtype=indtype, returns='dat+scale', mask=present) # --- random sample --- # uniform noise in the uncertainty interval if rand and not (scale == 1 and not dtype.is_floating_point): dat[present] = volutils.addnoise(dat[present], scale) # --- final cast --- dat = volutils.cast(dat, dtype, 'unsafe') # --- replace missing values --- if missing is not None: if dtype.is_floating_point: dat[missing] = float('nan') else: dat[missing] = 0 return dat
def chunk(inp, chunk_size=1, dim=-1, output=None, transform=None): """Unstack a ND volume, while preserving the orientation matrices. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. chunk_size : int, default=1 Size of one chunk. dim : int, default=-1 Dimension along which to unstack. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. transform : [sequence of] str, optional Output filename(s) of the corresponding transforms. Not written by default. Returns ------- output : list[str or (tensor, tensor)] If the input is a path, the output paths are returned. Else, the unstacked data and orientation matrices are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.data(dtype=dtype(f.dtype).torch_upcast), f.affine) if output is None: output = '{dir}{sep}{base}.{i}{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp ndim = aff0.shape[-1] - 1 nchunks = math.ceil(dat.shape[dim] / chunk_size) if dim > ndim: # we didn't touch the spatial dimensions aff = [aff0.clone() for _ in range(len(dat))] else: aff = [] slicer = [slice(None) for _ in range(ndim)] shape = dat.shape[:ndim] for z in range(nchunks): slicer[dim] = slice(z * chunk_size, (z + 1) * chunk_size) aff1, _ = spatial.affine_sub(aff0, shape, tuple(slicer)) aff.append(aff1) dat = torch.split(dat, chunk_size, dim) dat = list(zip(dat, aff)) formatted_output = [] if output: output = py.make_list(output, len(dat)) formatted_output = [] for i, ((dat1, aff1), out1) in enumerate(zip(dat, output)): if is_file: out1 = out1.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, i=i + 1) io.volumes.save(dat1, out1, like=fname, affine=aff1) else: out1 = out1.format(sep=os.path.sep, i=i + 1) io.volumes.save(dat1, out1, affine=aff1) formatted_output.append(out1) if transform: transform = py.make_list(transform, len(dat)) for i, ((_, aff1), trf1) in enumerate(zip(dat, transform)): if is_file: trf1 = trf1.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, i=i + 1) else: trf1 = trf1.format(sep=os.path.sep, i=i + 1) io.transforms.savef(torch.eye(4), trf1, source=aff0, target=aff1) if is_file: return formatted_output else: return dat, aff
def metadata_to_header(header, metadata, shape=None, dtype=None): """Register metadata into a nibabel Header object Parameters ---------- header : Header or type Original header _or_ Header class. If it is a class, default values are used to populate the missing fields. metadata : dict Dictionary of metadata Returns ------- header : Header """ is_empty = type(header) is type if is_empty: header = header() # --- generic fields --- if metadata.get('voxel_size', None) is not None: header = set_voxel_size(header, metadata['voxel_size'], shape) if metadata.get('affine', None) is not None: header = set_affine(header, metadata['affine'], shape) if (metadata.get('slope', None) is not None or metadata.get('inter', None) is not None): slope = metadata.get('slope', 1.) inter = metadata.get('inter', None) if isinstance(header, (Spm99AnalyzeHeader, Nifti1Header)): header.set_slope_inter(slope, inter) else: if slope not in (1, None) or inter not in (0, None): format_name = type(header).__name__.split('Header')[0] warn('Format {} does not accept intensity transforms. ' 'It will be discarded.'.format(type(header).__name__), RuntimeWarning) if (metadata.get('time_step', None) is not None or metadata.get('tr', None) is not None): time_step = metadata.get('time_step', None) or metadata['tr'] if isinstance(header, (MGHHeader, Nifti1Header)): zooms = header.get_zooms()[:3] zooms = (*zooms, time_step) header.set_zooms(zooms) else: warn('Format {} does not accept time steps. ' 'It will be discarded.'.format(type(header).__name__), RuntimeWarning) # TODO: time offset / intent for nifti format # TODO: te/ti/fa for MGH format # maybe also nifti from description field? if dtype is not None or metadata.get('dtype', None) is not None: dtype = dtype or metadata.get('dtype', None) dtype = dtypes.dtype(dtype).numpy header.set_data_dtype(dtype) return header
def metadata_to_header(header, metadata, shape=None, dtype=None): """Register metadata into a nibabel Header object Parameters ---------- header : Header or type Original header _or_ Header class. If it is a class, default values are used to populate the missing fields. metadata : dict Dictionary of metadata Returns ------- header : Header """ is_empty = type(header) is type if is_empty: header = header() # --- generic fields --- if metadata.get('voxel_size_unit', None) is not None: val = metadata['voxel_size_unit'] if hasattr(header, 'set_xyzt_units'): header.set_xyzt_units(val, None) else: format_name = type(header).__name__.split('Header')[0] warn('Format {} does not accept voxel size units. ' 'It will be discarded.'.format(format_name), RuntimeWarning) if metadata.get('voxel_size', None) is not None: val = metadata['voxel_size'] if isinstance(val, str): val = ast.literal_eval(val) header = set_voxel_size(header, val, shape) if metadata.get('affine', None) is not None: val = metadata['affine'] if isinstance(val, str): val = ast.literal_eval(val) header = set_affine(header, val, shape) if (metadata.get('slope', None) is not None or metadata.get('inter', None) is not None): slope = metadata.get('slope', 1.) inter = metadata.get('inter', None) if isinstance(slope, str): slope = float(ast.literal_eval(slope)) if isinstance(inter, str): slope = float(ast.literal_eval(inter)) if isinstance(header, (Spm99AnalyzeHeader, Nifti1Header)): header.set_slope_inter(slope, inter) else: if slope not in (1, None) or inter not in (0, None): format_name = type(header).__name__.split('Header')[0] warn('Format {} does not accept intensity transforms. ' 'It will be discarded.'.format(format_name), RuntimeWarning) if (metadata.get('time_step', None) is not None or metadata.get('tr', None) is not None): time_step = metadata.get('time_step', None) or metadata['tr'] unit = None if isinstance(time_step, str): if time_step.endswith('sec'): time_step = time_step[:-3] unit = 'sec' elif time_step.endswith('ms'): time_step = time_step[:-2] unit = 'ms' elif time_step.endswith('s'): time_step = time_step[:-1] unit = 'sec' time_step = float(ast.literal_eval(time_step)) unit = unit or metadata.get('te_unit', 'sec') if isinstance(header, (MGHHeader, Nifti1Header)): if unit == 'sec': time_step = time_step * 1e3 # TODO: unit for niftis? zooms = header.get_zooms()[:3] zooms = (*zooms, time_step) try: # only possible if 4-th dimension is explicit header.set_zooms(zooms) except HeaderDataError: if isinstance(header, MGHHeader): # set tr manually header['tr'] = time_step else: warn('Format {} does not accept time steps. ' 'It will be discarded.'.format(type(header).__name__), RuntimeWarning) # TODO: time offset / intent for nifti format # TODO: te/ti/fa for MGH format # maybe also nifti from description field? if metadata.get('te', None) is not None: val = metadata['te'] unit = None if isinstance(val, str): if val.endswith('sec'): val = val[:-3] unit = 'sec' elif val.endswith('ms'): val = val[:-2] unit = 'ms' elif val.endswith('s'): val = val[:-1] unit = 'sec' val = float(ast.literal_eval(val)) unit = unit or metadata.get('te_unit', 'sec') if isinstance(header, MGHHeader): if unit == 'sec': val = val * 1e3 header['te'] = val if metadata.get('ti', None) is not None: val = metadata['ti'] unit = None if isinstance(val, str): if val.endswith('sec'): val = val[:-3] unit = 'sec' elif val.endswith('ms'): val = val[:-2] unit = 'ms' elif val.endswith('s'): val = val[:-1] unit = 'sec' val = float(ast.literal_eval(val)) unit = unit or metadata.get('ti_unit', 'sec') if isinstance(header, MGHHeader): if unit == 'sec': val = val * 1e3 header['ti'] = val if metadata.get('fa', None) is not None: val = metadata['fa'] unit = None if isinstance(val, str): if val.endswith('deg'): val = val[:-3] unit = 'deg' elif val.endswith('rad'): val = val[:-3] unit = 'rad' val = float(ast.literal_eval(val)) unit = unit or metadata.get('fa_unit', 'deg') if isinstance(header, MGHHeader): if unit == 'deg': val = val * constants.pi / 180. header['flip_angle'] = val if dtype is not None or metadata.get('dtype', None) is not None: dtype = dtype or metadata.get('dtype', None) dtype = dtypes.dtype(dtype).numpy header.set_data_dtype(dtype) return header
def cast(dat, dtype, casting='unsafe', returns='dat', indtype=None, mask=None): """Cast an array to a given type. Parameters ---------- dat : tensor or ndarray Input array dtype : dtype Output data type (should have the proper on-disk byte order) indtype : dtype, default=dat.dtype Original input dtype casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe', 'rescale', 'rescale_zero'}, default='unsafe' Casting method: * 'rescale' makes sure that the dynamic range in the array matches the range of the output data type * 'rescale_zero' does the same, but keeps the mapping `0 -> 0` intact. * all other options are implemented in numpy. See `np.can_cast`. returns : [combination of] {'dat', 'scale', 'offset'}, default='dat' Return the scaling/offset applied, if any. mask : tensor or ndarray, optional Mask of voxels to use to compute min/max Returns ------- dat : tensor or ndarray, if 'dat' in `returns` scale : float, if 'scale' in `returns` offset : float, if 'offset' in `returns` """ scale = 1. offset = 0. indtype = dtypes.dtype(indtype or dat.dtype) outdtype = dtypes.dtype(dtype) if mask is None: mask = (Ellipsis, ) if casting.startswith('rescale') and not outdtype.is_floating_point: # rescale # TODO: I am using float64 as an intermediate to cast # Maybe I can do things in a nicer / more robust way minval = astype(min(dat[mask]), dtypes.float64) maxval = astype(max(dat[mask]), dtypes.float64) if 'dat' in returns: if not dtypes.equivalent(indtype, dtypes.float64): dat = dat.astype(np.float64) if not writeable(dat): dat = copy(dat) if casting == 'rescale': scale = (1 - minval / maxval) / (1 - outdtype.min / outdtype.max) offset = (outdtype.max - outdtype.min) / (maxval - minval) if 'dat' in returns: dat *= scale dat += offset else: assert casting == 'rescale_zero' if minval < 0 and not outdtype.is_signed: warn("Converting negative values to an unsigned datatype") scale = min( abs(outdtype.max / maxval) if maxval else float('inf'), abs(outdtype.min / minval) if minval else float('inf')) if 'dat' in returns: dat *= scale indtype = dtypes.dtype(dat.dtype) casting = 'unsafe' # unsafe cast if 'dat' in returns and indtype != outdtype: dat = astype(dat, outdtype, casting=casting) output = [] for component in returns.split('+'): if component == 'dat': output.append(dat) elif component == 'scale': output.append(scale) elif component == 'offset': output.append(offset) else: output.append(None) return tuple(output) if len(output) > 1 else output[0] if output else None