コード例 #1
0
class CameraUI(traits.HasTraits):
    """Camera settings defines basic camera settings
    """
    camera_control = traits.Instance(Camera, transient = True)
    
    cameras = traits.List([_NO_CAMERAS],transient = True)
    camera = traits.Any(value = _NO_CAMERAS, desc = 'camera serial number', editor = ui.EnumEditor(name = 'cameras'))
    
    search = traits.Button(desc = 'camera search action')

    _is_initialized= traits.Bool(False, transient = True)
    
    play = traits.Button(desc = 'display preview action')
    stop = traits.Button(desc = 'close preview action')
    on_off = traits.Button('On/Off', desc = 'initiate/Uninitiate camera action')

    gain = create_range_feature('gain',desc = 'camera gain',transient = True)
    shutter = create_range_feature('shutter', desc = 'camera exposure time',transient = True)
    format = create_mapped_feature('format',_FORMAT, desc = 'image format',transient = True)
    roi = traits.Instance(ROI,transient = True)
    
    im_shape = traits.Property(depends_on = 'format.value,roi.values')
    im_dtype = traits.Property(depends_on = 'format.value')
    
    capture = traits.Button()
    save_button = traits.Button('Save as...')
    
    message = traits.Str(transient = True)
    
    view = ui.View(ui.Group(ui.HGroup(ui.Item('camera', springy = True),
                           ui.Item('search', show_label = False, springy = True),
                           ui.Item('on_off', show_label = False, springy = True),
                           ui.Item('play', show_label = False, enabled_when = 'is_initialized', springy = True),
                           ui.Item('stop', show_label = False, enabled_when = 'is_initialized', springy = True),
                           ),
                    ui.Group(
                        ui.Item('gain', style = 'custom'),
                        ui.Item('shutter', style = 'custom'),
                        ui.Item('format', style = 'custom'),
                        ui.Item('roi', style = 'custom'),
                        ui.HGroup(ui.Item('capture',show_label = False),
                        ui.Item('save_button',show_label = False)),
                        enabled_when = 'is_initialized',
                        ),
                        ),
                resizable = True,
                statusbar = [ ui.StatusItem( name = 'message')],
                buttons = ['OK'])
    
    #default initialization    
    def __init__(self, **kw):
        super(CameraUI, self).__init__(**kw)
        self.search_cameras()

    def _camera_control_default(self):
        return Camera()

    def _roi_default(self):
        return ROI()
        
    #@display_cls_error 
    def _get_im_shape(self):
        top, left, width, height = self.roi.values
        shape = (height, width)
        try:
            colors = _COLORS[self.format.value] 
            if colors > 1:
                shape += (colors,)
        except KeyError:
            raise NotImplementedError('Unsupported format')  
        return shape
    
    #@display_cls_error    
    def _get_im_dtype(self):
        try:        
            return _DTYPE[self.format.value]
        except KeyError:
            raise NotImplementedError('Unsupported format')        
        
   
    def _search_fired(self):
        self.search_cameras()
        
    #@display_cls_error
    def search_cameras(self):
        """
        Finds cameras if any and selects first from list
        """
        try:
            cameras = get_number_cameras()
        except Exception as e:
            cameras = []
            raise e
        finally:
            if len(cameras) == 0:
                cameras = [_NO_CAMERAS]
            self.cameras = cameras
            self.camera = cameras[0]

    #@display_cls_error
    def _camera_changed(self):
        if self._is_initialized:
            self._is_initialized= False
            self.camera_control.close()
            self.message = 'Camera uninitialized'
    
    #@display_cls_error
    def init_camera(self):
        self._is_initialized= False
        if self.camera != _NO_CAMERAS:
            self.camera_control.init(self.camera)
            self.init_features()
            self._is_initialized= True
            self.message = 'Camera initialized'
            
    #@display_cls_error
    def _on_off_fired(self):
        if self._is_initialized:
            self._is_initialized= False
            self.camera_control.close()
            self.message = 'Camera uninitialized'
        else:
            self.init_camera()
            
    #@display_cls_error
    def init_features(self):
        """
        Initializes all features to values given by the camera
        """
        features = self.camera_control.get_camera_features()
        self._init_single_valued_features(features)
        self._init_roi(features)
    
    #@display_cls_error
    def _init_single_valued_features(self, features):
        """
        Initializes all single valued features to camera values
        """
        for name, id in list(_SINGLE_VALUED_FEATURES.items()):
            feature = getattr(self, name)
            feature.low, feature.high = features[id]['params'][0]
            feature.value = self.camera_control.get_feature(id)[0]
            
    #@display_cls_error
    def _init_roi(self, features):
        for i,name in enumerate(('top','left','width','height')):
            feature = getattr(self.roi, name)
            low, high = features[FEATURE_ROI]['params'][i]
            value = self.camera_control.get_feature(FEATURE_ROI)[i]
            try:
                feature.value = value
            finally:
                feature.low, feature.high = low, high
                       
    @traits.on_trait_change('format.value')
    def _on_format_change(self, object, name, value):
        if self._is_initialized:
            self.camera_control.set_preview_state(STOP_PREVIEW)
            self.camera_control.set_stream_state(STOP_STREAM)
            self.set_feature(FEATURE_PIXEL_FORMAT, [value])
            
    @traits.on_trait_change('gain.value,shutter.value')
    def _single_valued_feature_changed(self, object, name, value):
        if self._is_initialized:
            self.set_feature(object.id, [value])

    #@display_cls_error
    def set_feature(self, id, values, flags = 2):
        self.camera_control.set_feature(id, values, flags = flags)
            
    @traits.on_trait_change('roi.values')
    def a_roi_feature_changed(self, object, name, value):
        if self._is_initialized:
            self.set_feature(FEATURE_ROI, value)
            try:
                self._is_initialized= False
                self.init_features()
            finally:
                self._is_initialized= True
        
    #@display_cls_error                    
    def _play_fired(self):
        self.camera_control.set_preview_state(STOP_PREVIEW)
        self.camera_control.set_stream_state(STOP_STREAM)
        self.camera_control.set_stream_state(START_STREAM)
        self.camera_control.set_preview_state(START_PREVIEW)
        
    #@display_cls_error
    def _stop_fired(self): 
        self.camera_control.set_preview_state(STOP_PREVIEW)
        self.camera_control.set_stream_state(STOP_STREAM)
        self.error = ''
 
    #@display_cls_error
    def _format_changed(self, value):
        self.camera_control.set_preview_state(STOP_PREVIEW)
        self.camera_control.set_stream_state(STOP_STREAM)
        self.camera_control.set_feature(FEATURE_PIXEL_FORMAT, [value],2)
    
    #@display_cls_error
    def _capture_fired(self):
        self.camera_control.set_stream_state(STOP_STREAM)
        self.camera_control.set_stream_state(START_STREAM)
        im = self.capture_image()
        plt.imshow(im)
        plt.show()

    def capture_image(self):
        im = numpy.empty(shape = self.im_shape, dtype = self.im_dtype)
        self.camera_control.get_next_frame(im)
        return im.newbyteorder('>')
        
    def save_image(self, fname):
        """Captures image and saves to format guessed from filename extension"""
        im = self.capture_image()
        base, ext = os.path.splitext(fname)
        if ext == '.npy':
            numpy.save(fname, im)
        else:
            im = toimage(im)
            im.save(fname)

    def _save_button_fired(self):
        f = pyface.FileDialog(action = 'save as') 
                       #wildcard = self.filter)
        if f.open() == pyface.OK: 
            self.save_image(f.path)                 

    def capture_HDR(self):
        pass
                
    def __del__(self):
        try:
            self.camera_control.set_preview_state(STOP_PREVIEW)
            self.camera_control.set_stream_state(STOP_STREAM)
        except:
            pass
