def setup_method(self, method):
        axes_list = [{
            'name': 'a',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.3,
            'size': 2,
            'units': 'aa'
        }, {
            'name': 'b',
            'navigate': False,
            'offset': 1.0,
            'scale': 6.0,
            'size': 3,
            'units': 'bb'
        }, {
            'name': 'c',
            'navigate': False,
            'offset': 2.0,
            'scale': 100.0,
            'size': 4,
            'units': 'cc'
        }, {
            'name': 'd',
            'navigate': True,
            'offset': 3.0,
            'scale': 1000000.0,
            'size': 5,
            'units': 'dd'
        }]

        self.am = AxesManager(axes_list)
 def test_convert_to_navigation_units_Undefined(self):
     self.axes_list[0]['units'] = t.Undefined
     am = AxesManager(self.axes_list)
     am.convert_units(axes='navigation', same_units=True)
     assert am['x'].units == t.Undefined
     nt.assert_almost_equal(am['x'].scale, 1.5E-9)
     assert am['y'].units == 'm'
     nt.assert_almost_equal(am['y'].scale, 0.5E-9)
     assert am['energy'].units == 'eV'
     nt.assert_almost_equal(am['energy'].scale, 5)
Exemple #3
0
 def test_convert_to_navigation_units_Undefined(self):
     self.axes_list[0]['units'] = t.Undefined
     am = AxesManager(self.axes_list)
     am.convert_units(axes='navigation', same_units=True)
     assert am['x'].units == t.Undefined
     nt.assert_almost_equal(am['x'].scale, 1.5E-9)
     assert am['y'].units == 'm'
     nt.assert_almost_equal(am['y'].scale, 0.5E-9)
     assert am['energy'].units == 'eV'
     nt.assert_almost_equal(am['energy'].scale, 5)
Exemple #4
0
    def setup_method(self, method):
        self.axes_list = [{
            'name': 'x',
            'navigate': True,
            'is_binned': False,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'y',
            'navigate': True,
            'is_binned': False,
            'offset': 0.0,
            'scale': 0.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'is_binned': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [{
            'name': 'x',
            'navigate': True,
            'is_binned': False,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'is_binned': False,
            'offset': 0.0,
            'scale': 2.5,
            'size': 4096,
            'units': 'eV'
        }, {
            'name': 'energy2',
            'navigate': False,
            'is_binned': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]
        self.am2 = AxesManager(self.axes_list2)
Exemple #5
0
        def do_ffts():
            j = 0
            for s in signals:
                ffts = s.deepcopy()
                if ffts.data.itemsize <= 4:
                    ffts.change_dtype(np.complex64)
                else:
                    ffts.change_dtype(np.complex128)

                am = AxesManager(s.axes_manager._get_axes_dicts())
                for idx in am:
                    fftdata = s.data[am._getitem_tuple]
                    fftdata = scipy.fftpack.fftn(fftdata)
                    fftdata = scipy.fftpack.fftshift(fftdata)
                    ffts.data[am._getitem_tuple] = fftdata
                    j += 1
                    yield j

                for i in range(ffts.axes_manager.signal_dimension):
                    axis = ffts.axes_manager.signal_axes[i]
                    s_axis = s.axes_manager.signal_axes[i]
                    axis.scale = 1 / (s_axis.size * s_axis.scale)
                    shift = (axis.high_value - axis.low_value) / 2
                    axis.offset -= shift
                    u = s_axis.units
                    if u.endswith('-1'):
                        u = u[:-2]
                    else:
                        u += '-1'
                    axis.units = u
                indstr = ' ' + str(s.axes_manager.indices) \
                    if len(s.axes_manager.indices) > 0 else ''
                ffts.metadata.General.title = 'FFT of ' + \
                    ffts.metadata.General.title + indstr
                fftsignals.append(ffts)
Exemple #6
0
    def setup_method(self, method):
        axes_list = [
            {'name': 'a',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.3,
             'size': 2,
             'units': 'aa'},
            {'name': 'b',
             'navigate': False,
             'offset': 1.0,
             'scale': 6.0,
             'size': 3,
             'units': 'bb'},
            {'name': 'c',
             'navigate': False,
             'offset': 2.0,
             'scale': 100.0,
             'size': 4,
             'units': 'cc'},
            {'name': 'd',
             'navigate': True,
             'offset': 3.0,
             'scale': 1000000.0,
             'size': 5,
             'units': 'dd'}]

        self.am = AxesManager(axes_list)
Exemple #7
0
    def append(self, *args):
        if len(args) < 1:
            pass
        else:
            smp = self.mapped_parameters
            print args
            for arg in args:
                #object parameters
                mp = arg.mapped_parameters

                if mp.original_filename not in smp.original_files.keys():
                    smp.original_files[mp.original_filename] = arg
                    # add the data to the aggregate array
                    if self.data == None:
                        self.data = np.atleast_3d(arg.data)
                    else:
                        self.data = np.append(self.data,
                                              np.atleast_3d(arg.data),
                                              axis=2)
                    print "File %s added to aggregate." % mp.original_filename
                else:
                    print "Data from file %s already in this aggregate. \n \
    Delete it first if you want to update it." % mp.original_filename
            # refresh the axes for the new sized data
            self.axes_manager = AxesManager(self._get_undefined_axes_list())
            smp.original_filename = "Aggregate Image: %s" % smp.original_files.keys(
            )
            self.summary()
Exemple #8
0
    def append(self, *args):
        if len(args) < 1:
            pass
        else:
            smp = self.mapped_parameters
            for arg in args:
                #object parameters
                mp = arg.mapped_parameters
                pmp = mp.parent.mapped_parameters

                if pmp.original_filename not in smp.locations.keys():
                    smp.locations[pmp.original_filename] = mp.locations
                    smp.original_files[pmp.original_filename] = mp.parent
                    smp.image_stacks[pmp.original_filename] = arg
                    smp.aggregate_address[pmp.original_filename] = (
                        smp.aggregate_end_pointer,
                        smp.aggregate_end_pointer + arg.data.shape[-1] - 1)
                    # add the data to the aggregate array
                    if self.data == None:
                        self.data = np.atleast_3d(arg.data)
                    else:
                        self.data = np.append(self.data, arg.data, axis=2)
                    print "File %s added to aggregate." % mp.original_filename
                    smp.aggregate_end_pointer = self.data.shape[2]
                else:
                    print "Data from file %s already in this aggregate. \n \
    Delete it first if you want to update it." % mp.original_filename
            # refresh the axes for the new sized data
            self.axes_manager = AxesManager(self._get_undefined_axes_list())
            smp.original_filename = "Aggregate Cells: %s" % smp.locations.keys(
            )
            self.summary()
Exemple #9
0
class TestAxesManager:
    def setup_method(self, method):
        axes_list = [
            {
                "name": "a",
                "navigate": True,
                "offset": 0.0,
                "scale": 1.3,
                "size": 2,
                "units": "aa",
            },
            {
                "name": "b",
                "navigate": False,
                "offset": 1.0,
                "scale": 6.0,
                "size": 3,
                "units": "bb",
            },
            {
                "name": "c",
                "navigate": False,
                "offset": 2.0,
                "scale": 100.0,
                "size": 4,
                "units": "cc",
            },
            {
                "name": "d",
                "navigate": True,
                "offset": 3.0,
                "scale": 1000000.0,
                "size": 5,
                "units": "dd",
            },
        ]

        self.am = AxesManager(axes_list)

    def test_reprs(self):
        repr(self.am)
        self.am._repr_html_

    def test_update_from(self):
        am = self.am
        am2 = self.am.deepcopy()
        m = mock.Mock()
        am.events.any_axis_changed.connect(m.changed)
        am.update_axes_attributes_from(am2._axes)
        assert not m.changed.called
        am2[0].scale = 0.5
        am2[1].units = "km"
        am2[2].offset = 50
        am2[3].size = 1
        am.update_axes_attributes_from(am2._axes, attributes=["units", "scale"])
        assert m.changed.called
        assert am2[0].scale == am[0].scale
        assert am2[1].units == am[1].units
        assert am2[2].offset != am[2].offset
        assert am2[3].size != am[3].size
Exemple #10
0
 def _add_object(self, arg):
     #object parameters
     mp = arg.mapped_parameters
     smp = self.mapped_parameters
     if mp.original_filename not in smp.original_files.keys():
         smp.original_files[mp.original_filename] = arg
         # save the original data shape to the mva_results for later use
         smp.original_files[
             mp.
             original_filename].mva_results.original_shape = arg.data.shape[:
                                                                            -1]
         arg.unfold()
         smp.aggregate_address[mp.original_filename] = (
             smp.aggregate_end_pointer,
             smp.aggregate_end_pointer + arg.data.shape[0] - 1)
         # add the data to the aggregate array
         if self.data == None:
             self.data = np.atleast_2d(arg.data)
             # copy the axes for the sake of calibration
             axes = [
                 arg.axes_manager.axes[i].get_axis_dictionary()
                 for i in xrange(len(arg.axes_manager.axes))
             ]
             self.axes_manager = AxesManager(axes)
         else:
             self.data = np.append(self.data, arg.data, axis=0)
             smp.aggregate_end_pointer = self.data.shape[0]
             print "File %s added to aggregate." % mp.original_filename
     else:
         print "Data from file %s already in this aggregate. \n \
 Delete it first if you want to update it." % mp.original_filename
Exemple #11
0
def hdfgroup2dict(group, dictionary={}):
    for key, value in group.attrs.iteritems():
        if isinstance(value, (np.string_, str)):
            if value == '_None_':
                value = None
        elif isinstance(value, np.bool_):
            value = bool(value)

        elif isinstance(value, np.ndarray) and \
                value.dtype == np.dtype('|S1'):
            value = value.tolist()
        # skip signals - these are handled below.
        if key.startswith('_sig_'):
            pass
        elif key.startswith('_datetime_'):
            dictionary[key.replace("_datetime_", "")] = eval(value)
        else:
            dictionary[key] = value
    if not isinstance(group, h5py.Dataset):
        for key in group.keys():
            if key.startswith('_sig_'):
                from hyperspy.io import dict2signal
                dictionary[key[len('_sig_'):]] = (dict2signal(
                    hdfgroup2signaldict(group[key])))
            elif isinstance(group[key], h5py.Dataset):
                dictionary[key] = np.array(group[key])
            elif key.startswith('_hspy_AxesManager_'):
                dictionary[key[len('_hspy_AxesManager_'):]] = \
                    AxesManager([i
                                 for k, i in sorted(iter(
                                     hdfgroup2dict(group[key]).iteritems()))])
            else:
                dictionary[key] = {}
                hdfgroup2dict(group[key], dictionary[key])
    return dictionary
Exemple #12
0
 def setup_method(self, method):
     self.c = Component(["parameter"])
     self.c._axes_manager = AxesManager([{
         "size": 3,
         "navigate": True
     }, {
         "size": 2,
         "navigate": True
     }])
Exemple #13
0
class TestAxesManager:
    def setup_method(self, method):
        axes_list = [{
            'name': 'a',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.3,
            'size': 2,
            'units': 'aa'
        }, {
            'name': 'b',
            'navigate': False,
            'offset': 1.0,
            'scale': 6.0,
            'size': 3,
            'units': 'bb'
        }, {
            'name': 'c',
            'navigate': False,
            'offset': 2.0,
            'scale': 100.0,
            'size': 4,
            'units': 'cc'
        }, {
            'name': 'd',
            'navigate': True,
            'offset': 3.0,
            'scale': 1000000.0,
            'size': 5,
            'units': 'dd'
        }]

        self.am = AxesManager(axes_list)

    def test_reprs(self):
        repr(self.am)
        self.am._repr_html_

    def test_update_from(self):
        am = self.am
        am2 = self.am.deepcopy()
        m = mock.Mock()
        am.events.any_axis_changed.connect(m.changed)
        am.update_axes_attributes_from(am2._axes)
        assert not m.changed.called
        am2[0].scale = 0.5
        am2[1].units = "km"
        am2[2].offset = 50
        am2[3].size = 1
        am.update_axes_attributes_from(am2._axes,
                                       attributes=["units", "scale"])
        assert m.changed.called
        assert am2[0].scale == am[0].scale
        assert am2[1].units == am[1].units
        assert am2[2].offset != am[2].offset
        assert am2[3].size != am[3].size
    def setup_method(self, method):
        self.axes_list = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'y',
             'navigate': True,
             'offset': 0.0,
             'scale': 0.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 2.5,
             'size': 4096,
             'units': 'eV'},
            {'name': 'energy2',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]
        self.am2 = AxesManager(self.axes_list2)
Exemple #15
0
 def remove(self, *keys):
     smp = self.mapped_parameters
     for key in keys:
         idx = smp.original_files.keys().index(key)
         self.data = np.delete(self.data, np.s_[idx:idx + 1:1], 2)
         del smp.original_files[key]
         print "File %s removed from aggregate." % key
     self.axes_manager = AxesManager(self._get_undefined_axes_list())
     smp.original_filename = "Aggregate Image: %s" % smp.original_files.keys(
     )
     self.summary()
 def test_convert_to_navigation_units_different(self):
     # Don't convert the units since the units of the navigation axes are
     # different
     self.axes_list.insert(0,
                           {'name': 'time',
                            'navigate': True,
                            'offset': 0.0,
                            'scale': 1.5,
                            'size': 20,
                            'units': 's'})
     am = AxesManager(self.axes_list)
     am.convert_units(axes='navigation', same_units=True)
     assert am['time'].units == 's'
     nt.assert_almost_equal(am['time'].scale, 1.5)
     assert am['x'].units == 'nm'
     nt.assert_almost_equal(am['x'].scale, 1.5)
     assert am['y'].units == 'nm'
     nt.assert_almost_equal(am['y'].scale, 0.5)
     assert am['energy'].units == 'eV'
     nt.assert_almost_equal(am['energy'].scale, 5)
Exemple #17
0
class TestAxesManager:

    def setup_method(self, method):
        axes_list = [
            {'name': 'a',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.3,
             'size': 2,
             'units': 'aa'},
            {'name': 'b',
             'navigate': False,
             'offset': 1.0,
             'scale': 6.0,
             'size': 3,
             'units': 'bb'},
            {'name': 'c',
             'navigate': False,
             'offset': 2.0,
             'scale': 100.0,
             'size': 4,
             'units': 'cc'},
            {'name': 'd',
             'navigate': True,
             'offset': 3.0,
             'scale': 1000000.0,
             'size': 5,
             'units': 'dd'}]

        self.am = AxesManager(axes_list)

    def test_reprs(self):
        repr(self.am)
        self.am._repr_html_

    def test_update_from(self):
        am = self.am
        am2 = self.am.deepcopy()
        m = mock.Mock()
        am.events.any_axis_changed.connect(m.changed)
        am.update_axes_attributes_from(am2._axes)
        assert not m.changed.called
        am2[0].scale = 0.5
        am2[1].units = "km"
        am2[2].offset = 50
        am2[3].size = 1
        am.update_axes_attributes_from(am2._axes,
                                       attributes=["units", "scale"])
        assert m.changed.called
        assert am2[0].scale == am[0].scale
        assert am2[1].units == am[1].units
        assert am2[2].offset != am[2].offset
        assert am2[3].size != am[3].size
Exemple #18
0
 def __init__(self, *args, **kw):
     # this axes_manager isn't really ideal for Aggregates.
     self.axes_manager = AxesManager([{
         'name': 'undefined',
         'scale': 1.,
         'offset': 0.,
         'size': 1,
         'units': 'undefined',
         'index_in_array': 0,
     }])
     super(Aggregate, self).__init__(*args, **kw)
     self.data = None
     self.mapped_parameters.original_files = OrderedDict()
Exemple #19
0
 def test_convert_to_navigation_units_different(self):
     # Don't convert the units since the units of the navigation axes are
     # different
     self.axes_list.insert(
         0, {
             'name': 'time',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5,
             'size': 20,
             'units': 's'
         })
     am = AxesManager(self.axes_list)
     am.convert_units(axes='navigation', same_units=True)
     assert am['time'].units == 's'
     nt.assert_almost_equal(am['time'].scale, 1.5)
     assert am['x'].units == 'nm'
     nt.assert_almost_equal(am['x'].scale, 1.5)
     assert am['y'].units == 'nm'
     nt.assert_almost_equal(am['y'].scale, 0.5)
     assert am['energy'].units == 'eV'
     nt.assert_almost_equal(am['energy'].scale, 5)
Exemple #20
0
 def remove(self, *keys):
     smp = self.mapped_parameters
     for key in keys:
         del smp.locations[key]
         del smp.original_files[key]
         del smp.image_stacks[key]
         address = smp.aggregate_address[key]
         self.data = np.delete(self.data, np.s_[address[0]:address[1]:1], 2)
         print "File %s removed from aggregate." % key
     self.axes_manager = AxesManager(self._get_undefined_axes_list())
     smp.aggregate_end_pointer = self.data.shape[2]
     smp.original_filename = "Aggregate Cells: %s" % smp.locations.keys()
     self.summary()
Exemple #21
0
    def setup_method(self, method):
        axes_list = [
            {
                "name": "a",
                "navigate": True,
                "offset": 0.0,
                "scale": 1.3,
                "size": 2,
                "units": "aa",
            },
            {
                "name": "b",
                "navigate": False,
                "offset": 1.0,
                "scale": 6.0,
                "size": 3,
                "units": "bb",
            },
            {
                "name": "c",
                "navigate": False,
                "offset": 2.0,
                "scale": 100.0,
                "size": 4,
                "units": "cc",
            },
            {
                "name": "d",
                "navigate": True,
                "offset": 3.0,
                "scale": 1000000.0,
                "size": 5,
                "units": "dd",
            },
        ]

        self.am = AxesManager(axes_list)
Exemple #22
0
class TestAxesManager:

    def setup(self):
        axes_list = [
            {'name': 'a',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.3,
             'size': 2,
             'units': 'aa'},
            {'name': 'b',
             'navigate': False,
             'offset': 1.0,
             'scale': 6.0,
             'size': 3,
             'units': 'bb'},
            {'name': 'c',
             'navigate': False,
             'offset': 2.0,
             'scale': 100.0,
             'size': 4,
             'units': 'cc'},
            {'name': 'd',
             'navigate': True,
             'offset': 3.0,
             'scale': 1000000.0,
             'size': 5,
             'units': 'dd'}]

        self.am = AxesManager(axes_list)

    def test_update_from(self):
        am = self.am
        am2 = self.am.deepcopy()
        m = mock.Mock()
        am.events.any_axis_changed.connect(m.changed)
        am.update_axes_attributes_from(am2._axes)
        nt.assert_false(m.changed.called)
        am2[0].scale = 0.5
        am2[1].units = "km"
        am2[2].offset = 50
        am2[3].size = 1
        am.update_axes_attributes_from(am2._axes,
                                       attributes=["units", "scale"])
        nt.assert_true(m.changed.called)
        nt.assert_equal(am2[0].scale, am[0].scale)
        nt.assert_equal(am2[1].units, am[1].units)
        nt.assert_not_equal(am2[2].offset, am[2].offset)
        nt.assert_not_equal(am2[3].size, am[3].size)
Exemple #23
0
    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])
