示例#1
0
    def translated(self, coord1, coord2=None, coord3=None):
        """
        Make a copy of this scatterer translated to a new location

        Parameters
        ----------
        x, y, z : float
            Value of the translation along each axis

        Returns
        -------
        translated : Scatterer
            A copy of this scatterer translated to a new location
        """
        if coord2 is None and len(ensure_array(coord1) == 3):
            # entered translation vector
            trans_coords = ensure_array(coord1)
        elif coord2 is not None and coord3 is not None:
            # entered 3 coords
            trans_coords = np.array([coord1, coord2, coord3])
        else:
            raise InvalidScatterer(
                self, "Cannot interpret translation coordinates")
        new = copy(self)
        new.center = self.center + trans_coords
        return new
示例#2
0
 def __init__(self, spheres, translation=(0, 0, 0), rotation=(0, 0, 0)):
     if isinstance(spheres, Spheres):
         self.spheres = spheres
     else:
         raise InvalidScatterer(
             self,
             "RigidCluster only accepts a scatterer of class Spheres.")
     if not (len(ensure_array(translation)) == 3
             and len(ensure_array(rotation)) == 3):
         raise ValueError(
             'translation and rotation must be listlike of len 3')
     else:
         self.translation = translation
         self.rotation = rotation
示例#3
0
 def _calculate_multiple_color_scattered_field(self, scatterer, schema):
     field = []
     for illum in schema.illum_wavelen.illumination.values:
         this_schema = update_metadata(
             schema,
             illum_wavelen=ensure_array(
                 schema.illum_wavelen.sel(illumination=illum).values)[0],
             illum_polarization=ensure_array(
                 schema.illum_polarization.sel(illumination=illum).values))
         this_field = self._calculate_single_color_scattered_field(
             scatterer.select({illumination: illum}), this_schema)
         field.append(this_field)
     field = clean_concat(field, dim=schema.illum_wavelen.illumination)
     return field
示例#4
0
文件: dda.py 项目: sid6155330/holopy
    def _adda_discretized(self, scatterer, medium_wavelen, medium_index, temp_dir):
        spacing = self.required_spacing(medium_wavelen, medium_index, scatterer.n)
        outf = tempfile.NamedTemporaryFile(dir = temp_dir, delete=False)

        vox = scatterer.voxelate_domains(spacing)
        idx = np.concatenate([g[..., np.newaxis] for g in
                              np.mgrid[[slice(0,d) for d in vox.shape]]],
                             3).reshape((-1, 3))
        vox = vox.flatten()
        ns = ensure_array(scatterer.n)
        n_domains = len(ns)
        if n_domains > 1:
            out = np.hstack((idx, vox[...,np.newaxis]))
            outf.write("Nmat={0}\n".format(n_domains).encode('utf-8'))
        else:
            out = idx
        np.savetxt(outf, out[np.nonzero(vox)], fmt='%d')
        outf.close()

        cmd = []
        cmd.extend(['-shape', 'read', outf.name])
        cmd.extend(
            ['-dpl', str(self._dpl(medium_wavelen, medium_index, scatterer.n))])
        cmd.extend(['-m'])
        for n in ns:
            m = n.real/medium_index
            if m == 1:
                warnings.warn("Adda cannot compute particles with index equal to medium index, adjusting particle index {} to {}".format(m, m+1e-6))
                m += 1e-6
            cmd.extend([str(m), str(n.imag/medium_index)])
        return cmd
示例#5
0
    def rotated(self, ang1, ang2=None, ang3=None):

        if ang2 is None and len(ensure_array(ang1)==3):
            #entered rotation angle tuple
            alpha, beta, gamma = ang1
        elif ang2 is not None and ang3 is not None:
            #entered 3 angles
            alpha=ang1; beta=ang2; gamma=ang3
        else:
            raise InvalidScatterer(self, "Cannot interpret rotation coordinates")

        centers = np.array([s.center for s in self.scatterers])
        com = centers.mean(0)

        new_centers = com + rotate_points(centers - com, alpha, beta, gamma)

        scatterers = []

        for i in range(len(self.scatterers)):
            scatterers.append(self.scatterers[i].translated(
                *(new_centers[i,:] - centers[i,:])).rotated(alpha, beta, gamma))

        new = copy(self)
        new.scatterers = scatterers

        return new
