Exemplo n.º 1
0
class Signal(t.HasTraits, MVA):
    data = t.Any()
    axes_manager = t.Instance(AxesManager)
    original_parameters = t.Instance(Parameters)
    mapped_parameters = t.Instance(Parameters)
    physical_property = t.Str()

    def __init__(self, file_data_dict=None, *args, **kw):
        """All data interaction is made through this class or its subclasses


        Parameters:
        -----------
        dictionary : dictionary
           see load_dictionary for the format
        """
        super(Signal, self).__init__()
        self.mapped_parameters = Parameters()
        self.original_parameters = Parameters()
        if type(file_data_dict).__name__ == "dict":
            self.load_dictionary(file_data_dict)
        self._plot = None
        self.mva_results=MVA_Results()
        self._shape_before_unfolding = None
        self._axes_manager_before_unfolding = None

    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(
            file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])

    def _get_signal_dict(self):
        dic = {}
        dic['data'] = self.data.copy()
        dic['axes'] = self.axes_manager._get_axes_dicts()
        dic['mapped_parameters'] = \
        self.mapped_parameters._get_parameters_dictionary()
        dic['original_parameters'] = \
        self.original_parameters._get_parameters_dictionary()
        return dic

    def _get_undefined_axes_list(self):
        axes = []
        for i in xrange(len(self.data.shape)):
            axes.append({
                        'name': 'undefined',
                        'scale': 1.,
                        'offset': 0.,
                        'size': int(self.data.shape[i]),
                        'units': 'undefined',
                        'index_in_array': i, })
        return axes

    def __call__(self, axes_manager=None):
        if axes_manager is None:
            axes_manager = self.axes_manager
        return self.data.__getitem__(axes_manager._getitem_tuple)

    def _get_hse_1D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        inslice = self.axes_manager._non_slicing_axes[0].index_in_array
        if islice > inslice:
            return self.data.squeeze()
        else:
            return self.data.squeeze().T

    def _get_hse_2D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        data = self.data.sum(islice)
        return data

    def _get_hie_explorer(self, *args, **kwargs):
        isslice = [self.axes_manager._slicing_axes[0].index_in_array,
                   self.axes_manager._slicing_axes[1].index_in_array]
        isslice.sort()
        data = self.data.sum(isslice[1]).sum(isslice[0])
        return data

    def _get_explorer(self, *args, **kwargs):
        nav_dim = self.axes_manager.navigation_dimension
        if self.axes_manager.signal_dimension == 1:
            if nav_dim == 1:
                return self._get_hse_1D_explorer(*args, **kwargs)
            elif nav_dim == 2:
                return self._get_hse_2D_explorer(*args, **kwargs)
            else:
                return None
        if self.axes_manager.signal_dimension == 2:
            if nav_dim == 1 or nav_dim == 2:
                return self._get_hie_explorer(*args, **kwargs)
            else:
                return None
        else:
            return None

    def plot(self, axes_manager=None):
        if self._plot is not None:
                try:
                    self._plot.close()
                except:
                    # If it was already closed it will raise an exception,
                    # but we want to carry on...
                    pass

        if axes_manager is None:
            axes_manager = self.axes_manager

        if axes_manager.signal_dimension == 1:
            # Hyperspectrum

            self._plot = mpl_hse.MPL_HyperSpectrum_Explorer()
            self._plot.spectrum_data_function = self.__call__
            self._plot.spectrum_title = self.mapped_parameters.name
            self._plot.xlabel = '%s (%s)' % (
                self.axes_manager._slicing_axes[0].name,
                self.axes_manager._slicing_axes[0].units)
            self._plot.ylabel = 'Intensity'
            self._plot.axes_manager = axes_manager
            self._plot.axis = self.axes_manager._slicing_axes[0].axis

            # Image properties
            if self.axes_manager._non_slicing_axes:
                self._plot.image_data_function = self._get_explorer
                self._plot.image_title = ''
                self._plot.pixel_size = \
                self.axes_manager._non_slicing_axes[0].scale
                self._plot.pixel_units = \
                self.axes_manager._non_slicing_axes[0].units
            self._plot.plot()

        elif axes_manager.signal_dimension == 2:

            # Mike's playground with new plotting toolkits - needs to be a
            # branch.
            """
            if len(self.data.shape)==2:
                from drawing.guiqwt_hie import image_plot_2D
                image_plot_2D(self)

            import drawing.chaco_hie
            self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self)
            self._plot.configure_traits()
            """
            self._plot = mpl_hie.MPL_HyperImage_Explorer()
            self._plot.image_data_function = self.__call__
            self._plot.navigator_data_function = self._get_explorer
            self._plot.axes_manager = axes_manager
            self._plot.plot()

        else:
            messages.warning_exit('Plotting is not supported for this view')

    traits_view = tui.View(
        tui.Item('name'),
        tui.Item('physical_property'),
        tui.Item('units'),
        tui.Item('offset'),
        tui.Item('scale'),)

    def plot_residual(self, axes_manager=None):
        """Plot the residual between original data and reconstructed data

        Requires you to have already run PCA or ICA, and to reconstruct data
        using either the pca_build_SI or ica_build_SI methods.
        """

        if hasattr(self, 'residual'):
            self.residual.plot(axes_manager)
        else:
            print "Object does not have any residual information.  Is it a \
reconstruction created using either pca_build_SI or ica_build_SI methods?"

    def save(self, filename, only_view = False, **kwds):
        """Saves the signal in the specified format.

        The function gets the format from the extension. You can use:
            - hdf5 for HDF5
            - nc for NetCDF
            - msa for EMSA/MSA single spectrum saving.
            - bin to produce a raw binary file
            - Many image formats such as png, tiff, jpeg...

        Please note that not all the formats supports saving datasets of
        arbitrary dimensions, e.g. msa only suports 1D data.

        Parameters
        ----------
        filename : str
        msa_format : {'Y', 'XY'}
            'Y' will produce a file without the energy axis. 'XY' will also
            save another column with the energy axis. For compatibility with
            Gatan Digital Micrograph 'Y' is the default.
        only_view : bool
            If True, only the current view will be saved. Otherwise the full
            dataset is saved. Please note that not all the formats support this
            option at the moment.
        """
        io.save(filename, self, **kwds)

    def _replot(self):
        if self._plot is not None:
            if self._plot.is_active() is True:
                self.plot()

    def get_dimensions_from_data(self):
        """Get the dimension parameters from the data_cube. Useful when the
        data_cube was externally modified, or when the SI was not loaded from
        a file
        """
        dc = self.data
        for axis in self.axes_manager.axes:
            axis.size = int(dc.shape[axis.index_in_array])
            print("%s size: %i" %
            (axis.name, dc.shape[axis.index_in_array]))
        self._replot()

    def crop_in_pixels(self, axis, i1 = None, i2 = None):
        """Crops the data in a given axis. The range is given in pixels
        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_units
        """
        axis = self._get_positive_axis_index_index(axis)
        if i1 is not None:
            new_offset = self.axes_manager.axes[axis].axis[i1]
        # We take a copy to guarantee the continuity of the data
        self.data = self.data[
        (slice(None),)*axis + (slice(i1, i2), Ellipsis)].copy()

        if i1 is not None:
            self.axes_manager.axes[axis].offset = new_offset
        self.get_dimensions_from_data()

    def crop_in_units(self, axis, x1 = None, x2 = None):
        """Crops the data in a given axis. The range is given in the units of
        the axis

        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_pixels

        """
        i1 = self.axes_manager.axes[axis].value2index(x1)
        i2 = self.axes_manager.axes[axis].value2index(x2)
        self.crop_in_pixels(axis, i1, i2)

    def roll_xy(self, n_x, n_y = 1):
        """Roll over the x axis n_x positions and n_y positions the former rows

        This method has the purpose of "fixing" a bug in the acquisition of the
        Orsay's microscopes and probably it does not have general interest

        Parameters
        ----------
        n_x : int
        n_y : int

        Note: Useful to correct the SI column storing bug in Marcel's
        acquisition routines.
        """
        self.data = np.roll(self.data, n_x, 0)
        self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1)
        self._replot()

    # TODO: After using this function the plotting does not work
    def swap_axis(self, axis1, axis2):
        """Swaps the axes

        Parameters
        ----------
        axis1 : positive int
        axis2 : positive int
        """
        self.data = self.data.swapaxes(axis1, axis2)
        c1 = self.axes_manager.axes[axis1]
        c2 = self.axes_manager.axes[axis2]
        c1.index_in_array, c2.index_in_array =  \
            c2.index_in_array, c1.index_in_array
        self.axes_manager.axes[axis1] = c2
        self.axes_manager.axes[axis2] = c1
        self.axes_manager.set_signal_dimension()
        self._replot()

    def rebin(self, new_shape):
        """
        Rebins the data to the new shape

        Parameters
        ----------
        new_shape: tuple of ints
            The new shape must be a divisor of the original shape
        """
        factors = np.array(self.data.shape) / np.array(new_shape)
        self.data = utils.rebin(self.data, new_shape)
        for axis in self.axes_manager.axes:
            axis.scale *= factors[axis.index_in_array]
        self.get_dimensions_from_data()

    def split_in(self, axis, number_of_parts = None, steps = None):
        """Splits the data

        The split can be defined either by the `number_of_parts` or by the
        `steps` size.

        Parameters
        ----------
        number_of_parts : int or None
            Number of parts in which the SI will be splitted
        steps : int or None
            Size of the splitted parts
        axis : int
            The splitting axis

        Return
        ------
        tuple with the splitted signals
        """
        axis = self._get_positive_axis_index_index(axis)
        if number_of_parts is None and steps is None:
            if not self._splitting_steps:
                messages.warning_exit(
                "Please provide either number_of_parts or a steps list")
            else:
                steps = self._splitting_steps
                print "Splitting in ", steps
        elif number_of_parts is not None and steps is not None:
            print "Using the given steps list. number_of_parts dimissed"
        splitted = []
        shape = self.data.shape

        if steps is None:
            rounded = (shape[axis] - (shape[axis] % number_of_parts))
            step = rounded / number_of_parts
            cut_node = range(0, rounded+step, step)
        else:
            cut_node = np.array([0] + steps).cumsum()
        for i in xrange(len(cut_node)-1):
            data = self.data[
            (slice(None), ) * axis + (slice(cut_node[i], cut_node[i + 1]),
            Ellipsis)]
            s = Signal({'data': data})
            # TODO: When copying plotting does not work