Exemple #24
0
    def __init__(self, signal):
        if signal.axes_manager.signal_dimension != 1:
            raise SignalDimensionError(
                signal.axes_manager.signal_dimension, 1)

        self.signal = signal
        self.signal.plot()
        axis_dict = signal.axes_manager.signal_axes[0].get_axis_dictionary()
        am = AxesManager([axis_dict, ])
        am._axes[0].navigate = True
        # Set the position of the line in the middle of the spectral
        # range by default
        am._axes[0].index = int(round(am._axes[0].size / 2))
        self.axes_manager = am
        self.axes_manager.connect(self.update_position)
        self.on_trait_change(self.switch_on_off, 'on')
Exemple #25
0
def hdfgroup2dict(group, dictionary={}):
    for key, value in group.attrs.iteritems():
        if type(value) is np.string_:
            if value == '_None_':
                value = None
            else:
                try:
                    value = value.decode('utf8')
                except UnicodeError:
                    # For old files
                    value = value.decode('latin-1')
        elif type(value) is np.bool_:
            value = bool(value)

        elif type(value) is np.ndarray and \
                value.dtype == np.dtype('|S1'):
            value = value.tolist()
        # skip signals - these are handled below.
        if key.startswith('_sig_'):
            pass
        else:
            dictionary[key] = value
    if not isinstance(group, h5py.Dataset):
        for key in group.keys():
            if key.startswith('_sig_'):
                from hyperspy.io import dict2signal
                dictionary[key[len('_sig_'):]] = (dict2signal(
                    hdfgroup2signaldict(group[key])))
            elif isinstance(group[key], h5py.Dataset):
                dictionary[key] = np.array(group[key])
            elif key.startswith('_hspy_AxesManager_'):
                dictionary[key[len('_hspy_AxesManager_'):]] = \
                    AxesManager([i
                        for k, i in sorted(iter(
                            hdfgroup2dict(group[key]).iteritems()))])
            else:
                dictionary[key] = {}
                hdfgroup2dict(group[key], dictionary[key])
    return dictionary
    def setup(self):
        self.axes_list = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1E-6,
             'size': 1024,
             'units': 'm'},
            {'name': 'y',
             'navigate': True,
             'offset': 0.0,
             'scale': 0.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]

        self.am = AxesManager(self.axes_list)