示例#6
0
    def select(self, keys):
        """
        Select certain parts of a Scatterer with multiple parameter values

        Parameters
        ----------
        parameters: dict
            values to select. Should be of form {dim:val(s)}.

        Returns
        -------
        scatterer: Scatterer class
            A scatterer with only the values for each parameter specified.
        """
        params = _interpret_parameters(self.parameters)
        for key in params.keys():
            if isinstance(getattr(self, key), xr.DataArray):
                params[key] = getattr(self, key).sel(**keys).item()
            elif isinstance(params[key], dict):
                for dimkeys in keys.values():
                    params[key] = [params[key][dimkey]
                                   for dimkey in ensure_array(dimkeys)]
                    if len(params[key]) == 1:
                        params[key] = params[key][0]
        return type(self)(**params)
示例#7
0
文件: mie.py 项目: sid6155330/holopy
    def _scat_coeffs_internal(self, s, medium_wavevec, medium_index):
        '''
        Calculate expansion coefficients for Lorenz-Mie electric field
        inside a sphere.
        '''
        x_arr = medium_wavevec * ensure_array(s.r)
        m_arr = ensure_array(s.n) / medium_index

        # Check that the scatterer is in a range we can compute for
        if x_arr.max() > 1e3:
            msg = "radius too large, field calculation would take forever"
            raise InvalidScatterer(s, msg)

        if len(x_arr) == 1 and len(m_arr) == 1:
            # Could just use scatcoeffs_multi here, but jerome is in favor of
            # keeping the simpler single layer code here
            lmax = miescatlib.nstop(x_arr[0])
            return  miescatlib.internal_coeffs(m_arr[0], x_arr[0], lmax)
示例#8
0
def prep_schema(detector, medium_index, illum_wavelen, illum_polarization):
    detector = update_metadata(
        detector, medium_index, illum_wavelen, illum_polarization)

    if detector.illum_wavelen is None:
        raise MissingParameter("wavelength")
    if detector.medium_index is None:
        raise MissingParameter("medium refractive index")
    if illum_polarization is not False and detector.illum_polarization is None:
        raise MissingParameter("polarization")

    illum_wavelen = ensure_array(detector.illum_wavelen)
    illum_polarization = detector.illum_polarization

    if len(illum_wavelen) > 1 or ensure_array(illum_polarization).ndim == 2:
        #  multiple illuminations to calculate
        if illumination in illum_polarization.dims:
            if isinstance(illum_wavelen, xr.DataArray):
                pass
            else:
                if len(illum_wavelen) == 1:
                    illum_wavelen = illum_wavelen.repeat(
                        len(illum_polarization.illumination))
                illum_wavelen = xr.DataArray(
                    illum_wavelen, dims=illumination,
                    coords={illumination: illum_polarization.illumination})
        else:
            #  need to interpret illumination from detector.illum_wavelen
            if not isinstance(illum_wavelen, xr.DataArray):
                illum_wavelen = xr.DataArray(
                    illum_wavelen, dims=illumination,
                    coords={illumination: illum_wavelen})
            illum_polarization = xr.broadcast(
                illum_polarization, illum_wavelen, exclude=[vector])[0]

        if illumination in detector.dims:
            detector = detector.sel(
                illumination=detector.illumination[0], drop=True)
        detector = update_metadata(
            detector, illum_wavelen=illum_wavelen,
            illum_polarization=illum_polarization)

    return detector
示例#9
0
def make_coords(shape, spacing, z=0):
    if np.isscalar(shape):
        shape = np.repeat(shape, 2)
    if np.isscalar(spacing):
        spacing = np.repeat(spacing, 2)
    to_return = OrderedDict([
        ('z', ensure_array(z)),
        ('x', np.arange(shape[1]) * spacing[0]),
        ('y', np.arange(shape[2]) * spacing[1]),
        ])
    return to_return