コード例 #2
0
class DeviceModel(traits.HasTraits):
    """Represent the trigger device in the host computer, and push any state

    We keep a local copy of the state of the device in memory on the
    host computer, and any state changes to the device to through this
    class, also allowing us to update our copy of the state.

    """
    # Private runtime details
    _libusb_handle = traits.Any(None,transient=True)
    _lock = traits.Any(None,transient=True) # lock access to the handle
    real_device = traits.Bool(False,transient=True) # real USB device present
    FOSC = traits.Float(8000000.0,transient=True)

    ignore_version_mismatch = traits.Bool(False, transient=True)

    # A couple properties
    frames_per_second = RemoteFpsFloat
    frames_per_second_actual = traits.Property(depends_on='_t3_state')
    timer3_top = traits.Property(depends_on='_t3_state')

    # Timer 3 state:
    _t3_state = traits.Instance(DeviceTimer3State) # atomic updates

    # LEDs state
    _led_state = traits.Int

    led1 = traits.Property(depends_on='_led_state')
    led2 = traits.Property(depends_on='_led_state')
    led3 = traits.Property(depends_on='_led_state')
    led4 = traits.Property(depends_on='_led_state')

    # Event would be fine for these, but use Button to get nice editor
    reset_framecount_A = traits.Button
    reset_AIN_overflow = traits.Button
    do_single_frame_pulse = traits.Button

    ext_trig1 = traits.Button
    ext_trig2 = traits.Button
    ext_trig3 = traits.Button

    # Analog input state:
    _ain_state = traits.Instance(DeviceAnalogInState) # atomic updates
    Vcc = traits.Property(depends_on='_ain_state')

    AIN_running = traits.Property(depends_on='_ain_state')
    enabled_channels = traits.Property(depends_on='_ain_state')
    enabled_channel_names = traits.Property(depends_on='_ain_state')

    # The view:
    traits_view = View(Group( Group(Item('frames_per_second',
                                         label='frame rate',
                                         ),
                                    Item('frames_per_second_actual',
                                         show_label=False,
                                         style='readonly',
                                         ),
                                    orientation='horizontal',),
                              Group(Item('ext_trig1',show_label=False),
                                    Item('ext_trig2',show_label=False),
                                    Item('ext_trig3',show_label=False),
                                    orientation='horizontal'),
                              Item('_ain_state',show_label=False,
                                   style='custom'),
                              Item('reset_AIN_overflow',show_label=False),
                              ))

    def __init__(self,*a,**k):
        super(DeviceModel,self).__init__(*a,**k)
        self._t3_state = DeviceTimer3State()
        self._ain_state = DeviceAnalogInState(trigger_device=self)

    def __new__(cls,*args,**kwargs):
        """Set the transient object state

        This must be done outside of __init__, because instances can
        get created without calling __init__. In particular, when
        being loaded from a pickle.
        """
        self = super(DeviceModel, cls).__new__(cls,*args,**kwargs)
        self._lock = threading.Lock()
        self._open_device()
        # force the USBKEY's state to our idea of its state
        self.__led_state_changed()
        self.__t3_state_changed()
        self.__ain_state_changed()
        self.reset_AIN_overflow = True # reset ain overflow

        #self.rand_pulse_enable()
        #self.rand_pulse_disable()
        #self.set_aout_values(300,250)

        return self

    def _set_led_mask(self,led_mask,value):
        if value:
            self._led_state = self._led_state | led_mask
        else:
            self._led_state = self._led_state & ~led_mask

    def __led_state_changed(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_SET_LED_STATE)
        buf[1] = chr(self._led_state)
        self._send_buf(buf)

    @traits.cached_property
    def _get_led1(self):
        return bool(self._led_state & LEDS_LED1)
    def _set_led1(self,value):
        self._set_led_mask(LEDS_LED1,value)

    @traits.cached_property
    def _get_led2(self):
        return bool(self._led_state & LEDS_LED2)
    def _set_led2(self,value):
        self._set_led_mask(LEDS_LED2,value)

    @traits.cached_property
    def _get_led3(self):
        return bool(self._led_state & LEDS_LED3)
    def _set_led3(self,value):
        self._set_led_mask(LEDS_LED3,value)

    @traits.cached_property
    def _get_led4(self):
        return bool(self._led_state & LEDS_LED4)
    def _set_led4(self,value):
        self._set_led_mask(LEDS_LED4,value)

    @traits.cached_property
    def _get_Vcc(self):
        return self._ain_state.Vcc

    @traits.cached_property
    def _get_AIN_running(self):
        return self._ain_state.AIN_running

    @traits.cached_property
    def _get_enabled_channels(self):
        result = []
        if self._ain_state.AIN0_enabled:
            result.append(0)
        if self._ain_state.AIN1_enabled:
            result.append(1)
        if self._ain_state.AIN2_enabled:
            result.append(2)
        if self._ain_state.AIN3_enabled:
            result.append(3)
        return result

    @traits.cached_property
    def _get_enabled_channel_names(self):
        result = []
        if self._ain_state.AIN0_enabled:
            result.append(self._ain_state.AIN0_name)
        if self._ain_state.AIN1_enabled:
            result.append(self._ain_state.AIN1_name)
        if self._ain_state.AIN2_enabled:
            result.append(self._ain_state.AIN2_name)
        if self._ain_state.AIN3_enabled:
            result.append(self._ain_state.AIN3_name)
        return result

    @traits.cached_property
    def _get_timer3_top(self):
        return self._t3_state.timer3_top

    @traits.cached_property
    def _get_frames_per_second_actual(self):
        if self._t3_state.timer3_CS==0:
            return 0
        return self.FOSC/self._t3_state.timer3_CS/self._t3_state.timer3_top

    def set_frames_per_second_approximate(self,value):
        """Set the framerate as close as possible to the desired value"""
        new_t3_state = DeviceTimer3State()
        if value==0:
            new_t3_state.timer3_CS=0
        else:
            # For all possible clock select values
            CSs = np.array([1.0,8.0,64.0,256.0,1024.0])
            # find the value of top that to gives the desired framerate
            best_top = np.clip(np.round(self.FOSC/CSs/value),0,2**16-1).astype(np.int)
            # and find the what the framerate would be at that top value
            best_rate = self.FOSC/CSs/best_top
            # and choose the best one.
            idx = np.argmin(abs(best_rate-value))
            expected_rate = best_rate[idx]
            new_t3_state.timer3_CS = CSs[idx]
            new_t3_state.timer3_top = best_top[idx]

            ideal_ocr3a = 0.02 * new_t3_state.timer3_top # 2% duty cycle
            ocr3a = int(np.round(ideal_ocr3a))
            if ocr3a==0:
                ocr3a=1
            if ocr3a >= new_t3_state.timer3_top:
                ocr3a-=1
                if ocr3a <= 0:
                    raise ValueError('impossible combination for ocr3a')
            new_t3_state.ocr3a = ocr3a
        self._t3_state = new_t3_state # atomic update

    def get_framestamp(self,full_output=False):
        """Get the framestamp and the value of PORTC

        The framestamp includes fraction of IFI until next frame.

        The inter-frame counter counts up from 0 to self.timer3_top
        between frame ticks.
        """
        if not self.real_device:
            now = time.time()
            if full_output:
                framecount = now//1
                tcnt3 = now%1.0
                results = now, framecount, tcnt3
            else:
                results = now
            return results
        buf = ctypes.create_string_buffer(1)
        buf[0] = chr(CAMTRIG_GET_FRAMESTAMP_NOW)
        self._send_buf(buf)
        data = self._read_buf()
        if data is None:
            raise NoDataError('no data available from device')
        framecount = 0
        for i in range(8):
            framecount += ord(data[i]) << (i*8)
        tcnt3 = ord(data[8]) + (ord(data[9]) << 8)
        frac = tcnt3/float(self._t3_state.timer3_top)
        if frac>1:
            print('In ttriger.DeviceModel.get_framestamp(): '
                  'large fractional value in framestamp. resetting')
            frac=1
        framestamp = framecount+frac
        # WBD  
        #if full_output:
        #    results = framestamp, framecount, tcnt3
        #else:
        #    results = framestamp
        pulse_width = ord(data[10])
        if full_output:
            results = framestamp, pulse_width, framecount, tcnt3
        else:
            results = framestamp, pulse_width

        return results

    def get_analog_input_buffer_rawLE(self):
        if not self.real_device:
            outbuf = np.array([],dtype='<u2') # unsigned 2 byte little endian
            return outbuf
        EP_LEN = 256
        INPUT_BUFFER = ctypes.create_string_buffer(EP_LEN)

        bufs = []
        got_bytes = False
        timeout = 50 # msec

        cnt = 0 # Count number of times endpoint has been read
        min_cnt = 2 # Minimum number of times end point should be read

        while 1:
            # keep pumping until no more data
            try:
                with self._lock:
                    n_bytes = usb.bulk_read(self._libusb_handle, (ENDPOINT_DIR_IN|ANALOG_EPNUM), INPUT_BUFFER, timeout)
            except usb.USBNoDataAvailableError:
                break
            cnt += 1
            n_elements = n_bytes//2
            buf = np.fromstring(INPUT_BUFFER.raw,dtype='<u2') # unsigned 2 byte little endian
            buf = buf[:n_elements]
            bufs.append(buf)
            if (n_bytes < EP_LEN) and (cnt >= min_cnt):
                break # don't bother waiting for data to dribble in

        if len(bufs):
            outbuf = np.hstack(bufs)
        else:
            outbuf = np.array([],dtype='<u2') # unsigned 2 byte little endian
        return outbuf

    def __t3_state_changed(self):
        # A value was assigned to self._t3_state.
        # 1. Send its contents to device
        self._send_t3_state()
        # 2. Ensure updates to it also get sent to device
        if self._t3_state is None:
            return
        self._t3_state.on_trait_change(self._send_t3_state)

    def _send_t3_state(self):
        """ensure our concept of the device's state is correct by setting it"""
        t3 = self._t3_state # shorthand
        if t3 is None:
            return
        buf = ctypes.create_string_buffer(10)
        buf[0] = chr(CAMTRIG_NEW_TIMER3_DATA)

        buf[1] = chr(t3.ocr3a//0x100)
        buf[2] = chr(t3.ocr3a%0x100)
        buf[3] = chr(t3.ocr3b//0x100)
        buf[4] = chr(t3.ocr3b%0x100)

        buf[5] = chr(t3.ocr3c//0x100)
        buf[6] = chr(t3.ocr3c%0x100)
        buf[7] = chr(t3.timer3_top//0x100) # icr3a
        buf[8] = chr(t3.timer3_top%0x100)  # icr3a

        buf[9] = chr(t3.timer3_CS_)
        self._send_buf(buf)

    def __ain_state_changed(self):
        # A value was assigned to self._ain_state.
        # 1. Send its contents to device
        self._send_ain_state()
        # 2. Ensure updates to it also get sent to device
        if self._ain_state is None:
            return
        self._ain_state.on_trait_change(self._send_ain_state)

    def _send_ain_state(self):
        """ensure our concept of the device's state is correct by setting it"""
        ain_state = self._ain_state # shorthand
        if ain_state is None:
            return
        if ain_state.AIN_running:
            # analog_cmd_flags
            channel_list = 0
            if ain_state.AIN0_enabled:
                channel_list |= ENABLE_ADC_CHAN0
            if ain_state.AIN1_enabled:
                channel_list |= ENABLE_ADC_CHAN1
            if ain_state.AIN2_enabled:
                channel_list |= ENABLE_ADC_CHAN2
            if ain_state.AIN3_enabled:
                channel_list |= ENABLE_ADC_CHAN3
            analog_cmd_flags = ADC_START_STREAMING | channel_list
            analog_sample_bits = ain_state.adc_prescaler_ | (ain_state.downsample_bits<<3)
        else:
            analog_cmd_flags = ADC_STOP_STREAMING
            analog_sample_bits = 0

        buf = ctypes.create_string_buffer(3)
        buf[0] = chr(CAMTRIG_AIN_SERVICE)
        buf[1] = chr(analog_cmd_flags)
        buf[2] = chr(analog_sample_bits)
        self._send_buf(buf)

    def enter_dfu_mode(self):
        buf = ctypes.create_string_buffer(1)
        buf[0] = chr(CAMTRIG_ENTER_DFU)
        self._send_buf(buf)

    def _do_single_frame_pulse_fired(self):
        buf = ctypes.create_string_buffer(1)
        buf[0] = chr(CAMTRIG_DO_TRIG_ONCE)
        self._send_buf(buf)

    def _ext_trig1_fired(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_SET_EXT_TRIG)
        buf[1] = chr(EXT_TRIG1)
        self._send_buf(buf)

    def _ext_trig2_fired(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_SET_EXT_TRIG)
        buf[1] = chr(EXT_TRIG2)
        self._send_buf(buf)

    def _ext_trig3_fired(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_SET_EXT_TRIG)
        buf[1] = chr(EXT_TRIG3)
        self._send_buf(buf)

    def _reset_framecount_A_fired(self):
        buf = ctypes.create_string_buffer(1)
        buf[0] = chr(CAMTRIG_RESET_FRAMECOUNT_A)
        self._send_buf(buf)

    def _reset_AIN_overflow_fired(self):
        buf = ctypes.create_string_buffer(3)
        buf[0] = chr(CAMTRIG_AIN_SERVICE)
        buf[1] = chr(ADC_RESET_AIN)
        # 3rd byte doesn't matter
        self._send_buf(buf)

    # WBD - functions for enabling and disabling random pulses
    # --------------------------------------------------------
    def rand_pulse_enable(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_RAND_PULSE)
        buf[1] = chr(RAND_PULSE_ENABLE)
        self._send_buf(buf)

    def rand_pulse_disable(self):
        buf = ctypes.create_string_buffer(2)
        buf[0] = chr(CAMTRIG_RAND_PULSE)
        buf[1] = chr(RAND_PULSE_DISABLE)
        self._send_buf(buf)

    # WBD - function for setting analog output values
    # -------------------------------------------------------
    def set_aout_values(self,val0, val1):
        buf = ctypes.create_string_buffer(5)
        buf[0] = chr(CAMTRIG_SET_AOUT)
        buf[1] = chr(val0//0x100)
        buf[2] = chr(val0%0x100)
        buf[3] = chr(val1//0x100) 
        buf[4] = chr(val1%0x100) 
        self._send_buf(buf)

    # WBD - get pulse width from frame count
    # -------------------------------------------------------
    def get_width_from_framecnt(self,framecnt):
        buf = ctypes.create_string_buffer(5)
        buf[0] = chr(CAMTRIG_GET_PULSE_WIDTH)
        for i in range(1,5):
            buf[i] = chr((framecnt >> ((i-1)*8)) & 0b11111111)
        self._send_buf(buf)
        data = self._read_buf()
        val = ord(data[0])
        return val

    # WBD - modified read_buf functions for multiple epnum in buffers
    # ---------------------------------------------------------------
    def _read_buf(self):
        if not self.real_device:
            return None
        buf = ctypes.create_string_buffer(16)
        timeout = 1000
        epnum = (ENDPOINT_DIR_IN|CAMTRIG_EPNUM)
        with self._lock:
            try:
                val = usb.bulk_read(self._libusb_handle, epnum, buf, timeout)
            except usb.USBNoDataAvailableError:
                return None
        return buf
    # ---------------------------------------------------------------

    def _send_buf(self,buf):
        if not self.real_device:
            return
        with self._lock:
            val = usb.bulk_write(self._libusb_handle, 0x06, buf, 9999)

    def _open_device(self):
        require_trigger = int(os.environ.get('REQUIRE_TRIGGER','1'))
        if require_trigger:

            usb.init()
            if not usb.get_busses():
                usb.find_busses()
                usb.find_devices()

            busses = usb.get_busses()

            found = False
            for bus in busses:
                for dev in bus.devices:
                    debug('idVendor: 0x%04x idProduct: 0x%04x'%
                          (dev.descriptor.idVendor,dev.descriptor.idProduct))
                    if (dev.descriptor.idVendor == 0x1781 and
                        dev.descriptor.idProduct == 0x0BAF):
                        found = True
                        break
                if found:
                    break
            if not found:
                raise RuntimeError("Cannot find device. (Perhaps run with "
                                   "environment variable REQUIRE_TRIGGER=0.)")
        else:
            self.real_device = False
            return
        with self._lock:
            self._libusb_handle = usb.open(dev)

            manufacturer = usb.get_string_simple(self._libusb_handle,dev.descriptor.iManufacturer)
            product = usb.get_string_simple(self._libusb_handle,dev.descriptor.iProduct)
            serial = usb.get_string_simple(self._libusb_handle,dev.descriptor.iSerialNumber)

            assert manufacturer == 'Strawman', 'Wrong manufacturer: %s'%manufacturer
            valid_product = 'Camera Trigger 1.0'
            if product == valid_product:
                self.FOSC = 8000000.0
            elif product.startswith('Camera Trigger 1.01'):
                osc_re = r'Camera Trigger 1.01 \(F_CPU = (.*)\)\w*'
                match = re.search(osc_re,product)
                fosc_str = match.groups()[0]
                if fosc_str.endswith('UL'):
                    fosc_str = fosc_str[:-2]
                self.FOSC = float(fosc_str)
            else:
                errmsg = 'Expected product "%s", but you have "%s"'%(
                    valid_product,product)
                if self.ignore_version_mismatch:
                    print 'WARNING:',errmsg
                    self.FOSC = 8000000.0
                    print ' assuming FOSC=',self.FOSC
                else:
                    raise ValueError(errmsg)

            interface_nr = 0
            if hasattr(usb,'get_driver_np'):
                # non-portable libusb extension
                name = usb.get_driver_np(self._libusb_handle,interface_nr)
                if name != '':
                    usb.detach_kernel_driver_np(self._libusb_handle,interface_nr)

            if dev.descriptor.bNumConfigurations > 1:
                debug("WARNING: more than one configuration, choosing first")

            config = dev.config[0]
            usb.set_configuration(self._libusb_handle, config.bConfigurationValue)
            usb.claim_interface(self._libusb_handle, interface_nr)
        self.real_device = True
コード例 #3
0
class LiveTimestampModeler(traits.HasTraits):
    _trigger_device = traits.Instance(ttrigger.DeviceModel)

    sync_interval = traits.Float(2.0)
    has_ever_synchronized = traits.Bool(False, transient=True)

    frame_offset_changed = traits.Event

    timestamps_framestamps = traits.Array(shape=(None, 2), dtype=np.float)

    timestamp_data = traits.Any()
    block_activity = traits.Bool(False, transient=True)

    synchronize = traits.Button(label='Synchronize')
    synchronizing_info = traits.Any(None)

    gain_offset_residuals = traits.Property(
        depends_on=['timestamps_framestamps'])

    residual_error = traits.Property(depends_on='gain_offset_residuals')

    gain = traits.Property(depends_on='gain_offset_residuals')

    offset = traits.Property(depends_on='gain_offset_residuals')

    frame_offsets = traits.Dict()
    last_frame = traits.Dict()

    view_time_model_plot = traits.Button

    traits_view = View(
        Group(
            Item(
                name='gain',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item(
                name='offset',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat2),
            ),
            Item(
                name='residual_error',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item('synchronize', show_label=False),
            Item('view_time_model_plot', show_label=False),
        ),
        title='Timestamp modeler',
    )

    def _block_activity_changed(self):
        if self.block_activity:
            print('Do not change frame rate or AIN parameters. '
                  'Automatic prevention of doing '
                  'so is not currently implemented.')
        else:
            print('You may change frame rate again')

    def _view_time_model_plot_fired(self):
        raise NotImplementedError('')

    def _synchronize_fired(self):
        if self.block_activity:
            print('Not synchronizing because activity is blocked. '
                  '(Perhaps because you are saving data now.')
            return

        orig_fps = self._trigger_device.frames_per_second_actual
        self._trigger_device.set_frames_per_second_approximate(0.0)
        self._trigger_device.reset_framecount_A = True  # trigger reset event
        self.synchronizing_info = (time.time() + self.sync_interval + 0.1,
                                   orig_fps)

    @traits.cached_property
    def _get_gain(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        return gain

    @traits.cached_property
    def _get_offset(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        return offset

    @traits.cached_property
    def _get_residual_error(self):
        result = self.gain_offset_residuals
        if result is None:
            # not enought data
            return None
        gain, offset, residuals = result
        if residuals is None or len(residuals) == 0:
            # not enought data
            return None
        assert len(residuals) == 1
        return residuals[0]

    @traits.cached_property
    def _get_gain_offset_residuals(self):
        if self.timestamps_framestamps is None:
            return None

        timestamps = self.timestamps_framestamps[:, 0]
        framestamps = self.timestamps_framestamps[:, 1]

        if len(timestamps) < 2:
            return None

        # like model_remote_to_local in flydra.analysis
        remote_timestamps = framestamps
        local_timestamps = timestamps

        a1 = remote_timestamps[:, np.newaxis]
        a2 = np.ones((len(remote_timestamps), 1))
        A = np.hstack((a1, a2))
        b = local_timestamps[:, np.newaxis]
        x, resids, rank, s = np.linalg.lstsq(A, b)

        gain = x[0, 0]
        offset = x[1, 0]
        return gain, offset, resids

    def set_trigger_device(self, device):
        self._trigger_device = device
        self._trigger_device.on_trait_event(
            self._on_trigger_device_reset_AIN_overflow_fired,
            name='reset_AIN_overflow')

    def _on_trigger_device_reset_AIN_overflow_fired(self):
        self.ain_overflowed = 0

    def _get_now_framestamp(self, max_error_seconds=0.003, full_output=False):
        count = 0
        while count <= 10:
            now1 = time.time()
            try:
                results = self._trigger_device.get_framestamp(
                    full_output=full_output)
            except ttrigger.NoDataError:
                raise ImpreciseMeasurementError('no data available')
            now2 = time.time()
            if full_output:
                framestamp, framecount, tcnt = results
            else:
                framestamp = results
            count += 1
            measurement_error = abs(now2 - now1)
            if framestamp % 1.0 < 0.1:
                warnings.warn('workaround of TCNT race condition on MCU...')
                continue
            if measurement_error < max_error_seconds:
                break
            time.sleep(0.01)  # wait 10 msec before trying again
        if not measurement_error < max_error_seconds:
            raise ImpreciseMeasurementError(
                'could not obtain low error measurement')
        if framestamp % 1.0 < 0.1:
            raise ImpreciseMeasurementError('workaround MCU bug')

        now = (now1 + now2) * 0.5
        if full_output:
            results = now, framestamp, now1, now2, framecount, tcnt
        else:
            results = now, framestamp
        return results

    def clear_samples(self, call_update=True):
        self.timestamps_framestamps = np.empty((0, 2))
        if call_update:
            self.update()

    def update(self, return_last_measurement_info=False):
        """call this function fairly often to pump information from the USB device"""
        if self.synchronizing_info is not None:
            done_time, orig_fps = self.synchronizing_info
            # suspended trigger pulses to re-synchronize
            if time.time() >= done_time:
                # we've waited the sync duration, restart
                self._trigger_device.set_frames_per_second_approximate(
                    orig_fps)
                self.clear_samples(call_update=False)  # avoid recursion
                self.synchronizing_info = None
                self.has_ever_synchronized = True

        results = self._get_now_framestamp(
            full_output=return_last_measurement_info)
        now, framestamp = results[:2]
        if return_last_measurement_info:
            start_timestamp, stop_timestamp, framecount, tcnt = results[2:]

        self.timestamps_framestamps = np.vstack(
            (self.timestamps_framestamps, [now, framestamp]))

        # If more than 100 samples,
        if len(self.timestamps_framestamps) > 100:
            # keep only the most recent 50.
            self.timestamps_framestamps = self.timestamps_framestamps[-50:]

        if return_last_measurement_info:
            return start_timestamp, stop_timestamp, framecount, tcnt

    def get_frame_offset(self, id_string):
        return self.frame_offsets[id_string]

    def register_frame(self,
                       id_string,
                       framenumber,
                       frame_timestamp,
                       full_output=False):
        """note that a frame happened and return start-of-frame time"""

        # This may get called from another thread (e.g. the realtime
        # image processing thread).

        # An important note about locking and thread safety: This code
        # relies on the Python interpreter to lock data structures
        # across threads. To do this internally, a lock would be made
        # for each variable in this instance and acquired before each
        # access. Because the data structures are simple Python
        # objects, I believe the operations are atomic and thus this
        # function is OK.

        # Don't trust camera drivers with giving a good timestamp. We
        # only use this to reset our framenumber-to-time data
        # gathering, anyway.
        frame_timestamp = time.time()

        if frame_timestamp is not None:
            last_frame_timestamp = self.last_frame.get(id_string, -np.inf)
            this_interval = frame_timestamp - last_frame_timestamp

            did_frame_offset_change = False
            if this_interval > self.sync_interval:
                if self.block_activity:
                    print(
                        'changing frame offset is disallowed, but you attempted to do it. ignoring.'
                    )
                else:
                    # re-synchronize camera

                    # XXX need to figure out where frame offset of two comes from:
                    self.frame_offsets[id_string] = framenumber - 2
                    did_frame_offset_change = True

            self.last_frame[id_string] = frame_timestamp

            if did_frame_offset_change:
                self.frame_offset_changed = True  # fire any listeners

        result = self.gain_offset_residuals
        if result is None:
            # not enough data
            if full_output:
                results = None, None, did_frame_offset_change
            else:
                results = None
            return results

        gain, offset, residuals = result
        corrected_framenumber = framenumber - self.frame_offsets[id_string]
        trigger_timestamp = corrected_framenumber * gain + offset

        if full_output:
            results = trigger_timestamp, corrected_framenumber, did_frame_offset_change
        else:
            results = trigger_timestamp
        return results
コード例 #4
0
class LiveTimestampModelerWithAnalogInput(LiveTimestampModeler):
    view_AIN = traits.Button(label='view analog input (AIN)')
    viewer = traits.Instance(AnalogInputViewer)

    # the actual analog data (as a wordstream)
    ain_data_raw = traits.Array(dtype=np.uint16, transient=True)
    old_data_raw = traits.Array(dtype=np.uint16, transient=True)

    timer3_top = traits.Property(
    )  # necessary to calculate precise timestamps for AIN data
    channel_names = traits.Property()
    Vcc = traits.Property(depends_on='_trigger_device')
    ain_overflowed = traits.Int(
        0,
        transient=True)  # integer for display (boolean readonly editor ugly)

    ain_wordstream_buffer = traits.Any()
    traits_view = View(
        Group(
            Item('synchronize', show_label=False),
            Item('view_time_model_plot', show_label=False),
            Item('ain_overflowed', style='readonly'),
            Item(
                name='gain',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item(
                name='offset',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat2),
            ),
            Item(
                name='residual_error',
                style='readonly',
                editor=TextEditor(evaluate=float, format_func=myformat),
            ),
            Item('view_AIN', show_label=False),
        ),
        title='Timestamp modeler',
    )

    @traits.cached_property
    def _get_Vcc(self):
        return self._trigger_device.Vcc

    def _get_timer3_top(self):
        return self._trigger_device.timer3_top

    def _get_channel_names(self):
        return self._trigger_device.enabled_channel_names

    def update_analog_input(self):
        """call this function frequently to avoid overruns"""
        new_data_raw = self._trigger_device.get_analog_input_buffer_rawLE()
        data_raw = np.hstack((new_data_raw, self.old_data_raw))
        self.ain_data_raw = new_data_raw
        newdata_all = []
        chan_all = []
        any_overflow = False
        #cum_framestamps = []
        while len(data_raw):
            result = cDecode.process(data_raw)
            (N, samples, channels, did_overflow, framestamp) = result
            if N == 0:
                # no data was able to be processed
                break
            data_raw = data_raw[N:]
            newdata_all.append(samples)
            chan_all.append(channels)
            if did_overflow:
                any_overflow = True
            # Save framestamp data.
            # This is not done yet:
            ## if framestamp is not None:
            ##     cum_framestamps.append( framestamp )
        self.old_data_raw = data_raw  # save unprocessed data for next run

        if any_overflow:
            # XXX should move to logging the error.
            self.ain_overflowed = 1
            raise AnalogDataOverflowedError()

        if len(chan_all) == 0:
            # no data
            return
        chan_all = np.hstack(chan_all)
        newdata_all = np.hstack(newdata_all)
        USB_channel_numbers = np.unique(chan_all)
        #print len(newdata_all),'new samples on channels',USB_channel_numbers

        ## F_OSC = 8000000.0 # 8 MHz
        ## adc_prescaler = 128
        ## downsample = 20 # maybe 21?
        ## n_chan = 3
        ## F_samp = F_OSC/adc_prescaler/downsample/n_chan
        ## dt=1.0/F_samp
        ## ## print '%.1f Hz sampling. %.3f msec dt'%(F_samp,dt*1e3)
        ## MAXLEN_SEC=0.3
        ## #MAXLEN = int(MAXLEN_SEC/dt)
        MAXLEN = 5000  #int(MAXLEN_SEC/dt)
        ## ## print 'MAXLEN',MAXLEN
        ## ## print

        for USB_chan in USB_channel_numbers:
            vi = self.viewer.usb_device_number2index[USB_chan]
            cond = chan_all == USB_chan
            newdata = newdata_all[cond]

            oldidx = self.viewer.channels[vi].index
            olddata = self.viewer.channels[vi].data

            if len(oldidx):
                baseidx = oldidx[-1] + 1
            else:
                baseidx = 0.0
            newidx = np.arange(len(newdata), dtype=np.float) + baseidx

            tmpidx = np.hstack((oldidx, newidx))
            tmpdata = np.hstack((olddata, newdata))

            if len(tmpidx) > MAXLEN:
                # clip to MAXLEN
                self.viewer.channels[vi].index = tmpidx[-MAXLEN:]
                self.viewer.channels[vi].data = tmpdata[-MAXLEN:]
            else:
                self.viewer.channels[vi].index = tmpidx
                self.viewer.channels[vi].data = tmpdata

    def _view_AIN_fired(self):
        self.viewer.edit_traits()
コード例 #5
0
class Signal(t.HasTraits, MVA):
    data = t.Any()
    axes_manager = t.Instance(AxesManager)
    original_parameters = t.Instance(Parameters)
    mapped_parameters = t.Instance(Parameters)
    physical_property = t.Str()

    def __init__(self, file_data_dict=None, *args, **kw):
        """All data interaction is made through this class or its subclasses


        Parameters:
        -----------
        dictionary : dictionary
           see load_dictionary for the format
        """
        super(Signal, self).__init__()
        self.mapped_parameters = Parameters()
        self.original_parameters = Parameters()
        if type(file_data_dict).__name__ == "dict":
            self.load_dictionary(file_data_dict)
        self._plot = None
        self.mva_results = MVA_Results()
        self._shape_before_unfolding = None
        self._axes_manager_before_unfolding = None

    def load_dictionary(self, file_data_dict):
        """Parameters:
        -----------
        file_data_dict : dictionary
            A dictionary containing at least a 'data' keyword with an array of
            arbitrary dimensions. Additionally the dictionary can contain the
            following keys:
                axes: a dictionary that defines the axes (see the
                    AxesManager class)
                attributes: a dictionary which keywords are stored as
                    attributes of the signal class
                mapped_parameters: a dictionary containing a set of parameters
                    that will be stored as attributes of a Parameters class.
                    For some subclasses some particular parameters might be
                    mandatory.
                original_parameters: a dictionary that will be accesible in the
                    original_parameters attribute of the signal class and that
                    typically contains all the parameters that has been
                    imported from the original data file.

        """
        self.data = file_data_dict['data']
        if 'axes' not in file_data_dict:
            file_data_dict['axes'] = self._get_undefined_axes_list()
        self.axes_manager = AxesManager(file_data_dict['axes'])
        if not 'mapped_parameters' in file_data_dict:
            file_data_dict['mapped_parameters'] = {}
        if not 'original_parameters' in file_data_dict:
            file_data_dict['original_parameters'] = {}
        if 'attributes' in file_data_dict:
            for key, value in file_data_dict['attributes'].iteritems():
                self.__setattr__(key, value)
        self.original_parameters.load_dictionary(
            file_data_dict['original_parameters'])
        self.mapped_parameters.load_dictionary(
            file_data_dict['mapped_parameters'])

    def _get_signal_dict(self):
        dic = {}
        dic['data'] = self.data.copy()
        dic['axes'] = self.axes_manager._get_axes_dicts()
        dic['mapped_parameters'] = \
        self.mapped_parameters._get_parameters_dictionary()
        dic['original_parameters'] = \
        self.original_parameters._get_parameters_dictionary()
        return dic

    def _get_undefined_axes_list(self):
        axes = []
        for i in xrange(len(self.data.shape)):
            axes.append({
                'name': 'undefined',
                'scale': 1.,
                'offset': 0.,
                'size': int(self.data.shape[i]),
                'units': 'undefined',
                'index_in_array': i,
            })
        return axes

    def __call__(self, axes_manager=None):
        if axes_manager is None:
            axes_manager = self.axes_manager
        return self.data.__getitem__(axes_manager._getitem_tuple)

    def _get_hse_1D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        inslice = self.axes_manager._non_slicing_axes[0].index_in_array
        if islice > inslice:
            return self.data.squeeze()
        else:
            return self.data.squeeze().T

    def _get_hse_2D_explorer(self, *args, **kwargs):
        islice = self.axes_manager._slicing_axes[0].index_in_array
        data = self.data.sum(islice)
        return data

    def _get_hie_explorer(self, *args, **kwargs):
        isslice = [
            self.axes_manager._slicing_axes[0].index_in_array,
            self.axes_manager._slicing_axes[1].index_in_array
        ]
        isslice.sort()
        data = self.data.sum(isslice[1]).sum(isslice[0])
        return data

    def _get_explorer(self, *args, **kwargs):
        nav_dim = self.axes_manager.navigation_dimension
        if self.axes_manager.signal_dimension == 1:
            if nav_dim == 1:
                return self._get_hse_1D_explorer(*args, **kwargs)
            elif nav_dim == 2:
                return self._get_hse_2D_explorer(*args, **kwargs)
            else:
                return None
        if self.axes_manager.signal_dimension == 2:
            if nav_dim == 1 or nav_dim == 2:
                return self._get_hie_explorer(*args, **kwargs)
            else:
                return None
        else:
            return None

    def plot(self, axes_manager=None):
        if self._plot is not None:
            try:
                self._plot.close()
            except:
                # If it was already closed it will raise an exception,
                # but we want to carry on...
                pass

        if axes_manager is None:
            axes_manager = self.axes_manager

        if axes_manager.signal_dimension == 1:
            # Hyperspectrum

            self._plot = mpl_hse.MPL_HyperSpectrum_Explorer()
            self._plot.spectrum_data_function = self.__call__
            self._plot.spectrum_title = self.mapped_parameters.name
            self._plot.xlabel = '%s (%s)' % (
                self.axes_manager._slicing_axes[0].name,
                self.axes_manager._slicing_axes[0].units)
            self._plot.ylabel = 'Intensity'
            self._plot.axes_manager = axes_manager
            self._plot.axis = self.axes_manager._slicing_axes[0].axis

            # Image properties
            if self.axes_manager._non_slicing_axes:
                self._plot.image_data_function = self._get_explorer
                self._plot.image_title = ''
                self._plot.pixel_size = \
                self.axes_manager._non_slicing_axes[0].scale
                self._plot.pixel_units = \
                self.axes_manager._non_slicing_axes[0].units
            self._plot.plot()

        elif axes_manager.signal_dimension == 2:

            # Mike's playground with new plotting toolkits - needs to be a
            # branch.
            """
            if len(self.data.shape)==2:
                from drawing.guiqwt_hie import image_plot_2D
                image_plot_2D(self)

            import drawing.chaco_hie
            self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self)
            self._plot.configure_traits()
            """
            self._plot = mpl_hie.MPL_HyperImage_Explorer()
            self._plot.image_data_function = self.__call__
            self._plot.navigator_data_function = self._get_explorer
            self._plot.axes_manager = axes_manager
            self._plot.plot()

        else:
            messages.warning_exit('Plotting is not supported for this view')

    traits_view = tui.View(
        tui.Item('name'),
        tui.Item('physical_property'),
        tui.Item('units'),
        tui.Item('offset'),
        tui.Item('scale'),
    )

    def plot_residual(self, axes_manager=None):
        """Plot the residual between original data and reconstructed data

        Requires you to have already run PCA or ICA, and to reconstruct data
        using either the pca_build_SI or ica_build_SI methods.
        """

        if hasattr(self, 'residual'):
            self.residual.plot(axes_manager)
        else:
            print "Object does not have any residual information.  Is it a \
reconstruction created using either pca_build_SI or ica_build_SI methods?"

    def save(self, filename, only_view=False, **kwds):
        """Saves the signal in the specified format.

        The function gets the format from the extension. You can use:
            - hdf5 for HDF5
            - nc for NetCDF
            - msa for EMSA/MSA single spectrum saving.
            - bin to produce a raw binary file
            - Many image formats such as png, tiff, jpeg...

        Please note that not all the formats supports saving datasets of
        arbitrary dimensions, e.g. msa only suports 1D data.

        Parameters
        ----------
        filename : str
        msa_format : {'Y', 'XY'}
            'Y' will produce a file without the energy axis. 'XY' will also
            save another column with the energy axis. For compatibility with
            Gatan Digital Micrograph 'Y' is the default.
        only_view : bool
            If True, only the current view will be saved. Otherwise the full
            dataset is saved. Please note that not all the formats support this
            option at the moment.
        """
        io.save(filename, self, **kwds)

    def _replot(self):
        if self._plot is not None:
            if self._plot.is_active() is True:
                self.plot()

    def get_dimensions_from_data(self):
        """Get the dimension parameters from the data_cube. Useful when the
        data_cube was externally modified, or when the SI was not loaded from
        a file
        """
        dc = self.data
        for axis in self.axes_manager.axes:
            axis.size = int(dc.shape[axis.index_in_array])
            print("%s size: %i" % (axis.name, dc.shape[axis.index_in_array]))
        self._replot()

    def crop_in_pixels(self, axis, i1=None, i2=None):
        """Crops the data in a given axis. The range is given in pixels
        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_units
        """
        axis = self._get_positive_axis_index_index(axis)
        if i1 is not None:
            new_offset = self.axes_manager.axes[axis].axis[i1]
        # We take a copy to guarantee the continuity of the data
        self.data = self.data[(slice(None), ) * axis +
                              (slice(i1, i2), Ellipsis)].copy()

        if i1 is not None:
            self.axes_manager.axes[axis].offset = new_offset
        self.get_dimensions_from_data()

    def crop_in_units(self, axis, x1=None, x2=None):
        """Crops the data in a given axis. The range is given in the units of
        the axis

        axis : int
        i1 : int
            Start index
        i2 : int
            End index

        See also:
        ---------
        crop_in_pixels

        """
        i1 = self.axes_manager.axes[axis].value2index(x1)
        i2 = self.axes_manager.axes[axis].value2index(x2)
        self.crop_in_pixels(axis, i1, i2)

    def roll_xy(self, n_x, n_y=1):
        """Roll over the x axis n_x positions and n_y positions the former rows

        This method has the purpose of "fixing" a bug in the acquisition of the
        Orsay's microscopes and probably it does not have general interest

        Parameters
        ----------
        n_x : int
        n_y : int

        Note: Useful to correct the SI column storing bug in Marcel's
        acquisition routines.
        """
        self.data = np.roll(self.data, n_x, 0)
        self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1)
        self._replot()

    # TODO: After using this function the plotting does not work
    def swap_axis(self, axis1, axis2):
        """Swaps the axes

        Parameters
        ----------
        axis1 : positive int
        axis2 : positive int
        """
        self.data = self.data.swapaxes(axis1, axis2)
        c1 = self.axes_manager.axes[axis1]
        c2 = self.axes_manager.axes[axis2]
        c1.index_in_array, c2.index_in_array =  \
            c2.index_in_array, c1.index_in_array
        self.axes_manager.axes[axis1] = c2
        self.axes_manager.axes[axis2] = c1
        self.axes_manager.set_signal_dimension()
        self._replot()

    def rebin(self, new_shape):
        """
        Rebins the data to the new shape

        Parameters
        ----------
        new_shape: tuple of ints
            The new shape must be a divisor of the original shape
        """
        factors = np.array(self.data.shape) / np.array(new_shape)
        self.data = utils.rebin(self.data, new_shape)
        for axis in self.axes_manager.axes:
            axis.scale *= factors[axis.index_in_array]
        self.get_dimensions_from_data()

    def split_in(self, axis, number_of_parts=None, steps=None):
        """Splits the data

        The split can be defined either by the `number_of_parts` or by the
        `steps` size.

        Parameters
        ----------
        number_of_parts : int or None
            Number of parts in which the SI will be splitted
        steps : int or None
            Size of the splitted parts
        axis : int
            The splitting axis

        Return
        ------
        tuple with the splitted signals
        """
        axis = self._get_positive_axis_index_index(axis)
        if number_of_parts is None and steps is None:
            if not self._splitting_steps:
                messages.warning_exit(
                    "Please provide either number_of_parts or a steps list")
            else:
                steps = self._splitting_steps
                print "Splitting in ", steps
        elif number_of_parts is not None and steps is not None:
            print "Using the given steps list. number_of_parts dimissed"
        splitted = []
        shape = self.data.shape

        if steps is None:
            rounded = (shape[axis] - (shape[axis] % number_of_parts))
            step = rounded / number_of_parts
            cut_node = range(0, rounded + step, step)
        else:
            cut_node = np.array([0] + steps).cumsum()
        for i in xrange(len(cut_node) - 1):
            data = self.data[(slice(None), ) * axis +
                             (slice(cut_node[i], cut_node[i + 1]), Ellipsis)]
            s = Signal({'data': data})
            # TODO: When copying plotting does not work
            #            s.axes = copy.deepcopy(self.axes_manager)
            s.get_dimensions_from_data()
            splitted.append(s)
        return splitted

    def unfold_if_multidim(self):
        """Unfold the datacube if it is >2D

        Returns
        -------

        Boolean. True if the data was unfolded by the function.
        """
        if len(self.axes_manager.axes) > 2:
            print "Automatically unfolding the data"
            self.unfold()
            return True
        else:
            return False

    def _unfold(self, steady_axes, unfolded_axis):
        """Modify the shape of the data by specifying the axes the axes which
        dimension do not change and the axis over which the remaining axes will
        be unfolded

        Parameters
        ----------
        steady_axes : list
            The indexes of the axes which dimensions do not change
        unfolded_axis : int
            The index of the axis over which all the rest of the axes (except
            the steady axes) will be unfolded

        See also
        --------
        fold
        """

        # It doesn't make sense unfolding when dim < 3
        if len(self.data.squeeze().shape) < 3:
            return False

        # We need to store the original shape and coordinates to be used by
        # the fold function only if it has not been already stored by a
        # previous unfold
        if self._shape_before_unfolding is None:
            self._shape_before_unfolding = self.data.shape
            self._axes_manager_before_unfolding = self.axes_manager

        new_shape = [1] * len(self.data.shape)
        for index in steady_axes:
            new_shape[index] = self.data.shape[index]
        new_shape[unfolded_axis] = -1
        self.data = self.data.reshape(new_shape)
        self.axes_manager = self.axes_manager.deepcopy()
        i = 0
        uname = ''
        uunits = ''
        to_remove = []
        for axis, dim in zip(self.axes_manager.axes, new_shape):
            if dim == 1:
                uname += ',' + axis.name
                uunits = ',' + axis.units
                to_remove.append(axis)
            else:
                axis.index_in_array = i
                i += 1
        self.axes_manager.axes[unfolded_axis].name += uname
        self.axes_manager.axes[unfolded_axis].units += uunits
        self.axes_manager.axes[unfolded_axis].size = \
                                                self.data.shape[unfolded_axis]
        for axis in to_remove:
            self.axes_manager.axes.remove(axis)

        self.data = self.data.squeeze()
        self._replot()

    def unfold(self):
        """Modifies the shape of the data by unfolding the signal and
        navigation dimensions separaterly

        """
        self.unfold_navigation_space()
        self.unfold_signal_space()

    def unfold_navigation_space(self):
        """Modify the shape of the data to obtain a navigation space of
        dimension 1
        """

        if self.axes_manager.navigation_dimension < 2:
            messages.information('Nothing done, the navigation dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._slicing_axes
        ]
        unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def unfold_signal_space(self):
        """Modify the shape of the data to obtain a signal space of
        dimension 1
        """
        if self.axes_manager.signal_dimension < 2:
            messages.information('Nothing done, the signal dimension was '
                                 'already 1')
            return False
        steady_axes = [
            axis.index_in_array for axis in self.axes_manager._non_slicing_axes
        ]
        unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array
        self._unfold(steady_axes, unfolded_axis)

    def fold(self):
        """If the signal was previously unfolded, folds it back"""
        if self._shape_before_unfolding is not None:
            self.data = self.data.reshape(self._shape_before_unfolding)
            self.axes_manager = self._axes_manager_before_unfolding
            self._shape_before_unfolding = None
            self._axes_manager_before_unfolding = None
            self._replot()

    def _get_positive_axis_index_index(self, axis):
        if axis < 0:
            axis = len(self.data.shape) + axis
        return axis

    def iterate_axis(self, axis=-1):
        # We make a copy to guarantee that the data in contiguous, otherwise
        # it will not return a view of the data
        self.data = self.data.copy()
        axis = self._get_positive_axis_index_index(axis)
        unfolded_axis = axis - 1
        new_shape = [1] * len(self.data.shape)
        new_shape[axis] = self.data.shape[axis]
        new_shape[unfolded_axis] = -1
        # Warning! if the data is not contigous it will make a copy!!
        data = self.data.reshape(new_shape)
        for i in xrange(data.shape[unfolded_axis]):
            getitem = [0] * len(data.shape)
            getitem[axis] = slice(None)
            getitem[unfolded_axis] = i
            yield (data[getitem])

    def sum(self, axis, return_signal=False):
        """Sum the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
             will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.sum(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.sum(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.sum(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def mean(self, axis, return_signal=False):
        """Average the data over the specify axis

        Parameters
        ----------
        axis : int
            The axis over which the operation will be performed
        return_signal : bool
            If False the operation will be performed on the current object. If
            True, the current object will not be modified and the operation
            will be performed in a new signal object that will be returned.

        Returns
        -------
        Depending on the value of the return_signal keyword, nothing or a
        signal instance

        See also
        --------
        sum_in_mask, mean

        Usage
        -----
        >>> import numpy as np
        >>> s = Signal({'data' : np.random.random((64,64,1024))})
        >>> s.data.shape
        (64,64,1024)
        >>> s.mean(-1)
        >>> s.data.shape
        (64,64)
        # If we just want to plot the result of the operation
        s.mean(-1, True).plot()
        """
        if return_signal is True:
            s = self.deepcopy()
        else:
            s = self
        s.data = s.data.mean(axis)
        s.axes_manager.axes.remove(s.axes_manager.axes[axis])
        for _axis in s.axes_manager.axes:
            if _axis.index_in_array > axis:
                _axis.index_in_array -= 1
        s.axes_manager.set_signal_dimension()
        if return_signal is True:
            return s

    def copy(self):
        return (copy.copy(self))

    def deepcopy(self):
        return (copy.deepcopy(self))

#    def sum_in_mask(self, mask):
#        """Returns the result of summing all the spectra in the mask.
#
#        Parameters
#        ----------
#        mask : boolean numpy array
#
#        Returns
#        -------
#        Spectrum
#        """
#        dc = self.data_cube.copy()
#        mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape)
#        dc = (mask3D*dc).sum(1).sum(1) / mask.sum()
#        s = Spectrum()
#        s.data_cube = dc.reshape((-1,1,1))
#        s.get_dimensions_from_cube()
#        utils.copy_energy_calibration(self,s)
#        return s
#
#    def mean(self, axis):
#        """Average the SI over the given axis
#
#        Parameters
#        ----------
#        axis : int
#        """
#        dc = self.data_cube
#        dc = dc.mean(axis)
#        dc = dc.reshape(list(dc.shape) + [1,])
#        self.data_cube = dc
#        self.get_dimensions_from_cube()
#
#    def roll(self, axis = 2, shift = 1):
#        """Roll the SI. see numpy.roll
#
#        Parameters
#        ----------
#        axis : int
#        shift : int
#        """
#        self.data_cube = np.roll(self.data_cube, shift, axis)
#        self._replot()
#

#
#    def get_calibration_from(self, s):
#        """Copy the calibration from another Spectrum instance
#        Parameters
#        ----------
#        s : spectrum instance
#        """
#        utils.copy_energy_calibration(s, self)
#
#    def estimate_variance(self, dc = None, gaussian_noise_var = None):
#        """Variance estimation supposing Poissonian noise
#
#        Parameters
#        ----------
#        dc : None or numpy array
#            If None the SI is used to estimate its variance. Otherwise, the
#            provided array will be used.
#        Note
#        ----
#        The gain_factor and gain_offset from the aquisition parameters are used
#        """
#        print "Variace estimation using the following values:"
#        print "Gain factor = ", self.acquisition_parameters.gain_factor
#        print "Gain offset = ", self.acquisition_parameters.gain_offset
#        if dc is None:
#            dc = self.data_cube
#        gain_factor = self.acquisition_parameters.gain_factor
#        gain_offset = self.acquisition_parameters.gain_offset
#        self.variance = dc*gain_factor + gain_offset
#        if self.variance.min() < 0:
#            if gain_offset == 0 and gaussian_noise_var is None:
#                print "The variance estimation results in negative values"
#                print "Maybe the gain_offset is wrong?"
#                self.variance = None
#                return
#            elif gaussian_noise_var is None:
#                print "Clipping the variance to the gain_offset value"
#                self.variance = np.clip(self.variance, np.abs(gain_offset),
#                np.Inf)
#            else:
#                print "Clipping the variance to the gaussian_noise_var"
#                self.variance = np.clip(self.variance, gaussian_noise_var,
#                np.Inf)
#
#    def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6,
#    modify_calibration = True):
#        dispersion = (rcE - lcE) / (rc - lc)
#        origin = lcE - dispersion * lc
#        print "Energy step = ", dispersion
#        print "Energy origin = ", origin
#        if modify_calibration is True:
#            self.set_new_calibration(origin, dispersion)
#        return origin, dispersion
#

    def _correct_navigation_mask_when_unfolded(
        self,
        navigation_mask=None,
    ):
        #if 'unfolded' in self.history:
        if navigation_mask is not None:
            navigation_mask = navigation_mask.reshape((-1, ))
        return navigation_mask