Exemple #27
0
        def do_iffts():
            j = 0
            for s in signals:
                ffts = s.deepcopy()
                if ffts.data.itemsize <= 4:
                    ffts.change_dtype(np.float32)
                else:
                    ffts.change_dtype(np.float64)

                am = AxesManager(s.axes_manager._get_axes_dicts())

                for i in range(ffts.axes_manager.signal_dimension):
                    axis = ffts.axes_manager.signal_axes[i]
                    s_axis = s.axes_manager.signal_axes[i]
                    shift = (axis.high_value - axis.low_value) / 2
                    axis.offset += shift
                    axis.scale = 1 / (s_axis.size * s_axis.scale)
                    u = s_axis.units
                    if u is traits.Undefined:
                        pass  # Leave unit as undefined
                    elif u.endswith('-1'):
                        u = u[:-2]
                    else:
                        u += '-1'
                    axis.units = u

                for idx in am:
                    fftdata = s.data[am._getitem_tuple]
                    fftdata = scipy.fftpack.ifftshift(fftdata)
                    fftdata = scipy.fftpack.ifftn(fftdata)
                    fftdata = np.abs(fftdata)
                    ffts.data[am._getitem_tuple] = fftdata
                    j += 1
                    yield j
                indstr = ' ' + str(s.axes_manager.indices) \
                    if len(s.axes_manager.indices) > 0 else ''
                ffts.metadata.General.title = 'Inverse FFT of ' + \
                    ffts.metadata.General.title + indstr
                fftsignals.append(ffts)
Exemple #28
0
    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(
            file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])
class TestAxesManager:

    def setup(self):
        self.axes_list = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1E-6,
             'size': 1024,
             'units': 'm'},
            {'name': 'y',
             'navigate': True,
             'offset': 0.0,
             'scale': 0.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]

        self.am = AxesManager(self.axes_list)

    def test_convenient_scale_unit(self):
        self.am.convert_units()
        nt.assert_almost_equal(self.am['x'].scale, 1.0, places=5)
        nt.assert_equal(self.am['x'].units, 'µm')
        nt.assert_almost_equal(self.am['y'].scale, 0.5, places=5)
        nt.assert_equal(self.am['y'].units, 'nm')
        nt.assert_almost_equal(self.am['energy'].scale, 0.005, places=5)
        nt.assert_equal(self.am['energy'].units, 'keV')
        
    def test_convert_to_navigation_units(self):
        self.am.convert_units(axes='navigation', units='µm')
        nt.assert_almost_equal(self.am['x'].scale, 1.0, places=5)
        nt.assert_equal(self.am['x'].units, 'µm')
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3, places=5)
        nt.assert_equal(self.am['y'].units, 'µm')
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'], places=5)
        nt.assert_equal(self.am['energy'].units, self.axes_list[-1]['units'])

    def test_convert_to_navigation_units_list(self):
        self.am.convert_units(axes='navigation', units=['µm', 'nm'])
        nt.assert_almost_equal(self.am['x'].scale, 1000.0, places=5)
        nt.assert_equal(self.am['x'].units, 'nm')
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3, places=5)
        nt.assert_equal(self.am['y'].units, 'µm')
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'], places=5)
        nt.assert_equal(self.am['energy'].units, self.axes_list[-1]['units'])
        
    def test_convert_to_signal_units(self):
        self.am.convert_units(axes='signal', units='keV')
        nt.assert_almost_equal(self.am['x'].scale, self.axes_list[0]['scale'],
                               places=5)
        nt.assert_equal(self.am['x'].units, self.axes_list[0]['units'])
        nt.assert_almost_equal(self.am['y'].scale, self.axes_list[1]['scale'],
                               places=5)
        nt.assert_equal(self.am['y'].units, self.axes_list[1]['units'])
        nt.assert_almost_equal(self.am['energy'].scale, 0.005, places=5)
        nt.assert_equal(self.am['energy'].units, 'keV')
        
    def test_convert_to_units_list(self):
        self.am.convert_units(units=['µm', 'nm', 'meV'])
        nt.assert_almost_equal(self.am['x'].scale, 1000.0, places=5)
        nt.assert_equal(self.am['x'].units, 'nm')
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3, places=5)
        nt.assert_equal(self.am['y'].units, 'µm')
        nt.assert_almost_equal(self.am['energy'].scale, 5E3, places=5)
        nt.assert_equal(self.am['energy'].units, 'meV')
Exemple #30
0
class Signal(t.HasTraits, MVA):
    data = t.Any()
    axes_manager = t.Instance(AxesManager)
    original_parameters = t.Instance(Parameters)
    mapped_parameters = t.Instance(Parameters)
    physical_property = t.Str()

    def __init__(self, file_data_dict=None, *args, **kw):
        """All data interaction is made through this class or its subclasses


        Parameters:
        -----------
        dictionary : dictionary
           see load_dictionary for the format
        """
        super(Signal, self).__init__()
        self.mapped_parameters = Parameters()
        self.original_parameters = Parameters()
        if type(file_data_dict).__name__ == "dict":
            self.load_dictionary(file_data_dict)
        self._plot = None
        self.mva_results=MVA_Results()
        self._shape_before_unfolding = None
        self._axes_manager_before_unfolding = None

    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(
            file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])

    def _get_signal_dict(self):
        dic = {}
        dic['data'] = self.data.copy()
        dic['axes'] = self.axes_manager._get_axes_dicts()
        dic['mapped_parameters'] = \
        self.mapped_parameters._get_parameters_dictionary()
        dic['original_parameters'] = \
        self.original_parameters._get_parameters_dictionary()
        return dic

    def _get_undefined_axes_list(self):
        axes = []
        for i in xrange(len(self.data.shape)):
            axes.append({
                        'name': 'undefined',
                        'scale': 1.,
                        'offset': 0.,
                        'size': int(self.data.shape[i]),
                        'units': 'undefined',
                        'index_in_array': i, })
        return axes

    def __call__(self, axes_manager=None):
        if axes_manager is None:
            axes_manager = self.axes_manager
        return self.data.__getitem__(axes_manager._getitem_tuple)

    def _get_hse_1D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        inslice = self.axes_manager._non_slicing_axes[0].index_in_array
        if islice > inslice:
            return self.data.squeeze()
        else:
            return self.data.squeeze().T

    def _get_hse_2D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        data = self.data.sum(islice)
        return data

    def _get_hie_explorer(self, *args, **kwargs):
        isslice = [self.axes_manager._slicing_axes[0].index_in_array,
                   self.axes_manager._slicing_axes[1].index_in_array]
        isslice.sort()
        data = self.data.sum(isslice[1]).sum(isslice[0])
        return data

    def _get_explorer(self, *args, **kwargs):
        nav_dim = self.axes_manager.navigation_dimension
        if self.axes_manager.signal_dimension == 1:
            if nav_dim == 1:
                return self._get_hse_1D_explorer(*args, **kwargs)
            elif nav_dim == 2:
                return self._get_hse_2D_explorer(*args, **kwargs)
            else:
                return None
        if self.axes_manager.signal_dimension == 2:
            if nav_dim == 1 or nav_dim == 2:
                return self._get_hie_explorer(*args, **kwargs)
            else:
                return None
        else:
            return None

    def plot(self, axes_manager=None):
        if self._plot is not None:
                try:
                    self._plot.close()
                except:
                    # If it was already closed it will raise an exception,
                    # but we want to carry on...
                    pass

        if axes_manager is None:
            axes_manager = self.axes_manager

        if axes_manager.signal_dimension == 1:
            # Hyperspectrum

            self._plot = mpl_hse.MPL_HyperSpectrum_Explorer()
            self._plot.spectrum_data_function = self.__call__
            self._plot.spectrum_title = self.mapped_parameters.name
            self._plot.xlabel = '%s (%s)' % (
                self.axes_manager._slicing_axes[0].name,
                self.axes_manager._slicing_axes[0].units)
            self._plot.ylabel = 'Intensity'
            self._plot.axes_manager = axes_manager
            self._plot.axis = self.axes_manager._slicing_axes[0].axis

            # Image properties
            if self.axes_manager._non_slicing_axes:
                self._plot.image_data_function = self._get_explorer
                self._plot.image_title = ''
                self._plot.pixel_size = \
                self.axes_manager._non_slicing_axes[0].scale
                self._plot.pixel_units = \
                self.axes_manager._non_slicing_axes[0].units
            self._plot.plot()

        elif axes_manager.signal_dimension == 2:

            # Mike's playground with new plotting toolkits - needs to be a
            # branch.
            """
            if len(self.data.shape)==2:
                from drawing.guiqwt_hie import image_plot_2D
                image_plot_2D(self)

            import drawing.chaco_hie
            self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self)
            self._plot.configure_traits()
            """
            self._plot = mpl_hie.MPL_HyperImage_Explorer()
            self._plot.image_data_function = self.__call__
            self._plot.navigator_data_function = self._get_explorer
            self._plot.axes_manager = axes_manager
            self._plot.plot()

        else:
            messages.warning_exit('Plotting is not supported for this view')

    traits_view = tui.View(
        tui.Item('name'),
        tui.Item('physical_property'),
        tui.Item('units'),
        tui.Item('offset'),
        tui.Item('scale'),)

    def plot_residual(self, axes_manager=None):
        """Plot the residual between original data and reconstructed data

        Requires you to have already run PCA or ICA, and to reconstruct data
        using either the pca_build_SI or ica_build_SI methods.
        """

        if hasattr(self, 'residual'):
            self.residual.plot(axes_manager)
        else:
            print "Object does not have any residual information.  Is it a \
reconstruction created using either pca_build_SI or ica_build_SI methods?"

    def save(self, filename, only_view = False, **kwds):
        """Saves the signal in the specified format.

        The function gets the format from the extension. You can use:
            - hdf5 for HDF5
            - nc for NetCDF
            - msa for EMSA/MSA single spectrum saving.
            - bin to produce a raw binary file
            - Many image formats such as png, tiff, jpeg...

        Please note that not all the formats supports saving datasets of
        arbitrary dimensions, e.g. msa only suports 1D data.

        Parameters
        ----------
        filename : str
        msa_format : {'Y', 'XY'}
            'Y' will produce a file without the energy axis. 'XY' will also
            save another column with the energy axis. For compatibility with
            Gatan Digital Micrograph 'Y' is the default.
        only_view : bool
            If True, only the current view will be saved. Otherwise the full
            dataset is saved. Please note that not all the formats support this
            option at the moment.
        """
        io.save(filename, self, **kwds)

    def _replot(self):
        if self._plot is not None:
            if self._plot.is_active() is True:
                self.plot()

    def get_dimensions_from_data(self):
        """Get the dimension parameters from the data_cube. Useful when the
        data_cube was externally modified, or when the SI was not loaded from
        a file
        """
        dc = self.data
        for axis in self.axes_manager.axes:
            axis.size = int(dc.shape[axis.index_in_array])
            print("%s size: %i" %
            (axis.name, dc.shape[axis.index_in_array]))
        self._replot()

    def crop_in_pixels(self, axis, i1 = None, i2 = None):
        """Crops the data in a given axis. The range is given in pixels
        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_units
        """
        axis = self._get_positive_axis_index_index(axis)
        if i1 is not None:
            new_offset = self.axes_manager.axes[axis].axis[i1]
        # We take a copy to guarantee the continuity of the data
        self.data = self.data[
        (slice(None),)*axis + (slice(i1, i2), Ellipsis)].copy()

        if i1 is not None:
            self.axes_manager.axes[axis].offset = new_offset
        self.get_dimensions_from_data()

    def crop_in_units(self, axis, x1 = None, x2 = None):
        """Crops the data in a given axis. The range is given in the units of
        the axis

        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_pixels

        """
        i1 = self.axes_manager.axes[axis].value2index(x1)
        i2 = self.axes_manager.axes[axis].value2index(x2)
        self.crop_in_pixels(axis, i1, i2)

    def roll_xy(self, n_x, n_y = 1):
        """Roll over the x axis n_x positions and n_y positions the former rows

        This method has the purpose of "fixing" a bug in the acquisition of the
        Orsay's microscopes and probably it does not have general interest

        Parameters
        ----------
        n_x : int
        n_y : int

        Note: Useful to correct the SI column storing bug in Marcel's
        acquisition routines.
        """
        self.data = np.roll(self.data, n_x, 0)
        self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1)
        self._replot()

    # TODO: After using this function the plotting does not work
    def swap_axis(self, axis1, axis2):
        """Swaps the axes

        Parameters
        ----------
        axis1 : positive int
        axis2 : positive int
        """
        self.data = self.data.swapaxes(axis1, axis2)
        c1 = self.axes_manager.axes[axis1]
        c2 = self.axes_manager.axes[axis2]
        c1.index_in_array, c2.index_in_array =  \
            c2.index_in_array, c1.index_in_array
        self.axes_manager.axes[axis1] = c2
        self.axes_manager.axes[axis2] = c1
        self.axes_manager.set_signal_dimension()
        self._replot()

    def rebin(self, new_shape):
        """
        Rebins the data to the new shape

        Parameters
        ----------
        new_shape: tuple of ints
            The new shape must be a divisor of the original shape
        """
        factors = np.array(self.data.shape) / np.array(new_shape)
        self.data = utils.rebin(self.data, new_shape)
        for axis in self.axes_manager.axes:
            axis.scale *= factors[axis.index_in_array]
        self.get_dimensions_from_data()

    def split_in(self, axis, number_of_parts = None, steps = None):
        """Splits the data

        The split can be defined either by the `number_of_parts` or by the
        `steps` size.

        Parameters
        ----------
        number_of_parts : int or None
            Number of parts in which the SI will be splitted
        steps : int or None
            Size of the splitted parts
        axis : int
            The splitting axis

        Return
        ------
        tuple with the splitted signals
        """
        axis = self._get_positive_axis_index_index(axis)
        if number_of_parts is None and steps is None:
            if not self._splitting_steps:
                messages.warning_exit(
                "Please provide either number_of_parts or a steps list")
            else:
                steps = self._splitting_steps
                print "Splitting in ", steps
        elif number_of_parts is not None and steps is not None:
            print "Using the given steps list. number_of_parts dimissed"
        splitted = []
        shape = self.data.shape

        if steps is None:
            rounded = (shape[axis] - (shape[axis] % number_of_parts))
            step = rounded / number_of_parts
            cut_node = range(0, rounded+step, step)
        else:
            cut_node = np.array([0] + steps).cumsum()
        for i in xrange(len(cut_node)-1):
            data = self.data[
            (slice(None), ) * axis + (slice(cut_node[i], cut_node[i + 1]),
            Ellipsis)]
            s = Signal({'data': data})
            # TODO: When copying plotting does not work