示例#10
0
 def _lnlike(self, pars, data):
     """
     Internal function taking pars as a list only
     """
     noise_sd = self._find_noise(pars, data)
     N = data.size
     log_likelihood = ensure_scalar(
         -N / 2 * np.log(2 * np.pi) -
         N * np.mean(np.log(ensure_array(noise_sd))) - 0.5 *
         (self._residuals(pars, data, noise_sd)**2).sum())
     return log_likelihood
示例#11
0
 def index_at(self, points, background=0):
     domains = self.in_domain(points)
     ns = ensure_array(self.n)
     if np.iscomplex(np.append(self.n, background)).any():
         dtype = np.complex
     else:
         dtype = np.float
     index = np.ones_like(domains, dtype=dtype) * background
     for i, n in enumerate(ns):
         index[domains==i+1] = n
     return index
示例#12
0
文件: mie.py 项目: sid6155330/holopy
    def _scat_coeffs(self, s, medium_wavevec, medium_index):
        '''
        Calculate Mie scattering coefficients.

        Parameters
        ----------
        s : :mod:`scatterer.Sphere` object
        medium_wavevec : float
            Wave vector in the medium, k = 2 * pi * n_med / lambda_0
        medium_index : float
            Medium refractive index

        Returns
        -------
        ndarray (2, n), complex
           Lorenz-Mie scattering coefficients a_n and b_n

        Notes
        -----
        See Bohren & Huffman for mathematical description.

        '''
        if (ensure_array(s.r) == 0).any():
            raise InvalidScatterer(s, "Radius is zero")
        x_arr = ensure_array(medium_wavevec * ensure_array(s.r))
        m_arr = ensure_array(ensure_array(s.n) / medium_index)

        # Check that the scatterer is in a range we can compute for
        if x_arr.max() > 1e3:
            msg =  "radius too large, field calculation would take forever"
            raise InvalidScatterer(s, msg)

        if len(x_arr) == 1 and len(m_arr) == 1:
            # Could just use scatcoeffs_multi here, but jerome is in favor of
            # keeping the simpler single layer code here
            lmax = miescatlib.nstop(x_arr[0])
            return  miescatlib.scatcoeffs(m_arr[0], x_arr[0], lmax, self.eps1,
                                          self.eps2)
        else:
            return scatcoeffs_multi(m_arr, x_arr, self.eps1, self.eps2)
示例#13
0
 def __init__(self,
              scatterer,
              noise_sd,
              medium_index=None,
              illum_wavelen=None,
              illum_polarization=None,
              theory='auto'):
     super().__init__(scatterer,
                      medium_index=medium_index,
                      illum_wavelen=illum_wavelen,
                      illum_polarization=illum_polarization,
                      theory=theory)
     # the float cast insures we don't have noise_sd wrapped up in a needless xarray
     self._use_parameter(ensure_array(noise_sd), 'noise_sd')
示例#14
0
    def _lnlike(self, pars, data):
        """
        Compute the likelihood for pars given data

        Parameters
        -----------
        pars: dict(string, float)
            Dictionary containing values for each parameter
        data: xarray
            The data to compute likelihood against
        """
        noise_sd = dict_to_array(data,self.get_par('noise_sd', pars, data))
        forward = self._forward(pars, data)
        N = data.size
        return (-N/2*np.log(2*np.pi)-N*np.mean(np.log(ensure_array(noise_sd))) -
                ((forward-data)**2/(2*noise_sd**2)).values.sum())
示例#15
0
 def __init__(self, indicators, n, center):
     """
     Parameters
     ----------
     indicators : function or list of functions
         Function or functions returning true for points inside the
         scatterer (or inside a specific domain) and false outside.
     n : complex
         Index of refraction of the scatterer or each domain.
     center : (float, float, float)
         The center of mass of the scatterer.
     """
     if not isinstance(indicators, Indicators):
         indicators = Indicators(indicators)
     self.indicators = indicators
     self.n = ensure_array(n)
     self.center = np.array(center)
示例#16
0
    def _lnlike(self, pars, data):
        """
        Compute the likelihood for pars given data

        Parameters
        -----------
        pars: dict(string, float)
            Dictionary containing values for each parameter
        data: xarray
            The data to compute likelihood against
        """
        noise_sd = dict_to_array(data, self.get_par('noise_sd', pars, data))
        forward = self._forward(pars, data)
        N = data.size
        return (-N / 2 * np.log(2 * np.pi) -
                N * np.mean(np.log(ensure_array(noise_sd))) -
                ((forward - data)**2 / (2 * noise_sd**2)).values.sum())