#            s.axes = copy.deepcopy(self.axes_manager)
            s.get_dimensions_from_data()
            splitted.append(s)
        return splitted

    def unfold_if_multidim(self):
        """Unfold the datacube if it is >2D

        Returns
        -------

        Boolean. True if the data was unfolded by the function.
        """
        if len(self.axes_manager.axes)>2:
            print "Automatically unfolding the data"
            self.unfold()
            return True
        else:
            return False

    def _unfold(self, steady_axes, unfolded_axis):
        """Modify the shape of the data by specifying the axes the axes which
        dimension do not change and the axis over which the remaining axes will
        be unfolded

        Parameters
        ----------
        steady_axes : list
            The indexes of the axes which dimensions do not change
        unfolded_axis : int
            The index of the axis over which all the rest of the axes (except
            the steady axes) will be unfolded

        See also
        --------
        fold
        """

        # It doesn't make sense unfolding when dim < 3
        if len(self.data.squeeze().shape) < 3:
            return False

        # We need to store the original shape and coordinates to be used by
        # the fold function only if it has not been already stored by a
        # previous unfold
        if self._shape_before_unfolding is None:
            self._shape_before_unfolding = self.data.shape
            self._axes_manager_before_unfolding = self.axes_manager

        new_shape = [1] * len(self.data.shape)
        for index in steady_axes:
            new_shape[index] = self.data.shape[index]
        new_shape[unfolded_axis] = -1
        self.data = self.data.reshape(new_shape)
        self.axes_manager = self.axes_manager.deepcopy()
        i = 0
        uname = ''
        uunits = ''
        to_remove = []
        for axis, dim in zip(self.axes_manager.axes, new_shape):
            if dim == 1:
                uname += ',' + axis.name
                uunits = ',' + axis.units
                to_remove.append(axis)
            else:
                axis.index_in_array = i
                i += 1
        self.axes_manager.axes[unfolded_axis].name += uname
        self.axes_manager.axes[unfolded_axis].units += uunits
        self.axes_manager.axes[unfolded_axis].size = \
                                                self.data.shape[unfolded_axis]
        for axis in to_remove:
            self.axes_manager.axes.remove(axis)

        self.data = self.data.squeeze()
        self._replot()

    def unfold(self):
        """Modifies the shape of the data by unfolding the signal and
        navigation dimensions separaterly

        """
        self.unfold_navigation_space()
        self.unfold_signal_space()

    def unfold_navigation_space(self):
        """Modify the shape of the data to obtain a navigation space of
        dimension 1
        """

        if self.axes_manager.navigation_dimension < 2:
            messages.information('Nothing done, the navigation dimension was '
                                'already 1')
            return False
        steady_axes = [
                        axis.index_in_array for axis in
                        self.axes_manager._slicing_axes]
        unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def unfold_signal_space(self):
        """Modify the shape of the data to obtain a signal space of
        dimension 1
        """
        if self.axes_manager.signal_dimension < 2:
            messages.information('Nothing done, the signal dimension was '
                                'already 1')
            return False
        steady_axes = [
                        axis.index_in_array for axis in
                        self.axes_manager._non_slicing_axes]
        unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def fold(self):
        """If the signal was previously unfolded, folds it back"""
        if self._shape_before_unfolding is not None:
            self.data = self.data.reshape(self._shape_before_unfolding)
            self.axes_manager = self._axes_manager_before_unfolding
            self._shape_before_unfolding = None
            self._axes_manager_before_unfolding = None
            self._replot()

    def _get_positive_axis_index_index(self, axis):
        if axis < 0:
            axis = len(self.data.shape) + axis
        return axis

    def iterate_axis(self, axis = -1):
        # We make a copy to guarantee that the data in contiguous, otherwise
        # it will not return a view of the data
        self.data = self.data.copy()
        axis = self._get_positive_axis_index_index(axis)
        unfolded_axis = axis - 1
        new_shape = [1] * len(self.data.shape)
        new_shape[axis] = self.data.shape[axis]
        new_shape[unfolded_axis] = -1
        # Warning! if the data is not contigous it will make a copy!!
        data = self.data.reshape(new_shape)
        for i in xrange(data.shape[unfolded_axis]):
            getitem = [0] * len(data.shape)
            getitem[axis] = slice(None)
            getitem[unfolded_axis] = i
            yield(data[getitem])

    def sum(self, axis, return_signal = False):
        """Sum the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
             will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.sum(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.sum(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.sum(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def mean(self, axis, return_signal = False):
        """Average the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
            will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.mean(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.mean(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.mean(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def copy(self):
        return(copy.copy(self))

    def deepcopy(self):
        return(copy.deepcopy(self))

#    def sum_in_mask(self, mask):
#        """Returns the result of summing all the spectra in the mask.
#
#        Parameters
#        ----------
#        mask : boolean numpy array
#
#        Returns
#        -------
#        Spectrum
#        """
#        dc = self.data_cube.copy()
#        mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape)
#        dc = (mask3D*dc).sum(1).sum(1) / mask.sum()
#        s = Spectrum()
#        s.data_cube = dc.reshape((-1,1,1))
#        s.get_dimensions_from_cube()
#        utils.copy_energy_calibration(self,s)
#        return s
#
#    def mean(self, axis):
#        """Average the SI over the given axis
#
#        Parameters
#        ----------
#        axis : int
#        """
#        dc = self.data_cube
#        dc = dc.mean(axis)
#        dc = dc.reshape(list(dc.shape) + [1,])
#        self.data_cube = dc
#        self.get_dimensions_from_cube()
#
#    def roll(self, axis = 2, shift = 1):
#        """Roll the SI. see numpy.roll
#
#        Parameters
#        ----------
#        axis : int
#        shift : int
#        """
#        self.data_cube = np.roll(self.data_cube, shift, axis)
#        self._replot()
#

#
#    def get_calibration_from(self, s):
#        """Copy the calibration from another Spectrum instance
#        Parameters
#        ----------
#        s : spectrum instance
#        """
#        utils.copy_energy_calibration(s, self)
#
#    def estimate_variance(self, dc = None, gaussian_noise_var = None):
#        """Variance estimation supposing Poissonian noise
#
#        Parameters
#        ----------
#        dc : None or numpy array
#            If None the SI is used to estimate its variance. Otherwise, the
#            provided array will be used.
#        Note
#        ----
#        The gain_factor and gain_offset from the aquisition parameters are used
#        """
#        print "Variace estimation using the following values:"
#        print "Gain factor = ", self.acquisition_parameters.gain_factor
#        print "Gain offset = ", self.acquisition_parameters.gain_offset
#        if dc is None:
#            dc = self.data_cube
#        gain_factor = self.acquisition_parameters.gain_factor
#        gain_offset = self.acquisition_parameters.gain_offset
#        self.variance = dc*gain_factor + gain_offset
#        if self.variance.min() < 0:
#            if gain_offset == 0 and gaussian_noise_var is None:
#                print "The variance estimation results in negative values"
#                print "Maybe the gain_offset is wrong?"
#                self.variance = None
#                return
#            elif gaussian_noise_var is None:
#                print "Clipping the variance to the gain_offset value"
#                self.variance = np.clip(self.variance, np.abs(gain_offset),
#                np.Inf)
#            else:
#                print "Clipping the variance to the gaussian_noise_var"
#                self.variance = np.clip(self.variance, gaussian_noise_var,
#                np.Inf)
#
#    def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6,
#    modify_calibration = True):
#        dispersion = (rcE - lcE) / (rc - lc)
#        origin = lcE - dispersion * lc
#        print "Energy step = ", dispersion
#        print "Energy origin = ", origin
#        if modify_calibration is True:
#            self.set_new_calibration(origin, dispersion)
#        return origin, dispersion
#
    def _correct_navigation_mask_when_unfolded(self, navigation_mask = None,):
        #if 'unfolded' in self.history:
        if navigation_mask is not None:
            navigation_mask = navigation_mask.reshape((-1,))
        return navigation_mask
Exemplo n.º 2
0
class TestAxesManager:
    def setup_method(self, method):
        self.axes_list = [{
            'name': 'x',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'y',
            'navigate': True,
            'offset': 0.0,
            'scale': 0.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [{
            'name': 'x',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'offset': 0.0,
            'scale': 2.5,
            'size': 4096,
            'units': 'eV'
        }, {
            'name': 'energy2',
            'navigate': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]
        self.am2 = AxesManager(self.axes_list2)

    def test_compact_unit(self):
        self.am.convert_units()
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['y'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5)
        assert self.am['energy'].units == 'keV'
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)

    def test_convert_to_navigation_units(self):
        self.am.convert_units(axes='navigation', units='mm')
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-6)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_units_axes_integer(self):
        # convert only the first axis
        self.am.convert_units(axes=0, units='nm', same_units=False)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-9)
        assert self.am['x'].units == 'm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

        self.am.convert_units(axes=0, units='nm', same_units=True)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'

    def test_convert_to_navigation_units_list(self):
        self.am.convert_units(axes='navigation',
                              units=['mm', 'nm'],
                              same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_to_navigation_units_list_same_units(self):
        self.am.convert_units(axes='navigation',
                              units=['mm', 'nm'],
                              same_units=True)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5e-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5e-6)
        assert self.am['energy'].units == 'eV'
        nt.assert_almost_equal(self.am['energy'].scale, 5)

    def test_convert_to_navigation_units_different(self):
        # Don't convert the units since the units of the navigation axes are
        # different
        self.axes_list.insert(
            0, {
                'name': 'time',
                'navigate': True,
                'offset': 0.0,
                'scale': 1.5,
                'size': 20,
                'units': 's'
            })
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['time'].units == 's'
        nt.assert_almost_equal(am['time'].scale, 1.5)
        assert am['x'].units == 'nm'
        nt.assert_almost_equal(am['x'].scale, 1.5)
        assert am['y'].units == 'nm'
        nt.assert_almost_equal(am['y'].scale, 0.5)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_navigation_units_Undefined(self):
        self.axes_list[0]['units'] = t.Undefined
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['x'].units == t.Undefined
        nt.assert_almost_equal(am['x'].scale, 1.5E-9)
        assert am['y'].units == 'm'
        nt.assert_almost_equal(am['y'].scale, 0.5E-9)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_signal_units(self):
        self.am.convert_units(axes='signal', units='keV')
        nt.assert_almost_equal(self.am['x'].scale, self.axes_list[0]['scale'])
        assert self.am['x'].units == self.axes_list[0]['units']
        nt.assert_almost_equal(self.am['y'].scale, self.axes_list[1]['scale'])
        assert self.am['y'].units == self.axes_list[1]['units']
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)
        assert self.am['energy'].units == 'keV'

    def test_convert_to_units_list(self):
        self.am.convert_units(units=['µm', 'nm', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3)
        assert self.am['y'].units == 'um'
        nt.assert_almost_equal(self.am['energy'].scale, 5E3)
        assert self.am['energy'].units == 'meV'

    def test_convert_to_units_list_same_units(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=True)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale,
                               self.axes_list2[1]['scale'])
        assert self.am2['energy'].units == self.axes_list2[1]['units']
        nt.assert_almost_equal(self.am2['energy2'].scale,
                               self.axes_list2[2]['scale'])
        assert self.am2['energy2'].units == self.axes_list2[2]['units']

    def test_convert_to_units_list_signal2D(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale, 2500)
        assert self.am2['energy'].units == 'meV'
        nt.assert_almost_equal(self.am2['energy2'].scale, 5.0)
        assert self.am2['energy2'].units == 'eV'

    @pytest.mark.parametrize("same_units", (True, False))
    def test_convert_to_units_unsupported_units(self, same_units):
        with assert_warns(message="not supported for conversion.",
                          category=UserWarning):
            self.am.convert_units('navigation',
                                  units='toto',
                                  same_units=same_units)
        assert_deep_almost_equal(self.am._get_axes_dicts(), self.axes_list)
Exemplo n.º 3
0
class TestAxesManager:

    def setup_method(self, method):
        self.axes_list = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'y',
             'navigate': True,
             'offset': 0.0,
             'scale': 0.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 2.5,
             'size': 4096,
             'units': 'eV'},
            {'name': 'energy2',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]
        self.am2 = AxesManager(self.axes_list2)

    def test_compact_unit(self):
        self.am.convert_units()
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['y'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5)
        assert self.am['energy'].units == 'keV'
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)

    def test_convert_to_navigation_units(self):
        self.am.convert_units(axes='navigation', units='mm')
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-6)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_units_axes_integer(self):
        # convert only the first axis
        self.am.convert_units(axes=0, units='nm', same_units=False)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-9)
        assert self.am['x'].units == 'm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

        self.am.convert_units(axes=0, units='nm', same_units=True)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'

    def test_convert_to_navigation_units_list(self):
        self.am.convert_units(axes='navigation', units=['mm', 'nm'],
                              same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_to_navigation_units_list_same_units(self):
        self.am.convert_units(axes='navigation', units=['mm', 'nm'],
                              same_units=True)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5e-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5e-6)
        assert self.am['energy'].units == 'eV'
        nt.assert_almost_equal(self.am['energy'].scale, 5)

    def test_convert_to_navigation_units_different(self):
        # Don't convert the units since the units of the navigation axes are
        # different
        self.axes_list.insert(0,
                              {'name': 'time',
                               'navigate': True,
                               'offset': 0.0,
                               'scale': 1.5,
                               'size': 20,
                               'units': 's'})
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['time'].units == 's'
        nt.assert_almost_equal(am['time'].scale, 1.5)
        assert am['x'].units == 'nm'
        nt.assert_almost_equal(am['x'].scale, 1.5)
        assert am['y'].units == 'nm'
        nt.assert_almost_equal(am['y'].scale, 0.5)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_navigation_units_Undefined(self):
        self.axes_list[0]['units'] = t.Undefined
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['x'].units == t.Undefined
        nt.assert_almost_equal(am['x'].scale, 1.5E-9)
        assert am['y'].units == 'm'
        nt.assert_almost_equal(am['y'].scale, 0.5E-9)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_signal_units(self):
        self.am.convert_units(axes='signal', units='keV')
        nt.assert_almost_equal(self.am['x'].scale, self.axes_list[0]['scale'])
        assert self.am['x'].units == self.axes_list[0]['units']
        nt.assert_almost_equal(self.am['y'].scale, self.axes_list[1]['scale'])
        assert self.am['y'].units == self.axes_list[1]['units']
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)
        assert self.am['energy'].units == 'keV'

    def test_convert_to_units_list(self):
        self.am.convert_units(units=['µm', 'nm', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3)
        assert self.am['y'].units == 'um'
        nt.assert_almost_equal(self.am['energy'].scale, 5E3)
        assert self.am['energy'].units == 'meV'

    def test_convert_to_units_list_same_units(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=True)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale,
                               self.axes_list2[1]['scale'])
        assert self.am2['energy'].units == self.axes_list2[1]['units']
        nt.assert_almost_equal(self.am2['energy2'].scale,
                               self.axes_list2[2]['scale'])
        assert self.am2['energy2'].units == self.axes_list2[2]['units']

    def test_convert_to_units_list_signal2D(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale, 2500)
        assert self.am2['energy'].units == 'meV'
        nt.assert_almost_equal(self.am2['energy2'].scale, 5.0)
        assert self.am2['energy2'].units == 'eV'

    @pytest.mark.parametrize("same_units", (True, False))
    def test_convert_to_units_unsupported_units(self, same_units):
        with assert_warns(
                message="not supported for conversion.",
                category=UserWarning):
            self.am.convert_units('navigation', units='toto',
                                  same_units=same_units)
        assert_deep_almost_equal(self.am._get_axes_dicts(),
                                 self.axes_list)