#            s.axes = copy.deepcopy(self.axes_manager)
            s.get_dimensions_from_data()
            splitted.append(s)
        return splitted

    def unfold_if_multidim(self):
        """Unfold the datacube if it is >2D

        Returns
        -------

        Boolean. True if the data was unfolded by the function.
        """
        if len(self.axes_manager.axes)>2:
            print "Automatically unfolding the data"
            self.unfold()
            return True
        else:
            return False

    def _unfold(self, steady_axes, unfolded_axis):
        """Modify the shape of the data by specifying the axes the axes which
        dimension do not change and the axis over which the remaining axes will
        be unfolded

        Parameters
        ----------
        steady_axes : list
            The indexes of the axes which dimensions do not change
        unfolded_axis : int
            The index of the axis over which all the rest of the axes (except
            the steady axes) will be unfolded

        See also
        --------
        fold
        """

        # It doesn't make sense unfolding when dim < 3
        if len(self.data.squeeze().shape) < 3:
            return False

        # We need to store the original shape and coordinates to be used by
        # the fold function only if it has not been already stored by a
        # previous unfold
        if self._shape_before_unfolding is None:
            self._shape_before_unfolding = self.data.shape
            self._axes_manager_before_unfolding = self.axes_manager

        new_shape = [1] * len(self.data.shape)
        for index in steady_axes:
            new_shape[index] = self.data.shape[index]
        new_shape[unfolded_axis] = -1
        self.data = self.data.reshape(new_shape)
        self.axes_manager = self.axes_manager.deepcopy()
        i = 0
        uname = ''
        uunits = ''
        to_remove = []
        for axis, dim in zip(self.axes_manager.axes, new_shape):
            if dim == 1:
                uname += ',' + axis.name
                uunits = ',' + axis.units
                to_remove.append(axis)
            else:
                axis.index_in_array = i
                i += 1
        self.axes_manager.axes[unfolded_axis].name += uname
        self.axes_manager.axes[unfolded_axis].units += uunits
        self.axes_manager.axes[unfolded_axis].size = \
                                                self.data.shape[unfolded_axis]
        for axis in to_remove:
            self.axes_manager.axes.remove(axis)

        self.data = self.data.squeeze()
        self._replot()

    def unfold(self):
        """Modifies the shape of the data by unfolding the signal and
        navigation dimensions separaterly

        """
        self.unfold_navigation_space()
        self.unfold_signal_space()

    def unfold_navigation_space(self):
        """Modify the shape of the data to obtain a navigation space of
        dimension 1
        """

        if self.axes_manager.navigation_dimension < 2:
            messages.information('Nothing done, the navigation dimension was '
                                'already 1')
            return False
        steady_axes = [
                        axis.index_in_array for axis in
                        self.axes_manager._slicing_axes]
        unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def unfold_signal_space(self):
        """Modify the shape of the data to obtain a signal space of
        dimension 1
        """
        if self.axes_manager.signal_dimension < 2:
            messages.information('Nothing done, the signal dimension was '
                                'already 1')
            return False
        steady_axes = [
                        axis.index_in_array for axis in
                        self.axes_manager._non_slicing_axes]
        unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def fold(self):
        """If the signal was previously unfolded, folds it back"""
        if self._shape_before_unfolding is not None:
            self.data = self.data.reshape(self._shape_before_unfolding)
            self.axes_manager = self._axes_manager_before_unfolding
            self._shape_before_unfolding = None
            self._axes_manager_before_unfolding = None
            self._replot()

    def _get_positive_axis_index_index(self, axis):
        if axis < 0:
            axis = len(self.data.shape) + axis
        return axis

    def iterate_axis(self, axis = -1):
        # We make a copy to guarantee that the data in contiguous, otherwise
        # it will not return a view of the data
        self.data = self.data.copy()
        axis = self._get_positive_axis_index_index(axis)
        unfolded_axis = axis - 1
        new_shape = [1] * len(self.data.shape)
        new_shape[axis] = self.data.shape[axis]
        new_shape[unfolded_axis] = -1
        # Warning! if the data is not contigous it will make a copy!!
        data = self.data.reshape(new_shape)
        for i in xrange(data.shape[unfolded_axis]):
            getitem = [0] * len(data.shape)
            getitem[axis] = slice(None)
            getitem[unfolded_axis] = i
            yield(data[getitem])

    def sum(self, axis, return_signal = False):
        """Sum the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
             will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.sum(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.sum(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.sum(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def mean(self, axis, return_signal = False):
        """Average the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
            will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.mean(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.mean(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.mean(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def copy(self):
        return(copy.copy(self))

    def deepcopy(self):
        return(copy.deepcopy(self))

#    def sum_in_mask(self, mask):
#        """Returns the result of summing all the spectra in the mask.
#
#        Parameters
#        ----------
#        mask : boolean numpy array
#
#        Returns
#        -------
#        Spectrum
#        """
#        dc = self.data_cube.copy()
#        mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape)
#        dc = (mask3D*dc).sum(1).sum(1) / mask.sum()
#        s = Spectrum()
#        s.data_cube = dc.reshape((-1,1,1))
#        s.get_dimensions_from_cube()
#        utils.copy_energy_calibration(self,s)
#        return s
#
#    def mean(self, axis):
#        """Average the SI over the given axis
#
#        Parameters
#        ----------
#        axis : int
#        """
#        dc = self.data_cube
#        dc = dc.mean(axis)
#        dc = dc.reshape(list(dc.shape) + [1,])
#        self.data_cube = dc
#        self.get_dimensions_from_cube()
#
#    def roll(self, axis = 2, shift = 1):
#        """Roll the SI. see numpy.roll
#
#        Parameters
#        ----------
#        axis : int
#        shift : int
#        """
#        self.data_cube = np.roll(self.data_cube, shift, axis)
#        self._replot()
#

#
#    def get_calibration_from(self, s):
#        """Copy the calibration from another Spectrum instance
#        Parameters
#        ----------
#        s : spectrum instance
#        """
#        utils.copy_energy_calibration(s, self)
#
#    def estimate_variance(self, dc = None, gaussian_noise_var = None):
#        """Variance estimation supposing Poissonian noise
#
#        Parameters
#        ----------
#        dc : None or numpy array
#            If None the SI is used to estimate its variance. Otherwise, the
#            provided array will be used.
#        Note
#        ----
#        The gain_factor and gain_offset from the aquisition parameters are used
#        """
#        print "Variace estimation using the following values:"
#        print "Gain factor = ", self.acquisition_parameters.gain_factor
#        print "Gain offset = ", self.acquisition_parameters.gain_offset
#        if dc is None:
#            dc = self.data_cube
#        gain_factor = self.acquisition_parameters.gain_factor
#        gain_offset = self.acquisition_parameters.gain_offset
#        self.variance = dc*gain_factor + gain_offset
#        if self.variance.min() < 0:
#            if gain_offset == 0 and gaussian_noise_var is None:
#                print "The variance estimation results in negative values"
#                print "Maybe the gain_offset is wrong?"
#                self.variance = None
#                return
#            elif gaussian_noise_var is None:
#                print "Clipping the variance to the gain_offset value"
#                self.variance = np.clip(self.variance, np.abs(gain_offset),
#                np.Inf)
#            else:
#                print "Clipping the variance to the gaussian_noise_var"
#                self.variance = np.clip(self.variance, gaussian_noise_var,
#                np.Inf)
#
#    def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6,
#    modify_calibration = True):
#        dispersion = (rcE - lcE) / (rc - lc)
#        origin = lcE - dispersion * lc
#        print "Energy step = ", dispersion
#        print "Energy origin = ", origin
#        if modify_calibration is True:
#            self.set_new_calibration(origin, dispersion)
#        return origin, dispersion
#
    def _correct_navigation_mask_when_unfolded(self, navigation_mask = None,):
        #if 'unfolded' in self.history:
        if navigation_mask is not None:
            navigation_mask = navigation_mask.reshape((-1,))
        return navigation_mask
class TestAxesManager:

    def setup_method(self, method):
        self.axes_list = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'y',
             'navigate': True,
             'offset': 0.0,
             'scale': 0.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [
            {'name': 'x',
             'navigate': True,
             'offset': 0.0,
             'scale': 1.5E-9,
             'size': 1024,
             'units': 'm'},
            {'name': 'energy',
             'navigate': False,
             'offset': 0.0,
             'scale': 2.5,
             'size': 4096,
             'units': 'eV'},
            {'name': 'energy2',
             'navigate': False,
             'offset': 0.0,
             'scale': 5.0,
             'size': 4096,
             'units': 'eV'}]
        self.am2 = AxesManager(self.axes_list2)

    def test_compact_unit(self):
        self.am.convert_units()
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['y'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5)
        assert self.am['energy'].units == 'keV'
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)

    def test_convert_to_navigation_units(self):
        self.am.convert_units(axes='navigation', units='mm')
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-6)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_units_axes_integer(self):
        # convert only the first axis
        self.am.convert_units(axes=0, units='nm', same_units=False)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-9)
        assert self.am['x'].units == 'm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

        self.am.convert_units(axes=0, units='nm', same_units=True)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'

    def test_convert_to_navigation_units_list(self):
        self.am.convert_units(axes='navigation', units=['mm', 'nm'],
                              same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_to_navigation_units_list_same_units(self):
        self.am.convert_units(axes='navigation', units=['mm', 'nm'],
                              same_units=True)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5e-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5e-6)
        assert self.am['energy'].units == 'eV'
        nt.assert_almost_equal(self.am['energy'].scale, 5)

    def test_convert_to_navigation_units_different(self):
        # Don't convert the units since the units of the navigation axes are
        # different
        self.axes_list.insert(0,
                              {'name': 'time',
                               'navigate': True,
                               'offset': 0.0,
                               'scale': 1.5,
                               'size': 20,
                               'units': 's'})
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['time'].units == 's'
        nt.assert_almost_equal(am['time'].scale, 1.5)
        assert am['x'].units == 'nm'
        nt.assert_almost_equal(am['x'].scale, 1.5)
        assert am['y'].units == 'nm'
        nt.assert_almost_equal(am['y'].scale, 0.5)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_navigation_units_Undefined(self):
        self.axes_list[0]['units'] = t.Undefined
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['x'].units == t.Undefined
        nt.assert_almost_equal(am['x'].scale, 1.5E-9)
        assert am['y'].units == 'm'
        nt.assert_almost_equal(am['y'].scale, 0.5E-9)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_signal_units(self):
        self.am.convert_units(axes='signal', units='keV')
        nt.assert_almost_equal(self.am['x'].scale, self.axes_list[0]['scale'])
        assert self.am['x'].units == self.axes_list[0]['units']
        nt.assert_almost_equal(self.am['y'].scale, self.axes_list[1]['scale'])
        assert self.am['y'].units == self.axes_list[1]['units']
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)
        assert self.am['energy'].units == 'keV'

    def test_convert_to_units_list(self):
        self.am.convert_units(units=['µm', 'nm', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3)
        assert self.am['y'].units == 'um'
        nt.assert_almost_equal(self.am['energy'].scale, 5E3)
        assert self.am['energy'].units == 'meV'

    def test_convert_to_units_list_same_units(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=True)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale,
                               self.axes_list2[1]['scale'])
        assert self.am2['energy'].units == self.axes_list2[1]['units']
        nt.assert_almost_equal(self.am2['energy2'].scale,
                               self.axes_list2[2]['scale'])
        assert self.am2['energy2'].units == self.axes_list2[2]['units']

    def test_convert_to_units_list_signal2D(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale, 2500)
        assert self.am2['energy'].units == 'meV'
        nt.assert_almost_equal(self.am2['energy2'].scale, 5.0)
        assert self.am2['energy2'].units == 'eV'

    @pytest.mark.parametrize("same_units", (True, False))
    def test_convert_to_units_unsupported_units(self, same_units):
        with assert_warns(
                message="not supported for conversion.",
                category=UserWarning):
            self.am.convert_units('navigation', units='toto',
                                  same_units=same_units)
        assert_deep_almost_equal(self.am._get_axes_dicts(),
                                 self.axes_list)