示例#17
0
 def __init__(self, scatterer, noise_sd=None, medium_index=None,
              illum_wavelen=None, illum_polarization=None, theory='auto',
              constraints=[]):
     self.scatterer = scatterer
     self.constraints = ensure_listlike(constraints)
     self._parameters = []
     self._use_parameters(scatterer.parameters, False)
     if not (np.isscalar(noise_sd) or isinstance(noise_sd, (Prior, dict))):
         noise_sd = ensure_array(noise_sd)
     parameters_to_use = {
         'medium_index': medium_index,
         'illum_wavelen': illum_wavelen,
         'illum_polarization': illum_polarization,
         'theory': theory,
         'noise_sd': noise_sd,
         }
     self._check_parameters_are_not_xarray(parameters_to_use)
     self._use_parameters(parameters_to_use)
示例#18
0
def pack_attrs(a, do_spacing=False):
    new_attrs = {attr_coords: {}}
    if a.name is not None:
        new_attrs['name'] = a.name
    if do_spacing:
        new_attrs['spacing'] = list(get_spacing(a))

    for attr, val in a.attrs.items():
        if isinstance(val, xr.DataArray):
            new_attrs[attr_coords][attr] = OrderedDict()
            for dim in val.dims:
                new_attrs[attr_coords][attr][str(dim)] = val[dim].values
            new_attrs[attr] = list(ensure_array(val.values))
        else:
            new_attrs[attr_coords][attr] = False
            if val is not None:
                new_attrs[attr] = yaml.dump(val)
    new_attrs[attr_coords] = yaml.dump(new_attrs[attr_coords],
                                       default_flow_style=True)
    return new_attrs
示例#19
0
 def __init__(self,
              scatterer,
              noise_sd=None,
              medium_index=None,
              illum_wavelen=None,
              illum_polarization=None,
              theory='auto',
              constraints=[]):
     dummy_parameters = {key: [0, 0, 0] for key in scatterer.parameters}
     self._dummy_scatterer = scatterer.from_parameters(dummy_parameters)
     self.theory = theory
     self.constraints = ensure_listlike(constraints)
     if not (np.isscalar(noise_sd) or isinstance(noise_sd, (Prior, dict))):
         noise_sd = ensure_array(noise_sd)
     optics = [medium_index, illum_wavelen, illum_polarization, noise_sd]
     optics_parameters = {key: val for key, val in zip(OPTICS_KEYS, optics)}
     self._parameters = []
     self._parameter_names = []
     self._maps = {
         'scatterer': self._convert_to_map(scatterer.parameters),
         'optics': self._convert_to_map(optics_parameters)
     }
示例#20
0
    def lnlike(self, pars, data):
        """
        Compute the log-likelihood for pars given data

        Parameters
        -----------
        pars: dict(string, float)
            Dictionary containing values for each parameter
        data: xarray
            The data to compute likelihood against

        Returns
        --------
        lnlike: float
        """
        noise_sd = self._find_noise(pars, data)
        N = data.size
        log_likelihood = ensure_scalar(
            -N/2 * np.log(2 * np.pi) -
            N * np.mean(np.log(ensure_array(noise_sd))) -
            0.5 * (self._residuals(pars, data, noise_sd)**2).sum())
        return log_likelihood
示例#21
0
    def calculate_scattered_field(self, scatterer, schema):
        """
        Implemented in derived classes only.

        Parameters
        ----------
        scatterer : :mod:`.scatterer` object
            (possibly composite) scatterer for which to compute scattering

        Returns
        -------
        e_field : :mod:`.VectorGrid`
            scattered electric field
        """
        if scatterer.center is None:
            raise MissingParameter("center")
        is_multicolor_hologram = len(ensure_array(schema.illum_wavelen)) > 1
        field = (self._calculate_multiple_color_scattered_field(
            scatterer, schema) if is_multicolor_hologram else
                 self._calculate_single_color_scattered_field(
                     scatterer, schema))
        return field