Exemplo n.º 4
0
class Signal(t.HasTraits, MVA):
    data = t.Any()
    axes_manager = t.Instance(AxesManager)
    original_parameters = t.Instance(Parameters)
    mapped_parameters = t.Instance(Parameters)
    physical_property = t.Str()

    def __init__(self, file_data_dict=None, *args, **kw):
        """All data interaction is made through this class or its subclasses


        Parameters:
        -----------
        dictionary : dictionary
           see load_dictionary for the format
        """
        super(Signal, self).__init__()
        self.mapped_parameters = Parameters()
        self.original_parameters = Parameters()
        if type(file_data_dict).__name__ == "dict":
            self.load_dictionary(file_data_dict)
        self._plot = None
        self.mva_results = MVA_Results()
        self._shape_before_unfolding = None
        self._axes_manager_before_unfolding = None

    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])

    def _get_signal_dict(self):
        dic = {}
        dic['data'] = self.data.copy()
        dic['axes'] = self.axes_manager._get_axes_dicts()
        dic['mapped_parameters'] = \
        self.mapped_parameters._get_parameters_dictionary()
        dic['original_parameters'] = \
        self.original_parameters._get_parameters_dictionary()
        return dic

    def _get_undefined_axes_list(self):
        axes = []
        for i in xrange(len(self.data.shape)):
            axes.append({
                'name': 'undefined',
                'scale': 1.,
                'offset': 0.,
                'size': int(self.data.shape[i]),
                'units': 'undefined',
                'index_in_array': i,
            })
        return axes

    def __call__(self, axes_manager=None):
        if axes_manager is None:
            axes_manager = self.axes_manager
        return self.data.__getitem__(axes_manager._getitem_tuple)

    def _get_hse_1D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        inslice = self.axes_manager._non_slicing_axes[0].index_in_array
        if islice > inslice:
            return self.data.squeeze()
        else:
            return self.data.squeeze().T

    def _get_hse_2D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        data = self.data.sum(islice)
        return data

    def _get_hie_explorer(self, *args, **kwargs):
        isslice = [
            self.axes_manager._slicing_axes[0].index_in_array,
            self.axes_manager._slicing_axes[1].index_in_array
        ]
        isslice.sort()
        data = self.data.sum(isslice[1]).sum(isslice[0])
        return data

    def _get_explorer(self, *args, **kwargs):
        nav_dim = self.axes_manager.navigation_dimension
        if self.axes_manager.signal_dimension == 1:
            if nav_dim == 1:
                return self._get_hse_1D_explorer(*args, **kwargs)
            elif nav_dim == 2:
                return self._get_hse_2D_explorer(*args, **kwargs)
            else:
                return None
        if self.axes_manager.signal_dimension == 2:
            if nav_dim == 1 or nav_dim == 2:
                return self._get_hie_explorer(*args, **kwargs)
            else:
                return None
        else:
            return None

    def plot(self, axes_manager=None):
        if self._plot is not None:
            try:
                self._plot.close()
            except:
                # If it was already closed it will raise an exception,
                # but we want to carry on...
                pass

        if axes_manager is None:
            axes_manager = self.axes_manager

        if axes_manager.signal_dimension == 1:
            # Hyperspectrum

            self._plot = mpl_hse.MPL_HyperSpectrum_Explorer()
            self._plot.spectrum_data_function = self.__call__
            self._plot.spectrum_title = self.mapped_parameters.name
            self._plot.xlabel = '%s (%s)' % (
                self.axes_manager._slicing_axes[0].name,
                self.axes_manager._slicing_axes[0].units)
            self._plot.ylabel = 'Intensity'
            self._plot.axes_manager = axes_manager
            self._plot.axis = self.axes_manager._slicing_axes[0].axis

            # Image properties
            if self.axes_manager._non_slicing_axes:
                self._plot.image_data_function = self._get_explorer
                self._plot.image_title = ''
                self._plot.pixel_size = \
                self.axes_manager._non_slicing_axes[0].scale
                self._plot.pixel_units = \
                self.axes_manager._non_slicing_axes[0].units
            self._plot.plot()

        elif axes_manager.signal_dimension == 2:

            # Mike's playground with new plotting toolkits - needs to be a
            # branch.
            """
            if len(self.data.shape)==2:
                from drawing.guiqwt_hie import image_plot_2D
                image_plot_2D(self)

            import drawing.chaco_hie
            self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self)
            self._plot.configure_traits()
            """
            self._plot = mpl_hie.MPL_HyperImage_Explorer()
            self._plot.image_data_function = self.__call__
            self._plot.navigator_data_function = self._get_explorer
            self._plot.axes_manager = axes_manager
            self._plot.plot()

        else:
            messages.warning_exit('Plotting is not supported for this view')

    traits_view = tui.View(
        tui.Item('name'),
        tui.Item('physical_property'),
        tui.Item('units'),
        tui.Item('offset'),
        tui.Item('scale'),
    )

    def plot_residual(self, axes_manager=None):
        """Plot the residual between original data and reconstructed data

        Requires you to have already run PCA or ICA, and to reconstruct data
        using either the pca_build_SI or ica_build_SI methods.
        """

        if hasattr(self, 'residual'):
            self.residual.plot(axes_manager)
        else:
            print "Object does not have any residual information.  Is it a \
reconstruction created using either pca_build_SI or ica_build_SI methods?"

    def save(self, filename, only_view=False, **kwds):
        """Saves the signal in the specified format.

        The function gets the format from the extension. You can use:
            - hdf5 for HDF5
            - nc for NetCDF
            - msa for EMSA/MSA single spectrum saving.
            - bin to produce a raw binary file
            - Many image formats such as png, tiff, jpeg...

        Please note that not all the formats supports saving datasets of
        arbitrary dimensions, e.g. msa only suports 1D data.

        Parameters
        ----------
        filename : str
        msa_format : {'Y', 'XY'}
            'Y' will produce a file without the energy axis. 'XY' will also
            save another column with the energy axis. For compatibility with
            Gatan Digital Micrograph 'Y' is the default.
        only_view : bool
            If True, only the current view will be saved. Otherwise the full
            dataset is saved. Please note that not all the formats support this
            option at the moment.
        """
        io.save(filename, self, **kwds)

    def _replot(self):
        if self._plot is not None:
            if self._plot.is_active() is True:
                self.plot()

    def get_dimensions_from_data(self):
        """Get the dimension parameters from the data_cube. Useful when the
        data_cube was externally modified, or when the SI was not loaded from
        a file
        """
        dc = self.data
        for axis in self.axes_manager.axes:
            axis.size = int(dc.shape[axis.index_in_array])
            print("%s size: %i" % (axis.name, dc.shape[axis.index_in_array]))
        self._replot()

    def crop_in_pixels(self, axis, i1=None, i2=None):
        """Crops the data in a given axis. The range is given in pixels
        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_units
        """
        axis = self._get_positive_axis_index_index(axis)
        if i1 is not None:
            new_offset = self.axes_manager.axes[axis].axis[i1]
        # We take a copy to guarantee the continuity of the data
        self.data = self.data[(slice(None), ) * axis +
                              (slice(i1, i2), Ellipsis)].copy()

        if i1 is not None:
            self.axes_manager.axes[axis].offset = new_offset
        self.get_dimensions_from_data()

    def crop_in_units(self, axis, x1=None, x2=None):
        """Crops the data in a given axis. The range is given in the units of
        the axis

        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_pixels

        """
        i1 = self.axes_manager.axes[axis].value2index(x1)
        i2 = self.axes_manager.axes[axis].value2index(x2)
        self.crop_in_pixels(axis, i1, i2)

    def roll_xy(self, n_x, n_y=1):
        """Roll over the x axis n_x positions and n_y positions the former rows

        This method has the purpose of "fixing" a bug in the acquisition of the
        Orsay's microscopes and probably it does not have general interest

        Parameters
        ----------
        n_x : int
        n_y : int

        Note: Useful to correct the SI column storing bug in Marcel's
        acquisition routines.
        """
        self.data = np.roll(self.data, n_x, 0)
        self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1)
        self._replot()

    # TODO: After using this function the plotting does not work
    def swap_axis(self, axis1, axis2):
        """Swaps the axes

        Parameters
        ----------
        axis1 : positive int
        axis2 : positive int
        """
        self.data = self.data.swapaxes(axis1, axis2)
        c1 = self.axes_manager.axes[axis1]
        c2 = self.axes_manager.axes[axis2]
        c1.index_in_array, c2.index_in_array =  \
            c2.index_in_array, c1.index_in_array
        self.axes_manager.axes[axis1] = c2
        self.axes_manager.axes[axis2] = c1
        self.axes_manager.set_signal_dimension()
        self._replot()

    def rebin(self, new_shape):
        """
        Rebins the data to the new shape

        Parameters
        ----------
        new_shape: tuple of ints
            The new shape must be a divisor of the original shape
        """
        factors = np.array(self.data.shape) / np.array(new_shape)
        self.data = utils.rebin(self.data, new_shape)
        for axis in self.axes_manager.axes:
            axis.scale *= factors[axis.index_in_array]
        self.get_dimensions_from_data()

    def split_in(self, axis, number_of_parts=None, steps=None):
        """Splits the data

        The split can be defined either by the `number_of_parts` or by the
        `steps` size.

        Parameters
        ----------
        number_of_parts : int or None
            Number of parts in which the SI will be splitted
        steps : int or None
            Size of the splitted parts
        axis : int
            The splitting axis

        Return
        ------
        tuple with the splitted signals
        """
        axis = self._get_positive_axis_index_index(axis)
        if number_of_parts is None and steps is None:
            if not self._splitting_steps:
                messages.warning_exit(
                    "Please provide either number_of_parts or a steps list")
            else:
                steps = self._splitting_steps
                print "Splitting in ", steps
        elif number_of_parts is not None and steps is not None:
            print "Using the given steps list. number_of_parts dimissed"
        splitted = []
        shape = self.data.shape

        if steps is None:
            rounded = (shape[axis] - (shape[axis] % number_of_parts))
            step = rounded / number_of_parts
            cut_node = range(0, rounded + step, step)
        else:
            cut_node = np.array([0] + steps).cumsum()
        for i in xrange(len(cut_node) - 1):
            data = self.data[(slice(None), ) * axis +
                             (slice(cut_node[i], cut_node[i + 1]), Ellipsis)]
            s = Signal({'data': data})
            # TODO: When copying plotting does not work
            #            s.axes = copy.deepcopy(self.axes_manager)
            s.get_dimensions_from_data()
            splitted.append(s)
        return splitted

    def unfold_if_multidim(self):
        """Unfold the datacube if it is >2D

        Returns
        -------

        Boolean. True if the data was unfolded by the function.
        """
        if len(self.axes_manager.axes) > 2:
            print "Automatically unfolding the data"
            self.unfold()
            return True
        else:
            return False

    def _unfold(self, steady_axes, unfolded_axis):
        """Modify the shape of the data by specifying the axes the axes which
        dimension do not change and the axis over which the remaining axes will
        be unfolded

        Parameters
        ----------
        steady_axes : list
            The indexes of the axes which dimensions do not change
        unfolded_axis : int
            The index of the axis over which all the rest of the axes (except
            the steady axes) will be unfolded

        See also
        --------
        fold
        """

        # It doesn't make sense unfolding when dim < 3
        if len(self.data.squeeze().shape) < 3:
            return False

        # We need to store the original shape and coordinates to be used by
        # the fold function only if it has not been already stored by a
        # previous unfold
        if self._shape_before_unfolding is None:
            self._shape_before_unfolding = self.data.shape
            self._axes_manager_before_unfolding = self.axes_manager

        new_shape = [1] * len(self.data.shape)
        for index in steady_axes:
            new_shape[index] = self.data.shape[index]
        new_shape[unfolded_axis] = -1
        self.data = self.data.reshape(new_shape)
        self.axes_manager = self.axes_manager.deepcopy()
        i = 0
        uname = ''
        uunits = ''
        to_remove = []
        for axis, dim in zip(self.axes_manager.axes, new_shape):
            if dim == 1:
                uname += ',' + axis.name
                uunits = ',' + axis.units
                to_remove.append(axis)
            else:
                axis.index_in_array = i
                i += 1
        self.axes_manager.axes[unfolded_axis].name += uname
        self.axes_manager.axes[unfolded_axis].units += uunits
        self.axes_manager.axes[unfolded_axis].size = \
                                                self.data.shape[unfolded_axis]
        for axis in to_remove:
            self.axes_manager.axes.remove(axis)

        self.data = self.data.squeeze()
        self._replot()

    def unfold(self):
        """Modifies the shape of the data by unfolding the signal and
        navigation dimensions separaterly

        """
        self.unfold_navigation_space()
        self.unfold_signal_space()

    def unfold_navigation_space(self):
        """Modify the shape of the data to obtain a navigation space of
        dimension 1
        """

        if self.axes_manager.navigation_dimension < 2:
            messages.information('Nothing done, the navigation dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._slicing_axes
        ]
        unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def unfold_signal_space(self):
        """Modify the shape of the data to obtain a signal space of
        dimension 1
        """
        if self.axes_manager.signal_dimension < 2:
            messages.information('Nothing done, the signal dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._non_slicing_axes
        ]
        unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def fold(self):
        """If the signal was previously unfolded, folds it back"""
        if self._shape_before_unfolding is not None:
            self.data = self.data.reshape(self._shape_before_unfolding)
            self.axes_manager = self._axes_manager_before_unfolding
            self._shape_before_unfolding = None
            self._axes_manager_before_unfolding = None
            self._replot()

    def _get_positive_axis_index_index(self, axis):
        if axis < 0:
            axis = len(self.data.shape) + axis
        return axis

    def iterate_axis(self, axis=-1):
        # We make a copy to guarantee that the data in contiguous, otherwise
        # it will not return a view of the data
        self.data = self.data.copy()
        axis = self._get_positive_axis_index_index(axis)
        unfolded_axis = axis - 1
        new_shape = [1] * len(self.data.shape)
        new_shape[axis] = self.data.shape[axis]
        new_shape[unfolded_axis] = -1
        # Warning! if the data is not contigous it will make a copy!!
        data = self.data.reshape(new_shape)
        for i in xrange(data.shape[unfolded_axis]):
            getitem = [0] * len(data.shape)
            getitem[axis] = slice(None)
            getitem[unfolded_axis] = i
            yield (data[getitem])

    def sum(self, axis, return_signal=False):
        """Sum the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
             will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.sum(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.sum(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.sum(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def mean(self, axis, return_signal=False):
        """Average the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
            will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.mean(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.mean(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.mean(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def copy(self):
        return (copy.copy(self))

    def deepcopy(self):
        return (copy.deepcopy(self))

#    def sum_in_mask(self, mask):
#        """Returns the result of summing all the spectra in the mask.
#
#        Parameters
#        ----------
#        mask : boolean numpy array
#
#        Returns
#        -------
#        Spectrum
#        """
#        dc = self.data_cube.copy()
#        mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape)
#        dc = (mask3D*dc).sum(1).sum(1) / mask.sum()
#        s = Spectrum()
#        s.data_cube = dc.reshape((-1,1,1))
#        s.get_dimensions_from_cube()
#        utils.copy_energy_calibration(self,s)
#        return s
#
#    def mean(self, axis):
#        """Average the SI over the given axis
#
#        Parameters
#        ----------
#        axis : int
#        """
#        dc = self.data_cube
#        dc = dc.mean(axis)
#        dc = dc.reshape(list(dc.shape) + [1,])
#        self.data_cube = dc
#        self.get_dimensions_from_cube()
#
#    def roll(self, axis = 2, shift = 1):
#        """Roll the SI. see numpy.roll
#
#        Parameters
#        ----------
#        axis : int
#        shift : int
#        """
#        self.data_cube = np.roll(self.data_cube, shift, axis)
#        self._replot()
#

#
#    def get_calibration_from(self, s):
#        """Copy the calibration from another Spectrum instance
#        Parameters
#        ----------
#        s : spectrum instance
#        """
#        utils.copy_energy_calibration(s, self)
#
#    def estimate_variance(self, dc = None, gaussian_noise_var = None):
#        """Variance estimation supposing Poissonian noise
#
#        Parameters
#        ----------
#        dc : None or numpy array
#            If None the SI is used to estimate its variance. Otherwise, the
#            provided array will be used.
#        Note
#        ----
#        The gain_factor and gain_offset from the aquisition parameters are used
#        """
#        print "Variace estimation using the following values:"
#        print "Gain factor = ", self.acquisition_parameters.gain_factor
#        print "Gain offset = ", self.acquisition_parameters.gain_offset
#        if dc is None:
#            dc = self.data_cube
#        gain_factor = self.acquisition_parameters.gain_factor
#        gain_offset = self.acquisition_parameters.gain_offset
#        self.variance = dc*gain_factor + gain_offset
#        if self.variance.min() < 0:
#            if gain_offset == 0 and gaussian_noise_var is None:
#                print "The variance estimation results in negative values"
#                print "Maybe the gain_offset is wrong?"
#                self.variance = None
#                return
#            elif gaussian_noise_var is None:
#                print "Clipping the variance to the gain_offset value"
#                self.variance = np.clip(self.variance, np.abs(gain_offset),
#                np.Inf)
#            else:
#                print "Clipping the variance to the gaussian_noise_var"
#                self.variance = np.clip(self.variance, gaussian_noise_var,
#                np.Inf)
#
#    def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6,
#    modify_calibration = True):
#        dispersion = (rcE - lcE) / (rc - lc)
#        origin = lcE - dispersion * lc
#        print "Energy step = ", dispersion
#        print "Energy origin = ", origin
#        if modify_calibration is True:
#            self.set_new_calibration(origin, dispersion)
#        return origin, dispersion
#

    def _correct_navigation_mask_when_unfolded(
        self,
        navigation_mask=None,
    ):
        #if 'unfolded' in self.history:
        if navigation_mask is not None:
            navigation_mask = navigation_mask.reshape((-1, ))
        return navigation_mask