Exemple #32
0
def hdfgroup2dict(group, dictionary=None, load_to_memory=True):
    if dictionary is None:
        dictionary = {}
    for key, value in group.attrs.iteritems():
        if isinstance(value, (np.string_, str)):
            if value == '_None_':
                value = None
        elif isinstance(value, np.bool_):
            value = bool(value)

        elif isinstance(value, np.ndarray) and \
                value.dtype == np.dtype('|S1'):
            value = value.tolist()
        # skip signals - these are handled below.
        if key.startswith('_sig_'):
            pass
        elif key.startswith('_list_empty_'):
            dictionary[key[len('_list_empty_'):]] = []
        elif key.startswith('_tuple_empty_'):
            dictionary[key[len('_tuple_empty_'):]] = ()
        elif key.startswith('_bs_'):
            dictionary[key[len('_bs_'):]] = value.tostring()
        elif key.startswith('_datetime_'):
            dictionary[key.replace("_datetime_", "")] = eval(value)
        else:
            dictionary[key] = value
    if not isinstance(group, h5py.Dataset):
        for key in group.keys():
            if key.startswith('_sig_'):
                from hyperspy.io import dict2signal
                dictionary[key[len('_sig_'):]] = (
                    dict2signal(hdfgroup2signaldict(group[key],
                                                    load_to_memory=load_to_memory)))
            elif isinstance(group[key], h5py.Dataset):
                if key.startswith("_list_"):
                    ans = np.array(group[key])
                    ans = ans.tolist()
                    kn = key[6:]
                elif key.startswith("_tuple_"):
                    ans = np.array(group[key])
                    ans = tuple(ans.tolist())
                    kn = key[7:]
                elif load_to_memory:
                    ans = np.array(group[key])
                    kn = key
                else:
                    # leave as h5py dataset
                    ans = group[key]
                    kn = key
                dictionary[kn] = ans
            elif key.startswith('_hspy_AxesManager_'):
                dictionary[key[len('_hspy_AxesManager_'):]] = \
                    AxesManager([i
                                 for k, i in sorted(iter(
                                     hdfgroup2dict(group[key], load_to_memory=load_to_memory).iteritems()))])
            elif key.startswith('_list_'):
                dictionary[key[7 + key[6:].find('_'):]] = \
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(group[key], load_to_memory=load_to_memory).iteritems()))]
            elif key.startswith('_tuple_'):
                dictionary[key[8 + key[7:].find('_'):]] = tuple(
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(group[key], load_to_memory=load_to_memory).iteritems()))])
            else:
                dictionary[key] = {}
                hdfgroup2dict(
                    group[key],
                    dictionary[key],
                    load_to_memory=load_to_memory)
    return dictionary
Exemple #33
0
    def _group2dict(self, group, dictionary=None, lazy=False):
        if dictionary is None:
            dictionary = {}
        for key, value in group.attrs.items():
            if isinstance(value, bytes):
                value = value.decode()
            if isinstance(value, (np.string_, str)):
                if value == '_None_':
                    value = None
            elif isinstance(value, np.bool_):
                value = bool(value)
            elif isinstance(value, np.ndarray) and value.dtype.char == "S":
                # Convert strings to unicode
                value = value.astype("U")
                if value.dtype.str.endswith("U1"):
                    value = value.tolist()
            # skip signals - these are handled below.
            if key.startswith('_sig_'):
                pass
            elif key.startswith('_list_empty_'):
                dictionary[key[len('_list_empty_'):]] = []
            elif key.startswith('_tuple_empty_'):
                dictionary[key[len('_tuple_empty_'):]] = ()
            elif key.startswith('_bs_'):
                dictionary[key[len('_bs_'):]] = value.tobytes()
            # The following two elif stataments enable reading date and time from
            # v < 2 of HyperSpy's metadata specifications
            elif key.startswith('_datetime_date'):
                date_iso = datetime.date(
                    *ast.literal_eval(value[value.index("("):])).isoformat()
                dictionary[key.replace("_datetime_", "")] = date_iso
            elif key.startswith('_datetime_time'):
                date_iso = datetime.time(
                    *ast.literal_eval(value[value.index("("):])).isoformat()
                dictionary[key.replace("_datetime_", "")] = date_iso
            else:
                dictionary[key] = value
        if not isinstance(group, self.Dataset):
            for key in group.keys():
                if key.startswith('_sig_'):
                    from hyperspy.io import dict2signal
                    dictionary[key[len('_sig_'):]] = (
                        dict2signal(self.group2signaldict(
                            group[key], lazy=lazy)))
                elif isinstance(group[key], self.Dataset):
                    dat = group[key]
                    kn = key
                    if key.startswith("_list_"):
                        if (h5py.check_string_dtype(dat.dtype) and
                                hasattr(dat, 'asstr')):
                            # h5py 3.0 and newer
                            # https://docs.h5py.org/en/3.0.0/strings.html
                            dat = dat.asstr()[:]
                        ans = np.array(dat)
                        ans = ans.tolist()
                        kn = key[6:]
                    elif key.startswith("_tuple_"):
                        ans = np.array(dat)
                        ans = tuple(ans.tolist())
                        kn = key[7:]
                    elif dat.dtype.char == "S":
                        ans = np.array(dat)
                        try:
                            ans = ans.astype("U")
                        except UnicodeDecodeError:
                            # There are some strings that must stay in binary,
                            # for example dill pickles. This will obviously also
                            # let "wrong" binary string fail somewhere else...
                            pass
                    elif lazy:
                        ans = da.from_array(dat, chunks=dat.chunks)
                    else:
                        ans = np.array(dat)
                    dictionary[kn] = ans
                elif key.startswith('_hspy_AxesManager_'):
                    dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager(
                        [i for k, i in sorted(iter(
                            self._group2dict(
                                group[key], lazy=lazy).items()
                        ))])
                elif key.startswith('_list_'):
                    dictionary[key[7 + key[6:].find('_'):]] = \
                        [i for k, i in sorted(iter(
                            self._group2dict(
                                group[key], lazy=lazy).items()
                        ))]
                elif key.startswith('_tuple_'):
                    dictionary[key[8 + key[7:].find('_'):]] = tuple(
                        [i for k, i in sorted(iter(
                            self._group2dict(
                                group[key], lazy=lazy).items()
                        ))])
                else:
                    dictionary[key] = {}
                    self._group2dict(
                        group[key],
                        dictionary[key],
                        lazy=lazy)

        return dictionary