示例#22
0
 def test_xarrays_without_coords(self):
     self.assertEqual(ensure_array(xr.DataArray(1)), np.array([1]))
     self.assertEqual(ensure_array(xr.DataArray([1])), np.array([1]))
示例#23
0
def load_average(filepath,
                 refimg=None,
                 spacing=None,
                 medium_index=None,
                 illum_wavelen=None,
                 illum_polarization=None,
                 normals=None,
                 noise_sd=None,
                 channel=None,
                 image_glob='*.tif'):
    """
    Average a set of images (usually as a background)

    Parameters
    ----------
    filepath : string or list(string)
        Directory or list of filenames or filepaths. If filename is a directory,
        it will average all images matching image_glob.
    refimg : xarray.DataArray
        reference image to provide spacing and metadata for the new image.
    spacing : float
        Spacing between pixels in the images. Used preferentially over refimg value if both are provided.
    medium_index : float
        Refractive index of the medium in the images. Used preferentially over refimg value if both are provided.
    illum_wavelen : float
        Wavelength of illumination in the images. Used preferentially over refimg value if both are provided.
    illum_polarization : list-like
        Polarization of illumination in the images. Used preferentially over refimg value if both are provided.
    image_glob : string
        Glob used to select images (if images is a directory)

    Returns
    -------
    averaged_image : xarray.DataArray
        Image which is an average of images
        noise_sd attribute contains average pixel stdev normalized by total image intensity
    """
    if normals is not None:
        raise ValueError(NORMALS_DEPRECATION_MESSAGE)

    if isinstance(filepath, str):
        if os.path.isdir(filepath):
            filepath = glob.glob(os.path.join(filepath, image_glob))
        else:
            #only a single image
            filepath = [filepath]

    if len(filepath) < 1:
        raise LoadError(filepath, "No images found")

    # read spacing from refimg if none provided
    if spacing is None:
        spacing = get_spacing(refimg)

    # read colour channels from refimg
    channel_dict = {'0': 'red', '1': 'green', '2': 'blue'}
    if channel is None and refimg is not None and illumination in refimg.dims:
        channel = [
            i for i, col in enumerate(['red', 'green', 'blue'])
            if col in refimg[illumination].values
        ]

    if np.isscalar(spacing):
        spacing = np.repeat(spacing, 2)

    # calculate the average
    accumulator = Accumulator()
    for path in filepath:
        accumulator.push(load_image(path, spacing, channel=channel))
    mean_image = accumulator.mean()

    # calculate average noise from image
    if noise_sd is None and len(filepath) > 1:
        if channel:
            noise_sd = xr.DataArray(accumulator.cv(),
                                    [[channel_dict[str(ch)]
                                      for ch in channel]], ['illumination'])
        else:
            noise_sd = ensure_array(accumulator.cv())

    # crop according to refimg dimensions
    if refimg is not None:

        def extent(i):
            name = ['x', 'y'][i]
            return np.around(refimg[name].values / spacing[i]).astype('int')

        mean_image = mean_image.isel(x=extent(0), y=extent(1))
        mean_image['x'] = refimg.x
        mean_image['y'] = refimg.y

    # copy metadata from refimg
    if refimg is not None:
        mean_image = copy_metadata(refimg, mean_image, do_coords=False)

    # overwrite metadata from refimg with provided values
    return update_metadata(mean_image, medium_index, illum_wavelen,
                           illum_polarization, normals, noise_sd)