Exemple #34
0
def hdfgroup2dict(group, dictionary=None, load_to_memory=True):
    if dictionary is None:
        dictionary = {}
    for key, value in group.attrs.items():
        if isinstance(value, bytes):
            value = value.decode()
        if isinstance(value, (np.string_, str)):
            if value == '_None_':
                value = None
        elif isinstance(value, np.bool_):
            value = bool(value)
        elif isinstance(value, np.ndarray) and value.dtype.char == "S":
            # Convert strings to unicode
            value = value.astype("U")
            if value.dtype.str.endswith("U1"):
                value = value.tolist()
        # skip signals - these are handled below.
        if key.startswith('_sig_'):
            pass
        elif key.startswith('_list_empty_'):
            dictionary[key[len('_list_empty_'):]] = []
        elif key.startswith('_tuple_empty_'):
            dictionary[key[len('_tuple_empty_'):]] = ()
        elif key.startswith('_bs_'):
            dictionary[key[len('_bs_'):]] = value.tostring()
        # The following two elif stataments enable reading date and time from
        # v < 2 of HyperSpy's metadata specifications
        elif key.startswith('_datetime_date'):
            date_iso = datetime.date(
                *ast.literal_eval(value[value.index("("):])).isoformat()
            dictionary[key.replace("_datetime_", "")] = date_iso
        elif key.startswith('_datetime_time'):
            date_iso = datetime.time(
                *ast.literal_eval(value[value.index("("):])).isoformat()
            dictionary[key.replace("_datetime_", "")] = date_iso
        else:
            dictionary[key] = value
    if not isinstance(group, h5py.Dataset):
        for key in group.keys():
            if key.startswith('_sig_'):
                from hyperspy.io import dict2signal
                dictionary[key[len('_sig_'):]] = (dict2signal(
                    hdfgroup2signaldict(group[key],
                                        load_to_memory=load_to_memory)))
            elif isinstance(group[key], h5py.Dataset):
                ans = np.array(group[key])
                if ans.dtype.char == "S":
                    try:
                        ans = ans.astype("U")
                    except UnicodeDecodeError:
                        # There are some strings that must stay in binary,
                        # for example dill pickles. This will obviously also
                        # let "wrong" binary string fail somewhere else...
                        pass
                kn = key
                if key.startswith("_list_"):
                    ans = ans.tolist()
                    kn = key[6:]
                elif key.startswith("_tuple_"):
                    ans = tuple(ans.tolist())
                    kn = key[7:]
                elif load_to_memory:
                    kn = key
                else:
                    # leave as h5py dataset
                    ans = group[key]
                    kn = key
                dictionary[kn] = ans
            elif key.startswith('_hspy_AxesManager_'):
                dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager([
                    i for k, i in sorted(
                        iter(
                            hdfgroup2dict(
                                group[key],
                                load_to_memory=load_to_memory).items()))
                ])
            elif key.startswith('_list_'):
                dictionary[key[7 + key[6:].find('_'):]] = \
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(
                            group[key], load_to_memory=load_to_memory).items()
                    ))]
            elif key.startswith('_tuple_'):
                dictionary[key[8 + key[7:].find('_'):]] = tuple([
                    i for k, i in sorted(
                        iter(
                            hdfgroup2dict(
                                group[key],
                                load_to_memory=load_to_memory).items()))
                ])
            else:
                dictionary[key] = {}
                hdfgroup2dict(group[key],
                              dictionary[key],
                              load_to_memory=load_to_memory)
    return dictionary
Exemple #35
0
def hdfgroup2dict(group, dictionary=None, load_to_memory=True):
    if dictionary is None:
        dictionary = {}
    for key, value in group.attrs.items():
        if isinstance(value, bytes):
            value = value.decode()
        if isinstance(value, (np.string_, str)):
            if value == '_None_':
                value = None
        elif isinstance(value, np.bool_):
            value = bool(value)
        elif isinstance(value, np.ndarray) and value.dtype.char == "S":
            # Convert strings to unicode
            value = value.astype("U")
            if value.dtype.str.endswith("U1"):
                value = value.tolist()
        # skip signals - these are handled below.
        if key.startswith('_sig_'):
            pass
        elif key.startswith('_list_empty_'):
            dictionary[key[len('_list_empty_'):]] = []
        elif key.startswith('_tuple_empty_'):
            dictionary[key[len('_tuple_empty_'):]] = ()
        elif key.startswith('_bs_'):
            dictionary[key[len('_bs_'):]] = value.tostring()
# The following is commented out as it could be used to evaluate
# arbitrary code i.e. it was a security flaw. We should instead
# use a standard string for date and time.
#        elif key.startswith('_datetime_'):
#            dictionary[key.replace("_datetime_", "")] = eval(value)
        else:
            dictionary[key] = value
    if not isinstance(group, h5py.Dataset):
        for key in group.keys():
            if key.startswith('_sig_'):
                from hyperspy.io import dict2signal
                dictionary[key[len('_sig_'):]] = (
                    dict2signal(hdfgroup2signaldict(
                        group[key], load_to_memory=load_to_memory)))
            elif isinstance(group[key], h5py.Dataset):
                ans = np.array(group[key])
                if ans.dtype.char == "S":
                    try:
                        ans = ans.astype("U")
                    except UnicodeDecodeError:
                        # There are some strings that must stay in binary,
                        # for example dill pickles. This will obviously also
                        # let "wrong" binary string fail somewhere else...
                        pass
                kn = key
                if key.startswith("_list_"):
                    ans = ans.tolist()
                    kn = key[6:]
                elif key.startswith("_tuple_"):
                    ans = tuple(ans.tolist())
                    kn = key[7:]
                elif load_to_memory:
                    kn = key
                else:
                    # leave as h5py dataset
                    ans = group[key]
                    kn = key
                dictionary[kn] = ans
            elif key.startswith('_hspy_AxesManager_'):
                dictionary[key[len('_hspy_AxesManager_'):]] = AxesManager(
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(
                            group[key], load_to_memory=load_to_memory).items()
                    ))])
            elif key.startswith('_list_'):
                dictionary[key[7 + key[6:].find('_'):]] = \
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(
                            group[key], load_to_memory=load_to_memory).items()
                    ))]
            elif key.startswith('_tuple_'):
                dictionary[key[8 + key[7:].find('_'):]] = tuple(
                    [i for k, i in sorted(iter(
                        hdfgroup2dict(
                            group[key], load_to_memory=load_to_memory).items()
                    ))])
            else:
                dictionary[key] = {}
                hdfgroup2dict(
                    group[key],
                    dictionary[key],
                    load_to_memory=load_to_memory)
    return dictionary