示例#24
0
def load_image(inf,
               spacing=None,
               medium_index=None,
               illum_wavelen=None,
               illum_polarization=None,
               normals=None,
               noise_sd=None,
               channel=None,
               name=None):
    """
    Load data or results

    Parameters
    ----------
    inf : string
        File to load.
    spacing : float or (float, float) (optional)
        pixel size of images in each dimension - assumes square pixels if single value.
        set equal to 1 if not passed in and issues warning.
    medium_index : float (optional)
        refractive index of the medium
    illum_wavelen : float (optional)
        wavelength (in vacuum) of illuminating light
    illum_polarization : (float, float) (optional)
        (x, y) polarization vector of the illuminating light
    noise_sd : float (optional)
        noise level in the image, normalized to image intensity
    channel : int or tuple of ints (optional)
        number(s) of channel to load for a color image (in general 0=red,
        1=green, 2=blue)
	name : str (optional)
        name to assign the xr.DataArray object resulting from load_image

    Returns
    -------
    obj : xarray.DataArray representation of the image with associated metadata

    """
    if normals is not None:
        raise ValueError(NORMALS_DEPRECATION_MESSAGE)
    if name is None:
        name = os.path.splitext(os.path.split(inf)[-1])[0]

    with open(inf, 'rb') as pi_raw:
        pi = pilimage.open(pi_raw)
        arr = np.asarray(pi).astype('d')
        try:
            if isinstance(yaml.safe_load(pi.tag[270][0]), dict):
                warnings.warn(
                    "Metadata detected but ignored. Use hp.load to read it.")
        except (AttributeError, KeyError):
            pass

    extra_dims = None
    if channel is None:
        if arr.ndim > 2:
            raise BadImage(
                'Not a greyscale image. You must specify which channel(s) to use'
            )
    elif arr.ndim == 2:
        if not channel == 'all':
            warnings.warn("Not a color image (channel number ignored)")
        pass
    else:
        # color image with specified channel(s)
        if channel == 'all':
            channel = range(arr.shape[2])
        channel = ensure_array(channel)
        if channel.max() >= arr.shape[2]:
            raise LoadError(
                filename, "The image doesn't have a channel number {0}".format(
                    channel.max()))
        else:
            arr = arr[:, :, channel].squeeze()

            if len(channel) > 1:
                # multiple channels. increase output dimensionality
                if channel.max() <= 2:
                    channel = [['red', 'green', 'blue'][c] for c in channel]
                extra_dims = {illumination: channel}
                if illum_wavelen is not None and not isinstance(
                        illum_wavelen, dict) and len(
                            ensure_array(illum_wavelen)) == len(channel):
                    illum_wavelen = xr.DataArray(ensure_array(illum_wavelen),
                                                 dims=illumination,
                                                 coords=extra_dims)
                if not isinstance(illum_polarization, dict) and np.array(
                        illum_polarization).ndim == 2:
                    pol_index = xr.DataArray(channel,
                                             dims=illumination,
                                             name=illumination)
                    illum_polarization = xr.concat(
                        [to_vector(pol) for pol in illum_polarization],
                        pol_index)

    image = data_grid(arr,
                      spacing=spacing,
                      medium_index=medium_index,
                      illum_wavelen=illum_wavelen,
                      illum_polarization=illum_polarization,
                      noise_sd=noise_sd,
                      name=name,
                      extra_dims=extra_dims)
    return image
示例#25
0
 def test_xarray_is_unchanged(self):
     xr_array = xr.DataArray([2], dims='a', coords={'a': ['b']})
     self.assertTrue(xr_array.equals(ensure_array(xr_array)))
示例#26
0
 def __init__(self, n=None, t=None, center=None):
     self.n = ensure_array(n)
     self.t = ensure_array(t)
     self.center = center
示例#27
0
 def __init__(self, scatterer, noise_sd, medium_index=None, illum_wavelen=None, illum_polarization=None, theory='auto'):
     super().__init__(scatterer, medium_index=medium_index, illum_wavelen=illum_wavelen, illum_polarization=illum_polarization, theory=theory)
     # the float cast insures we don't have noise_sd wrapped up in a needless xarray
     self._use_parameter(ensure_array(noise_sd), 'noise_sd')
示例#28
0
 def indicators(self):
     rs = ensure_array(self.r)
     funcs = [(lambda points, ri=ri: (points**2).sum(-1) < ri**2)
              for ri in rs]
     r = max(rs)
     return Indicators(funcs, [[-r, r], [-r, r], [-r, r]])