Exemple #36
0
class TestAxesManager:
    def setup_method(self, method):
        self.axes_list = [{
            'name': 'x',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'y',
            'navigate': True,
            'offset': 0.0,
            'scale': 0.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]

        self.am = AxesManager(self.axes_list)

        self.axes_list2 = [{
            'name': 'x',
            'navigate': True,
            'offset': 0.0,
            'scale': 1.5E-9,
            'size': 1024,
            'units': 'm'
        }, {
            'name': 'energy',
            'navigate': False,
            'offset': 0.0,
            'scale': 2.5,
            'size': 4096,
            'units': 'eV'
        }, {
            'name': 'energy2',
            'navigate': False,
            'offset': 0.0,
            'scale': 5.0,
            'size': 4096,
            'units': 'eV'
        }]
        self.am2 = AxesManager(self.axes_list2)

    def test_compact_unit(self):
        self.am.convert_units()
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['y'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5)
        assert self.am['energy'].units == 'keV'
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)

    def test_convert_to_navigation_units(self):
        self.am.convert_units(axes='navigation', units='mm')
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-6)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_units_axes_integer(self):
        # convert only the first axis
        self.am.convert_units(axes=0, units='nm', same_units=False)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5E-9)
        assert self.am['x'].units == 'm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

        self.am.convert_units(axes=0, units='nm', same_units=True)
        nt.assert_almost_equal(self.am[0].scale, 0.5)
        assert self.am[0].units == 'nm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'

    def test_convert_to_navigation_units_list(self):
        self.am.convert_units(axes='navigation',
                              units=['mm', 'nm'],
                              same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['energy'].scale,
                               self.axes_list[-1]['scale'])

    def test_convert_to_navigation_units_list_same_units(self):
        self.am.convert_units(axes='navigation',
                              units=['mm', 'nm'],
                              same_units=True)
        assert self.am['x'].units == 'mm'
        nt.assert_almost_equal(self.am['x'].scale, 1.5e-6)
        assert self.am['y'].units == 'mm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5e-6)
        assert self.am['energy'].units == 'eV'
        nt.assert_almost_equal(self.am['energy'].scale, 5)

    def test_convert_to_navigation_units_different(self):
        # Don't convert the units since the units of the navigation axes are
        # different
        self.axes_list.insert(
            0, {
                'name': 'time',
                'navigate': True,
                'offset': 0.0,
                'scale': 1.5,
                'size': 20,
                'units': 's'
            })
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['time'].units == 's'
        nt.assert_almost_equal(am['time'].scale, 1.5)
        assert am['x'].units == 'nm'
        nt.assert_almost_equal(am['x'].scale, 1.5)
        assert am['y'].units == 'nm'
        nt.assert_almost_equal(am['y'].scale, 0.5)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_navigation_units_Undefined(self):
        self.axes_list[0]['units'] = t.Undefined
        am = AxesManager(self.axes_list)
        am.convert_units(axes='navigation', same_units=True)
        assert am['x'].units == t.Undefined
        nt.assert_almost_equal(am['x'].scale, 1.5E-9)
        assert am['y'].units == 'm'
        nt.assert_almost_equal(am['y'].scale, 0.5E-9)
        assert am['energy'].units == 'eV'
        nt.assert_almost_equal(am['energy'].scale, 5)

    def test_convert_to_signal_units(self):
        self.am.convert_units(axes='signal', units='keV')
        nt.assert_almost_equal(self.am['x'].scale, self.axes_list[0]['scale'])
        assert self.am['x'].units == self.axes_list[0]['units']
        nt.assert_almost_equal(self.am['y'].scale, self.axes_list[1]['scale'])
        assert self.am['y'].units == self.axes_list[1]['units']
        nt.assert_almost_equal(self.am['energy'].scale, 0.005)
        assert self.am['energy'].units == 'keV'

    def test_convert_to_units_list(self):
        self.am.convert_units(units=['µm', 'nm', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am['x'].scale, 1.5)
        assert self.am['x'].units == 'nm'
        nt.assert_almost_equal(self.am['y'].scale, 0.5E-3)
        assert self.am['y'].units == 'um'
        nt.assert_almost_equal(self.am['energy'].scale, 5E3)
        assert self.am['energy'].units == 'meV'

    def test_convert_to_units_list_same_units(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=True)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale,
                               self.axes_list2[1]['scale'])
        assert self.am2['energy'].units == self.axes_list2[1]['units']
        nt.assert_almost_equal(self.am2['energy2'].scale,
                               self.axes_list2[2]['scale'])
        assert self.am2['energy2'].units == self.axes_list2[2]['units']

    def test_convert_to_units_list_signal2D(self):
        self.am2.convert_units(units=['µm', 'eV', 'meV'], same_units=False)
        nt.assert_almost_equal(self.am2['x'].scale, 0.0015)
        assert self.am2['x'].units == 'um'
        nt.assert_almost_equal(self.am2['energy'].scale, 2500)
        assert self.am2['energy'].units == 'meV'
        nt.assert_almost_equal(self.am2['energy2'].scale, 5.0)
        assert self.am2['energy2'].units == 'eV'

    @pytest.mark.parametrize("same_units", (True, False))
    def test_convert_to_units_unsupported_units(self, same_units):
        with assert_warns(message="not supported for conversion.",
                          category=UserWarning):
            self.am.convert_units('navigation',
                                  units='toto',
                                  same_units=same_units)
        assert_deep_almost_equal(self.am._get_axes_dicts(), self.axes_list)