示例#29
0
def display_image(im,
                  scaling='auto',
                  vert_axis='x',
                  horiz_axis='y',
                  depth_axis='z',
                  colour_axis='illumination'):
    im = im.copy()
    if isinstance(im, xr.DataArray):
        if hasattr(im, 'z') and len(im['z']) == 1 and depth_axis is not 'z':
            im = im[{'z': 0}]
        if depth_axis == 'z' and 'z' not in im.dims:
            im = im.expand_dims('z')
        if im.ndim > 3 + (colour_axis in im.dims):
            raise BadImage("Too many dims on DataArray to output properly.")
        attrs = im.attrs
    else:
        attrs = {}
        im = ensure_array(im)
        if im.ndim > 3:
            raise BadImage("Too many dims on ndarray to output properly.")
        elif im.ndim == 2:
            im = np.array([im])
        elif im.ndim < 2:
            raise BadImage("Too few dims on ndarray to output properly.")
        axes = [0, 1, 2]
        for axis in [vert_axis, horiz_axis, depth_axis]:
            if isinstance(axis, int):
                try:
                    axes.remove(axis)
                except KeyError:
                    raise ValueError("Cannot interpret axis specifications.")
        if len(axes) > 0:
            if not isinstance(depth_axis, int):
                depth_axis = axes[np.argmin([im.shape[i] for i in axes])]
                axes.remove(depth_axis)
            if not isinstance(vert_axis, int):
                vert_axis = axes[0]
                axes.pop(0)
            if not isinstance(horiz_axis, int):
                horiz_axis = axes[0]
        im = im.transpose([depth_axis, vert_axis, horiz_axis])
        depth_axis = 'z'
        vert_axis = 'x'
        horiz_axis = 'y'
        im = data_grid(im, spacing=1, z=range(len(im)))
    if np.iscomplex(im).any():
        warn("Image contains complex values. Taking image magnitude.")
        im = np.abs(im)
    if scaling is 'auto':
        scaling = (ensure_scalar(im.min()), ensure_scalar(im.max()))
    if scaling is not None:
        im = np.maximum(im, scaling[0])
        im = np.minimum(im, scaling[1])
        im = (im - scaling[0]) / (scaling[1] - scaling[0])
    im.attrs = attrs
    im.attrs['_image_scaling'] = scaling

    if colour_axis in im.dims:
        cols = [
            col[0].capitalize() if isinstance(col, str) else ' '
            for col in im[colour_axis].values
        ]
        RGB_names = np.all([letter in 'RGB' for letter in cols])
        if len(im[colour_axis]) == 1:
            im = im.squeeze(dim=colour_axis)
        elif len(im[colour_axis]) > 3:
            raise BadImage('Cannot output more than 3 colour channels')
        elif RGB_names:
            channels = {
                col: im[{
                    colour_axis: i
                }]
                for i, col in enumerate(cols)
            }
            if len(channels) == 2:
                dummy = im[{colour_axis: 0}].copy()
                dummy[:] = im.min()
                for i, col in enumerate('RGB'):
                    if col not in cols:
                        dummy[colour_axis] = col
                        channels[col] = dummy
                        channels['R'].attrs['_dummy_channel'] = i
                        break
            channels = [channels[col] for col in 'RGB']
            im = clean_concat(channels, colour_axis)
        elif len(im[colour_axis]) == 2:
            dummy = xr.full_like(im[{colour_axis: 0}], fill_value=im.min())
            dummy = dummy.expand_dims({colour_axis: [np.NaN]})
            im.attrs['_dummy_channel'] = -1
            im = clean_concat([im, dummy], colour_axis)
    dim_order = [depth_axis, vert_axis, horiz_axis, colour_axis][:im.ndim]
    return im.transpose(*dim_order)
示例#30
0
 def test_listlike(self):
     self.assertEqual(ensure_array([1]), np.array([1]))
     self.assertEqual(ensure_array((1)), np.array([1]))
     self.assertEqual(ensure_array(np.array([1])), np.array([1]))
示例#31
0
 def test_None_is_unchanged(self):
     self.assertTrue(ensure_array(None) is None)
示例#32
0
 def test_zero_d_objects(self):
     self.assertEqual(ensure_array(1), np.array([1]))
     self.assertEqual(ensure_array(np.array(1)), np.array([1]))
     zero_d_xarray = xr.DataArray(2, coords={'a': 'b'})
     xr_array = xr.DataArray([2], dims='a', coords={'a': ['b']})
     self.assertTrue(xr_array.equals(ensure_array(zero_d_xarray)))