Exemple #37
0
class Signal(t.HasTraits, MVA):
    data = t.Any()
    axes_manager = t.Instance(AxesManager)
    original_parameters = t.Instance(Parameters)
    mapped_parameters = t.Instance(Parameters)
    physical_property = t.Str()

    def __init__(self, file_data_dict=None, *args, **kw):
        """All data interaction is made through this class or its subclasses


        Parameters:
        -----------
        dictionary : dictionary
           see load_dictionary for the format
        """
        super(Signal, self).__init__()
        self.mapped_parameters = Parameters()
        self.original_parameters = Parameters()
        if type(file_data_dict).__name__ == "dict":
            self.load_dictionary(file_data_dict)
        self._plot = None
        self.mva_results = MVA_Results()
        self._shape_before_unfolding = None
        self._axes_manager_before_unfolding = None

    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])

    def _get_signal_dict(self):
        dic = {}
        dic['data'] = self.data.copy()
        dic['axes'] = self.axes_manager._get_axes_dicts()
        dic['mapped_parameters'] = \
        self.mapped_parameters._get_parameters_dictionary()
        dic['original_parameters'] = \
        self.original_parameters._get_parameters_dictionary()
        return dic

    def _get_undefined_axes_list(self):
        axes = []
        for i in xrange(len(self.data.shape)):
            axes.append({
                'name': 'undefined',
                'scale': 1.,
                'offset': 0.,
                'size': int(self.data.shape[i]),
                'units': 'undefined',
                'index_in_array': i,
            })
        return axes

    def __call__(self, axes_manager=None):
        if axes_manager is None:
            axes_manager = self.axes_manager
        return self.data.__getitem__(axes_manager._getitem_tuple)

    def _get_hse_1D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        inslice = self.axes_manager._non_slicing_axes[0].index_in_array
        if islice > inslice:
            return self.data.squeeze()
        else:
            return self.data.squeeze().T

    def _get_hse_2D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        data = self.data.sum(islice)
        return data

    def _get_hie_explorer(self, *args, **kwargs):
        isslice = [
            self.axes_manager._slicing_axes[0].index_in_array,
            self.axes_manager._slicing_axes[1].index_in_array
        ]
        isslice.sort()
        data = self.data.sum(isslice[1]).sum(isslice[0])
        return data

    def _get_explorer(self, *args, **kwargs):
        nav_dim = self.axes_manager.navigation_dimension
        if self.axes_manager.signal_dimension == 1:
            if nav_dim == 1:
                return self._get_hse_1D_explorer(*args, **kwargs)
            elif nav_dim == 2:
                return self._get_hse_2D_explorer(*args, **kwargs)
            else:
                return None
        if self.axes_manager.signal_dimension == 2:
            if nav_dim == 1 or nav_dim == 2:
                return self._get_hie_explorer(*args, **kwargs)
            else:
                return None
        else:
            return None

    def plot(self, axes_manager=None):
        if self._plot is not None:
            try:
                self._plot.close()
            except:
                # If it was already closed it will raise an exception,
                # but we want to carry on...
                pass

        if axes_manager is None:
            axes_manager = self.axes_manager

        if axes_manager.signal_dimension == 1:
            # Hyperspectrum

            self._plot = mpl_hse.MPL_HyperSpectrum_Explorer()
            self._plot.spectrum_data_function = self.__call__
            self._plot.spectrum_title = self.mapped_parameters.name
            self._plot.xlabel = '%s (%s)' % (
                self.axes_manager._slicing_axes[0].name,
                self.axes_manager._slicing_axes[0].units)
            self._plot.ylabel = 'Intensity'
            self._plot.axes_manager = axes_manager
            self._plot.axis = self.axes_manager._slicing_axes[0].axis

            # Image properties
            if self.axes_manager._non_slicing_axes:
                self._plot.image_data_function = self._get_explorer
                self._plot.image_title = ''
                self._plot.pixel_size = \
                self.axes_manager._non_slicing_axes[0].scale
                self._plot.pixel_units = \
                self.axes_manager._non_slicing_axes[0].units
            self._plot.plot()

        elif axes_manager.signal_dimension == 2:

            # Mike's playground with new plotting toolkits - needs to be a
            # branch.
            """
            if len(self.data.shape)==2:
                from drawing.guiqwt_hie import image_plot_2D
                image_plot_2D(self)

            import drawing.chaco_hie
            self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self)
            self._plot.configure_traits()
            """
            self._plot = mpl_hie.MPL_HyperImage_Explorer()
            self._plot.image_data_function = self.__call__
            self._plot.navigator_data_function = self._get_explorer
            self._plot.axes_manager = axes_manager
            self._plot.plot()

        else:
            messages.warning_exit('Plotting is not supported for this view')

    traits_view = tui.View(
        tui.Item('name'),
        tui.Item('physical_property'),
        tui.Item('units'),
        tui.Item('offset'),
        tui.Item('scale'),
    )

    def plot_residual(self, axes_manager=None):
        """Plot the residual between original data and reconstructed data

        Requires you to have already run PCA or ICA, and to reconstruct data
        using either the pca_build_SI or ica_build_SI methods.
        """

        if hasattr(self, 'residual'):
            self.residual.plot(axes_manager)
        else:
            print "Object does not have any residual information.  Is it a \
reconstruction created using either pca_build_SI or ica_build_SI methods?"

    def save(self, filename, only_view=False, **kwds):
        """Saves the signal in the specified format.

        The function gets the format from the extension. You can use:
            - hdf5 for HDF5
            - nc for NetCDF
            - msa for EMSA/MSA single spectrum saving.
            - bin to produce a raw binary file
            - Many image formats such as png, tiff, jpeg...

        Please note that not all the formats supports saving datasets of
        arbitrary dimensions, e.g. msa only suports 1D data.

        Parameters
        ----------
        filename : str
        msa_format : {'Y', 'XY'}
            'Y' will produce a file without the energy axis. 'XY' will also
            save another column with the energy axis. For compatibility with
            Gatan Digital Micrograph 'Y' is the default.
        only_view : bool
            If True, only the current view will be saved. Otherwise the full
            dataset is saved. Please note that not all the formats support this
            option at the moment.
        """
        io.save(filename, self, **kwds)

    def _replot(self):
        if self._plot is not None:
            if self._plot.is_active() is True:
                self.plot()

    def get_dimensions_from_data(self):
        """Get the dimension parameters from the data_cube. Useful when the
        data_cube was externally modified, or when the SI was not loaded from
        a file
        """
        dc = self.data
        for axis in self.axes_manager.axes:
            axis.size = int(dc.shape[axis.index_in_array])
            print("%s size: %i" % (axis.name, dc.shape[axis.index_in_array]))
        self._replot()

    def crop_in_pixels(self, axis, i1=None, i2=None):
        """Crops the data in a given axis. The range is given in pixels
        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_units
        """
        axis = self._get_positive_axis_index_index(axis)
        if i1 is not None:
            new_offset = self.axes_manager.axes[axis].axis[i1]
        # We take a copy to guarantee the continuity of the data
        self.data = self.data[(slice(None), ) * axis +
                              (slice(i1, i2), Ellipsis)].copy()

        if i1 is not None:
            self.axes_manager.axes[axis].offset = new_offset
        self.get_dimensions_from_data()

    def crop_in_units(self, axis, x1=None, x2=None):
        """Crops the data in a given axis. The range is given in the units of
        the axis

        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_pixels

        """
        i1 = self.axes_manager.axes[axis].value2index(x1)
        i2 = self.axes_manager.axes[axis].value2index(x2)
        self.crop_in_pixels(axis, i1, i2)

    def roll_xy(self, n_x, n_y=1):
        """Roll over the x axis n_x positions and n_y positions the former rows

        This method has the purpose of "fixing" a bug in the acquisition of the
        Orsay's microscopes and probably it does not have general interest

        Parameters
        ----------
        n_x : int
        n_y : int

        Note: Useful to correct the SI column storing bug in Marcel's
        acquisition routines.
        """
        self.data = np.roll(self.data, n_x, 0)
        self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1)
        self._replot()

    # TODO: After using this function the plotting does not work
    def swap_axis(self, axis1, axis2):
        """Swaps the axes

        Parameters
        ----------
        axis1 : positive int
        axis2 : positive int
        """
        self.data = self.data.swapaxes(axis1, axis2)
        c1 = self.axes_manager.axes[axis1]
        c2 = self.axes_manager.axes[axis2]
        c1.index_in_array, c2.index_in_array =  \
            c2.index_in_array, c1.index_in_array
        self.axes_manager.axes[axis1] = c2
        self.axes_manager.axes[axis2] = c1
        self.axes_manager.set_signal_dimension()
        self._replot()

    def rebin(self, new_shape):
        """
        Rebins the data to the new shape

        Parameters
        ----------
        new_shape: tuple of ints
            The new shape must be a divisor of the original shape
        """
        factors = np.array(self.data.shape) / np.array(new_shape)
        self.data = utils.rebin(self.data, new_shape)
        for axis in self.axes_manager.axes:
            axis.scale *= factors[axis.index_in_array]
        self.get_dimensions_from_data()

    def split_in(self, axis, number_of_parts=None, steps=None):
        """Splits the data

        The split can be defined either by the `number_of_parts` or by the
        `steps` size.

        Parameters
        ----------
        number_of_parts : int or None
            Number of parts in which the SI will be splitted
        steps : int or None
            Size of the splitted parts
        axis : int
            The splitting axis

        Return
        ------
        tuple with the splitted signals
        """
        axis = self._get_positive_axis_index_index(axis)
        if number_of_parts is None and steps is None:
            if not self._splitting_steps:
                messages.warning_exit(
                    "Please provide either number_of_parts or a steps list")
            else:
                steps = self._splitting_steps
                print "Splitting in ", steps
        elif number_of_parts is not None and steps is not None:
            print "Using the given steps list. number_of_parts dimissed"
        splitted = []
        shape = self.data.shape

        if steps is None:
            rounded = (shape[axis] - (shape[axis] % number_of_parts))
            step = rounded / number_of_parts
            cut_node = range(0, rounded + step, step)
        else:
            cut_node = np.array([0] + steps).cumsum()
        for i in xrange(len(cut_node) - 1):
            data = self.data[(slice(None), ) * axis +
                             (slice(cut_node[i], cut_node[i + 1]), Ellipsis)]
            s = Signal({'data': data})
            # TODO: When copying plotting does not work
            #            s.axes = copy.deepcopy(self.axes_manager)
            s.get_dimensions_from_data()
            splitted.append(s)
        return splitted

    def unfold_if_multidim(self):
        """Unfold the datacube if it is >2D

        Returns
        -------

        Boolean. True if the data was unfolded by the function.
        """
        if len(self.axes_manager.axes) > 2:
            print "Automatically unfolding the data"
            self.unfold()
            return True
        else:
            return False

    def _unfold(self, steady_axes, unfolded_axis):
        """Modify the shape of the data by specifying the axes the axes which
        dimension do not change and the axis over which the remaining axes will
        be unfolded

        Parameters
        ----------
        steady_axes : list
            The indexes of the axes which dimensions do not change
        unfolded_axis : int
            The index of the axis over which all the rest of the axes (except
            the steady axes) will be unfolded

        See also
        --------
        fold
        """

        # It doesn't make sense unfolding when dim < 3
        if len(self.data.squeeze().shape) < 3:
            return False

        # We need to store the original shape and coordinates to be used by
        # the fold function only if it has not been already stored by a
        # previous unfold
        if self._shape_before_unfolding is None:
            self._shape_before_unfolding = self.data.shape
            self._axes_manager_before_unfolding = self.axes_manager

        new_shape = [1] * len(self.data.shape)
        for index in steady_axes:
            new_shape[index] = self.data.shape[index]
        new_shape[unfolded_axis] = -1
        self.data = self.data.reshape(new_shape)
        self.axes_manager = self.axes_manager.deepcopy()
        i = 0
        uname = ''
        uunits = ''
        to_remove = []
        for axis, dim in zip(self.axes_manager.axes, new_shape):
            if dim == 1:
                uname += ',' + axis.name
                uunits = ',' + axis.units
                to_remove.append(axis)
            else:
                axis.index_in_array = i
                i += 1
        self.axes_manager.axes[unfolded_axis].name += uname
        self.axes_manager.axes[unfolded_axis].units += uunits
        self.axes_manager.axes[unfolded_axis].size = \
                                                self.data.shape[unfolded_axis]
        for axis in to_remove:
            self.axes_manager.axes.remove(axis)

        self.data = self.data.squeeze()
        self._replot()

    def unfold(self):
        """Modifies the shape of the data by unfolding the signal and
        navigation dimensions separaterly

        """
        self.unfold_navigation_space()
        self.unfold_signal_space()

    def unfold_navigation_space(self):
        """Modify the shape of the data to obtain a navigation space of
        dimension 1
        """

        if self.axes_manager.navigation_dimension < 2:
            messages.information('Nothing done, the navigation dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._slicing_axes
        ]
        unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def unfold_signal_space(self):
        """Modify the shape of the data to obtain a signal space of
        dimension 1
        """
        if self.axes_manager.signal_dimension < 2:
            messages.information('Nothing done, the signal dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._non_slicing_axes
        ]
        unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def fold(self):
        """If the signal was previously unfolded, folds it back"""
        if self._shape_before_unfolding is not None:
            self.data = self.data.reshape(self._shape_before_unfolding)
            self.axes_manager = self._axes_manager_before_unfolding
            self._shape_before_unfolding = None
            self._axes_manager_before_unfolding = None
            self._replot()

    def _get_positive_axis_index_index(self, axis):
        if axis < 0:
            axis = len(self.data.shape) + axis
        return axis

    def iterate_axis(self, axis=-1):
        # We make a copy to guarantee that the data in contiguous, otherwise
        # it will not return a view of the data
        self.data = self.data.copy()
        axis = self._get_positive_axis_index_index(axis)
        unfolded_axis = axis - 1
        new_shape = [1] * len(self.data.shape)
        new_shape[axis] = self.data.shape[axis]
        new_shape[unfolded_axis] = -1
        # Warning! if the data is not contigous it will make a copy!!
        data = self.data.reshape(new_shape)
        for i in xrange(data.shape[unfolded_axis]):
            getitem = [0] * len(data.shape)
            getitem[axis] = slice(None)
            getitem[unfolded_axis] = i
            yield (data[getitem])

    def sum(self, axis, return_signal=False):
        """Sum the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
             will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.sum(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.sum(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.sum(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def mean(self, axis, return_signal=False):
        """Average the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
            will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.mean(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.mean(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.mean(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def copy(self):
        return (copy.copy(self))

    def deepcopy(self):
        return (copy.deepcopy(self))

#    def sum_in_mask(self, mask):
#        """Returns the result of summing all the spectra in the mask.
#
#        Parameters
#        ----------
#        mask : boolean numpy array
#
#        Returns
#        -------
#        Spectrum
#        """
#        dc = self.data_cube.copy()
#        mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape)
#        dc = (mask3D*dc).sum(1).sum(1) / mask.sum()
#        s = Spectrum()
#        s.data_cube = dc.reshape((-1,1,1))
#        s.get_dimensions_from_cube()
#        utils.copy_energy_calibration(self,s)
#        return s
#
#    def mean(self, axis):
#        """Average the SI over the given axis
#
#        Parameters
#        ----------
#        axis : int
#        """
#        dc = self.data_cube
#        dc = dc.mean(axis)
#        dc = dc.reshape(list(dc.shape) + [1,])
#        self.data_cube = dc
#        self.get_dimensions_from_cube()
#
#    def roll(self, axis = 2, shift = 1):
#        """Roll the SI. see numpy.roll
#
#        Parameters
#        ----------
#        axis : int
#        shift : int
#        """
#        self.data_cube = np.roll(self.data_cube, shift, axis)
#        self._replot()
#

#
#    def get_calibration_from(self, s):
#        """Copy the calibration from another Spectrum instance
#        Parameters
#        ----------
#        s : spectrum instance
#        """
#        utils.copy_energy_calibration(s, self)
#
#    def estimate_variance(self, dc = None, gaussian_noise_var = None):
#        """Variance estimation supposing Poissonian noise
#
#        Parameters
#        ----------
#        dc : None or numpy array
#            If None the SI is used to estimate its variance. Otherwise, the
#            provided array will be used.
#        Note
#        ----
#        The gain_factor and gain_offset from the aquisition parameters are used
#        """
#        print "Variace estimation using the following values:"
#        print "Gain factor = ", self.acquisition_parameters.gain_factor
#        print "Gain offset = ", self.acquisition_parameters.gain_offset
#        if dc is None:
#            dc = self.data_cube
#        gain_factor = self.acquisition_parameters.gain_factor
#        gain_offset = self.acquisition_parameters.gain_offset
#        self.variance = dc*gain_factor + gain_offset
#        if self.variance.min() < 0:
#            if gain_offset == 0 and gaussian_noise_var is None:
#                print "The variance estimation results in negative values"
#                print "Maybe the gain_offset is wrong?"
#                self.variance = None
#                return
#            elif gaussian_noise_var is None:
#                print "Clipping the variance to the gain_offset value"
#                self.variance = np.clip(self.variance, np.abs(gain_offset),
#                np.Inf)
#            else:
#                print "Clipping the variance to the gaussian_noise_var"
#                self.variance = np.clip(self.variance, gaussian_noise_var,
#                np.Inf)
#
#    def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6,
#    modify_calibration = True):
#        dispersion = (rcE - lcE) / (rc - lc)
#        origin = lcE - dispersion * lc
#        print "Energy step = ", dispersion
#        print "Energy origin = ", origin
#        if modify_calibration is True:
#            self.set_new_calibration(origin, dispersion)
#        return origin, dispersion
#

    def _correct_navigation_mask_when_unfolded(
        self,
        navigation_mask=None,
    ):
        #if 'unfolded' in self.history:
        if navigation_mask is not None:
            navigation_mask = navigation_mask.reshape((-1, ))
        return navigation_mask