예제 #1
0
class FETS2D3U1M(FETSEval):
    r'''Triangular, three-node element.
    '''

    vtk_r = tr.Array(np.float_, value=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    vtk_cells = [[0, 1, 2]]
    vtk_cell_types = 'Triangle'
    vtk_cell = [0, 1, 2]
    vtk_cell_type = 'Triangle'

    vtk_expand_operator = tr.Array(np.float_, value=DELTA23_ab)

    # =========================================================================
    # Surface integrals using numerical integration
    # =========================================================================
    # i is gauss point index, p is point coords in natural coords (zeta_1, zeta_2, zeta_3)
    eta_ip = tr.Array('float_')
    r'''Integration points within a triangle.
    '''
    def _eta_ip_default(self):
        # here, just one integration point in the middle of the triangle (zeta_1 = 1/3, zeta_2 = 1/3, zeta_3 = 1/3)
        return np.array([[1. / 3., 1. / 3., 1. / 3.]], dtype='f')

    w_m = tr.Array('float_')
    r'''Weight factors for numerical integration.
    '''

    def _w_m_default(self):
        return np.array([1. / 2.], dtype='f')

    n_m = tr.Int(1)
    r'''Number of integration points.
    '''

    @tr.cached_property
    def _get_w_m(self):
        return len(self.w_m)

    n_nodal_dofs = tr.Int(3)

    # N_im = tr.Property(depends_on='eta_ip')
    # r'''Shape function values in integration points.
    # '''
    # @tr.cached_property
    # def _get_N_im(self):
    #     eta = self.eta_ip
    #     return np.array([eta[:, 0], eta[:, 1], 1 - eta[:, 0] - eta[:, 1]],
    #                     dtype='f')

    dN_imr = tr.Property(depends_on='eta_ip')
    r'''Derivatives of the shape functions in the integration points.
    '''

    @tr.cached_property
    def _get_dN_imr(self):
        dN_mri = np.array(
            [
                [
                    [1, 0, -1],  # dN1/d_zeta1, dN2/d_zeta1, dN3/d_zeta1
                    [0, 1, -1]
                ],  # dN1/d_zeta2, dN2/d_zeta2, dN3/d_zeta2
            ],
            dtype=np.float_)
        return np.einsum('mri->imr', dN_mri)

    dN_inr = tr.Property(depends_on='eta_ip')
    r'''Derivatives of the shape functions in the integration points.
    '''

    @tr.cached_property
    def _get_dN_inr(self):
        return self.dN_imr

    vtk_expand_operator = tr.Array(value=[1, 1, 0])
    vtk_node_cell_data = tr.Array
    vtk_ip_cell_data = tr.Array
예제 #2
0
class ModelInteract(tr.HasTraits):

    models = tr.List([
    ])

    py_vars = tr.List(tr.Str)
    map_py2sp = tr.Dict

    d = tr.Float(0.03, GEO=True)
    h = tr.Float(0.8, GEO=True) 
    
    # define the free parameters as traits with default, min and max values
    w_max = tr.Float(1.0)
    t = tr.Float(0.0001, min=1e-5, max=1)
    tau = tr.Float(0.5, interact=True)
    L_b = tr.Float(200, interact=True)
    E_f = tr.Float(100000, interact=True)
    A_f = tr.Float(20, interact=True)
    p = tr.Float(40, interact=True)
    E_m = tr.Float(26000, interact=True)
    A_m = tr.Float(100, interact=True)

    n_steps = tr.Int(50)
    
    sliders = tr.Property
    @tr.cached_property    
    def _get_sliders(self):
        traits = self.traits(interact=True)
        vals = self.trait_get(interact=True)
        slider_names = self.py_vars[1:]
        max_vals = {name : getattr(traits,'max', vals[name] * 2)
                    for name in slider_names}
        t_slider = {'t': ipw.FloatSlider(1e-5, min=1e-5, max=1, step=0.05,
                                         description=r'\(t\)')}
        param_sliders = { name : ipw.FloatSlider(value=vals[name],
                                        min=1e-5,
                                        max=max_vals[name],
                                        step=max_vals[name] / self.n_steps,
                                        description=r'\(%s\)' % self.map_py2sp[name].name)
            for (name, trait) in traits.items()
        }
        t_slider.update(param_sliders)
        return t_slider

    w_range = tr.Property(tr.Array(np.float_), depends_on='w_max')
    @tr.cached_property
    def _get_w_range(self):
        return np.linspace(0,self.w_max,50)
  
    x_range = tr.Property(tr.Array(np.float_), depends_on='L_b')
    @tr.cached_property
    def _get_x_range(self):
        return np.linspace(-self.L_b,0,100)
  
    model_plots = tr.Property(tr.List)
    @tr.cached_property
    def _get_model_plots(self):
        return [PlotModel(itr=self, model=m) for m in self.models ]
      
    def init_fields(self):
        self.fig, ((self.ax_po, self.ax_u),(self.ax_eps, self.ax_tau)) = plt.subplots(
            2,2,figsize=(9,5), tight_layout=True
        )
        values = self.trait_get(interact=True)
        params = list( values[py_var] for py_var in self.py_vars[1:])
        for mp in self.model_plots:
            mp.init_fields(*params)
            mp.init_Pw(*params)
        self.ax_po.set_xlim(0, self.w_max*1.05)

    def clear_fields(self):
        clear_plot(self.ax_po, self.ax_u, self.ax_eps,self.ax_tau)

    def update_fields(self, t, **values):
        w = t * self.w_max
        self.trait_set(**values)
        params = list( values[py_var] for py_var in self.py_vars[1:])
        L_b = self.L_b
        self.clear_fields()
        for mp in self.model_plots:
            mp.update_fields(w, *params)
            mp.update_Pw(w, *params)

        P_max = np.max(np.array([m.P_max for m in self.model_plots]))
        self.ax_po.set_ylim(0, P_max*1.05)
        self.ax_po.set_xlim(0, self.w_max*1.05)
        u_min = np.min(np.array([m.u_min for m in self.model_plots]))
        u_max = np.max(np.array([m.u_max for m in self.model_plots] + [1]))
        self.ax_u.set_ylim(u_min, u_max*1.1)
        self.ax_u.set_xlim(xmin=-1.05*L_b, xmax=0.05*L_b )
        eps_min = np.min(np.array([m.eps_min for m in self.model_plots]))
        eps_max = np.max(np.array([m.eps_max for m in self.model_plots]))
        self.ax_eps.set_ylim(eps_min, eps_max*1.1)
        self.ax_eps.set_xlim(xmin=-1.05*L_b, xmax=0.05*L_b )
        self.ax_tau.set_ylim(0, self.tau*1.1)
        self.ax_tau.set_xlim(xmin=-1.05*L_b, xmax=0.05*L_b )
        self.fig.canvas.draw_idle()

    def set_w_max_fields(self, w_max):
        self.w_max = w_max
        values = {name: slider.value for name, slider in self.sliders.items()}
        self.update_fields(**values)

    def interact_fields(self):
        self.init_fields()
        self.on_w_max_change = self.update_fields
        sliders = self.sliders
        out = ipw.interactive_output(self.update_fields, sliders);
        self.widget_layout(out)

    #===========================================================================
    # Interaction on the pull-out curve spatial plot
    #===========================================================================
    def init_geometry(self):
        self.fig, (self.ax_po, self.ax_geo) = plt.subplots(1,2,figsize=(8,3.4)) #, tight_layout=True)
        values = self.trait_get(interact=True)
        params = list( values[py_var] for py_var in self.py_vars[1:])
        h=self.h
        d=self.d
        x_C = np.array([[-1, 0], [0,0],[0, h], [-1, h]], dtype=np.float_)
        self.line_C, = self.ax_geo.fill(*x_C.T, color='gray', alpha=0.3)
        for mp in self.model_plots:
            mp.line_aw, = self.ax_geo.fill([],[], color='white', alpha=1)
            mp.line_F, = self.ax_geo.fill([],[], color='black', alpha=0.8)
            mp.line_F0, = self.ax_geo.fill([],[], color='white', alpha=1)
            mp.init_Pw(*params)
        self.ax_po.set_xlim(0, self.w_max*1.05)

    def clear_geometry(self):
        clear_plot(self.ax_po, self.ax_geo)

    def update_geometry(self, t, **values):
        w = t * self.w_max
        self.clear_geometry()
        self.trait_set(**values)
        params = list( values[py_var] for py_var in self.py_vars[1:])
        h = self.h
        d = self.d
        L_b = self.L_b
        f_top = h / 2 + d / 2
        f_bot = h / 2 - d / 2
        self.ax_geo.set_xlim(xmin=-1.05*L_b, xmax=max( 0.05*L_b, 1.1*self.w_max) )
        x_C = np.array([[-L_b, 0], [0,0],[0, h], [-L_b, h]], dtype=np.float_)
        self.line_C.set_xy(x_C)
        for mp in self.model_plots:
            a_val = mp.model.get_aw_pull(w, *params)
            width = d * 0.5
            x_a = np.array([[a_val, f_bot-width],[0, f_bot-width],
                            [0,f_top+width],[a_val, f_top+width]],
                           dtype=np.float_)
            mp.line_aw.set_xy(x_a)

            w_L_b = mp.model.get_w_L_b(w, *params)
            x_F = np.array([[-L_b+w_L_b, f_bot],[w,f_bot],
                            [w,f_top],[-L_b+w_L_b,f_top]], dtype=np.float_)
            mp.line_F.set_xy(x_F)
            x_F0 = np.array([[-L_b, f_bot],[-L_b+w_L_b,f_bot],
                             [-L_b+w_L_b,f_top],[-L_b,f_top]], dtype=np.float_)
            mp.line_F0.set_xy(x_F0)
            
            mp.update_Pw(w, *params)

        P_max = np.max(np.array([mp.P_max for mp in self.model_plots]))
        self.ax_po.set_ylim(0, P_max*1.1)
        self.ax_po.set_xlim(0, self.w_max*1.05)
        self.fig.canvas.draw_idle()

    def set_w_max(self, w_max):
        self.w_max = w_max
        values = {name: slider.value for name, slider in self.sliders.items()}
        self.on_w_max_change(**values)
        
    on_w_max_change = tr.Callable
    
    def interact_geometry(self):
        self.init_geometry()
        self.on_w_max_change = self.update_geometry
        sliders = self.sliders
        out = ipw.interactive_output(self.update_geometry, sliders);
        self.widget_layout(out)

    def widget_layout(self, out):
        sliders = self.sliders
        layout = ipw.Layout(grid_template_columns='1fr 1fr')
        param_sliders_list = [sliders[py_var] for py_var in self.py_vars[1:]]
        t_slider = sliders['t']
        grid = ipw.GridBox(param_sliders_list, layout=layout)
        w_max_text = ipw.FloatText(
            value=self.w_max,
            description=r'w_max',
            disabled=False
        )
        out_w_max = ipw.interactive_output(self.set_w_max, 
                                           {'w_max':w_max_text})

        hbox = ipw.HBox([t_slider, w_max_text])
        box = ipw.VBox([hbox, grid, out, out_w_max])
        display(box)
예제 #3
0
class config(BaseWorkflowConfig):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc='Workflow description')
    # Directories
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    field_dir = Directory(
        desc="Base directory of field-map data (Should be subject-independent) \
                                                 Set this value to None if you don't want fieldmap distortion correction"
    )
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")

    # Subjects

    subjects = traits.List(
        traits.Str,
        mandatory=True,
        usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files."
    )
    func_template = traits.String('%s/functional.nii.gz')
    run_datagrabber_without_submitting = traits.Bool(
        desc="Run the datagrabber without \
    submitting to the cluster")
    timepoints_to_remove = traits.Int(0, usedefault=True)

    # Fieldmap

    use_fieldmap = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='True to include fieldmap distortion correction. Note: field_dir \
                                     must be specified')
    magnitude_template = traits.String('%s/magnitude.nii.gz')
    phase_template = traits.String('%s/phase.nii.gz')
    TE_diff = traits.Float(desc='difference in B0 field map TEs')
    sigma = traits.Int(
        2, desc='2D spatial gaussing smoothing stdev (default = 2mm)')
    echospacing = traits.Float(desc="EPI echo spacing")

    # Motion Correction

    do_slicetiming = Bool(True,
                          usedefault=True,
                          desc="Perform slice timing correction")
    SliceOrder = traits.List(traits.Int)
    TR = traits.Float(1.0, mandatory=True, desc="TR of functional")
    motion_correct_node = traits.Enum(
        'nipy',
        'fsl',
        'spm',
        'afni',
        desc="motion correction algorithm to use",
        usedefault=True,
    )
    loops = traits.List([5], traits.Int(5), usedefault=True)
    #between_loops = traits.Either("None",traits.List([5]),usedefault=True)
    speedup = traits.List([5], traits.Int(5), usedefault=True)
    # Artifact Detection

    norm_thresh = traits.Float(1,
                               min=0,
                               usedefault=True,
                               desc="norm thresh for art")
    z_thresh = traits.Float(3, min=0, usedefault=True, desc="z thresh for art")

    # Smoothing
    fwhm = traits.List(
        [0, 5],
        traits.Float(),
        mandatory=True,
        usedefault=True,
        desc="Full width at half max. The data will be smoothed at all values \
                             specified in this list.")
    smooth_type = traits.Enum("susan",
                              "isotropic",
                              'freesurfer',
                              usedefault=True,
                              desc="Type of smoothing to use")
    surface_fwhm = traits.Float(
        0.0,
        desc='surface smoothing kernel, if freesurfer is selected',
        usedefault=True)

    # CompCor
    compcor_select = traits.BaseTuple(
        traits.Bool,
        traits.Bool,
        mandatory=True,
        desc="The first value in the list corresponds to applying \
                                       t-compcor, and the second value to a-compcor. Note: \
                                       both can be true")
    num_noise_components = traits.Int(
        6,
        usedefault=True,
        desc="number of principle components of the noise to use")
    regress_before_PCA = traits.Bool(True)
    # Highpass Filter
    hpcutoff = traits.Float(128., desc="highpass cutoff", usedefault=True)

    #zscore
    do_zscore = Bool(False)

    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()
    debug = traits.Bool(False)

    # Buttons
    check_func_datagrabber = Button("Check")
    check_field_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.base_dir, self.func_template % s)):
                print "ERROR", os.path.join(self.base_dir, self.func_template %
                                            s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.func_template % s), "exists!"

    def _check_field_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.field_dir, self.magnitude_template % s)):
                print "ERROR:", os.path.join(self.field_dir,
                                             self.magnitude_template %
                                             s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.magnitude_template % s), "exists!"
            if not os.path.exists(
                    os.path.join(self.field_dir, self.phase_template % s)):
                print "ERROR:", os.path.join(
                    self.field_dir, self.phase_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.phase_template % s), "exists!"
예제 #4
0
class PermuteTimeSeriesInputSpec(TraitedSpec):
    original_volume = File(exists=True,
                           desc="source volume for bootstrapping",
                           mandatory=True)
    id = traits.Int()
예제 #5
0
파일: axes.py 프로젝트: malliwi88/hyperspy
class AxesManager(t.HasTraits):
    """Contains and manages the data axes.

    It supports indexing, slicing, subscriptins and iteration. As an iterator,
    iterate over the navigation coordinates returning the current indices.
    It can only be indexed and sliced to access the DataAxis objects that it
    contains. Standard indexing and slicing follows the "natural order" as in
    Signal, i.e. [nX, nY, ...,sX, sY,...] where `n` indicates a navigation axis
    and `s` a signal axis. In addition AxesManager support indexing using
    complex numbers a + bj, where b can be one of 0, 1, 2 and 3 and a a valid
    index. If b is 3 AxesManager is indexed using the order of the axes in the
    array. If b is 1(2), indexes only the navigation(signal) axes in the
    natural order. In addition AxesManager supports subscription using
    axis name.

    Attributes
    ----------

    coordinates : tuple
        Get and set the current coordinates if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.


    indices : tuple
        Get and set the current indices if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.

    signal_axes, navigation_axes : list
        Contain the corresponding DataAxis objects

    Examples
    --------

    >>> %hyperspy
    HyperSpy imported!
    The following commands were just executed:
    ---------------
    import numpy as np
    import hyperspy.api as hs
    %matplotlib qt
    import matplotlib.pyplot as plt

    >>> # Create a spectrum with random data

    >>> s = hs.signals.Signal1D(np.random.random((2,3,4,5)))
    >>> s.axes_manager
    <Axes manager, axes: (<axis2 axis, size: 4, index: 0>, <axis1 axis, size: 3, index: 0>, <axis0 axis, size: 2, index: 0>, <axis3 axis, size: 5>)>
    >>> s.axes_manager[0]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[3j]
    <axis0 axis, size: 2, index: 0>
    >>> s.axes_manager[1j]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[2j]
    <axis3 axis, size: 5>
    >>> s.axes_manager[1].name="y"
    >>> s.axes_manager['y']
    <y axis, size: 3 index: 0>
    >>> for i in s.axes_manager:
    >>>     print(i, s.axes_manager.indices)
    (0, 0, 0) (0, 0, 0)
    (1, 0, 0) (1, 0, 0)
    (2, 0, 0) (2, 0, 0)
    (3, 0, 0) (3, 0, 0)
    (0, 1, 0) (0, 1, 0)
    (1, 1, 0) (1, 1, 0)
    (2, 1, 0) (2, 1, 0)
    (3, 1, 0) (3, 1, 0)
    (0, 2, 0) (0, 2, 0)
    (1, 2, 0) (1, 2, 0)
    (2, 2, 0) (2, 2, 0)
    (3, 2, 0) (3, 2, 0)
    (0, 0, 1) (0, 0, 1)
    (1, 0, 1) (1, 0, 1)
    (2, 0, 1) (2, 0, 1)
    (3, 0, 1) (3, 0, 1)
    (0, 1, 1) (0, 1, 1)
    (1, 1, 1) (1, 1, 1)
    (2, 1, 1) (2, 1, 1)
    (3, 1, 1) (3, 1, 1)
    (0, 2, 1) (0, 2, 1)
    (1, 2, 1) (1, 2, 1)
    (2, 2, 1) (2, 2, 1)
    (3, 2, 1) (3, 2, 1)

    """

    _axes = t.List(DataAxis)
    signal_axes = t.Tuple()
    navigation_axes = t.Tuple()
    _step = t.Int(1)

    def __init__(self, axes_list):
        super(AxesManager, self).__init__()
        self.events = Events()
        self.events.indices_changed = Event("""
            Event that triggers when the indices of the `AxesManager` changes

            Triggers after the internal state of the `AxesManager` has been
            updated.

            Arguments:
            ----------
            obj : The AxesManager that the event belongs to.
            """,
                                            arguments=['obj'])
        self.events.any_axis_changed = Event("""
            Event that trigger when the space defined by the axes transforms.

            Specifically, it triggers when one or more of the folloing
            attributes changes on one or more of the axes:
                `offset`, `size`, `scale`

            Arguments:
            ----------
            obj : The AxesManager that the event belongs to.
            """,
                                             arguments=['obj'])
        self.create_axes(axes_list)
        # set_signal_dimension is called only if there is no current
        # view. It defaults to spectrum
        navigates = [i.navigate for i in self._axes]
        if t.Undefined in navigates:
            # Default to Signal1D view if the view is not fully defined
            self.set_signal_dimension(len(axes_list))

        self._update_attributes()
        self._update_trait_handlers()
        self._index = None  # index for the iterator

    def _update_trait_handlers(self, remove=False):
        things = {
            self._on_index_changed: '_axes.index',
            self._on_slice_changed: '_axes.slice',
            self._on_size_changed: '_axes.size',
            self._on_scale_changed: '_axes.scale',
            self._on_offset_changed: '_axes.offset'
        }

        for k, v in things.items():
            self.on_trait_change(k, name=v, remove=remove)

    def _get_positive_index(self, axis):
        if axis < 0:
            axis += len(self._axes)
            if axis < 0:
                raise IndexError("index out of bounds")
        return axis

    def _array_indices_generator(self):
        shape = (self.navigation_shape[::-1] if self.navigation_size > 0 else [
            1,
        ])
        return np.ndindex(*shape)

    def _am_indices_generator(self):
        shape = (self.navigation_shape if self.navigation_size > 0 else [
            1,
        ])[::-1]
        return ndindex_nat(*shape)

    def __getitem__(self, y):
        """x.__getitem__(y) <==> x[y]

        """
        if isinstance(y, str) or not np.iterable(y):
            return self[(y, )][0]
        axes = [self._axes_getter(ax) for ax in y]
        _, indices = np.unique([_id for _id in map(id, axes)],
                               return_index=True)
        ans = tuple(axes[i] for i in sorted(indices))
        return ans

    def _axes_getter(self, y):
        if y in self._axes:
            return y
        if isinstance(y, str):
            axes = list(self._get_axes_in_natural_order())
            while axes:
                axis = axes.pop()
                if y == axis.name:
                    return axis
            raise ValueError("There is no DataAxis named %s" % y)
        elif (isfloat(y.real) and not y.real.is_integer()
              or isfloat(y.imag) and not y.imag.is_integer()):
            raise TypeError("axesmanager indices must be integers, "
                            "complex intergers or strings")
        if y.imag == 0:  # Natural order
            return self._get_axes_in_natural_order()[y]
        elif y.imag == 3:  # Array order
            # Array order
            return self._axes[int(y.real)]
        elif y.imag == 1:  # Navigation natural order
            #
            return self.navigation_axes[int(y.real)]
        elif y.imag == 2:  # Signal natural order
            return self.signal_axes[int(y.real)]
        else:
            raise IndexError("axesmanager imaginary part of complex indices "
                             "must be 0, 1, 2 or 3")

    def __getslice__(self, i=None, j=None):
        """x.__getslice__(i, j) <==> x[i:j]

        """
        return self._get_axes_in_natural_order()[i:j]

    def _get_axes_in_natural_order(self):
        return self.navigation_axes + self.signal_axes

    @property
    def _navigation_shape_in_array(self):
        return self.navigation_shape[::-1]

    @property
    def _signal_shape_in_array(self):
        return self.signal_shape[::-1]

    @property
    def shape(self):
        nav_shape = (self.navigation_shape if self.navigation_shape !=
                     (0, ) else tuple())
        sig_shape = (self.signal_shape if self.signal_shape !=
                     (0, ) else tuple())
        return nav_shape + sig_shape

    def remove(self, axes):
        """Remove one or more axes
        """
        axes = self[axes]
        if not np.iterable(axes):
            axes = (axes, )
        for ax in axes:
            self._remove_one_axis(ax)

    def _remove_one_axis(self, axis):
        """Remove the given Axis.

        Raises
        ------
        ValueError if the Axis is not present.

        """
        axis = self._axes_getter(axis)
        axis.axes_manager = None
        self._axes.remove(axis)

    def __delitem__(self, i):
        self.remove(self[i])

    def _get_data_slice(self, fill=None):
        """Return a tuple of slice objects to slice the data.

        Parameters
        ----------
        fill: None or iterable of (int, slice)
            If not None, fill the tuple of index int with the given
            slice.

        """
        cslice = [
            slice(None),
        ] * len(self._axes)
        if fill is not None:
            for index, slice_ in fill:
                cslice[index] = slice_
        return tuple(cslice)

    def create_axes(self, axes_list):
        """Given a list of dictionaries defining the axes properties
        create the DataAxis instances and add them to the AxesManager.

        The index of the axis in the array and in the `_axes` lists
        can be defined by the index_in_array keyword if given
        for all axes. Otherwise it is defined by their index in the
        list.

        See also
        --------
        _append_axis

        """
        # Reorder axes_list using index_in_array if it is defined
        # for all axes and the indices are not repeated.
        indices = set([
            axis['index_in_array'] for axis in axes_list
            if hasattr(axis, 'index_in_array')
        ])
        if len(indices) == len(axes_list):
            axes_list.sort(key=lambda x: x['index_in_array'])
        for axis_dict in axes_list:
            self._append_axis(**axis_dict)

    def _update_max_index(self):
        self._max_index = 1
        for i in self.navigation_shape:
            self._max_index *= i
        if self._max_index != 0:
            self._max_index -= 1

    def __next__(self):
        """
        Standard iterator method, updates the index and returns the
        current coordiantes

        Returns
        -------
        val : tuple of ints
            Returns a tuple containing the coordiantes of the current
            iteration.

        """
        if self._index is None:
            self._index = 0
            val = (0, ) * self.navigation_dimension
            self.indices = val
        elif self._index >= self._max_index:
            raise StopIteration
        else:
            self._index += 1
            val = np.unravel_index(self._index,
                                   tuple(
                                       self._navigation_shape_in_array))[::-1]
            self.indices = val
        return val

    def __iter__(self):
        # Reset the _index that can have a value != None due to
        # a previous iteration that did not hit a StopIteration
        self._index = None
        return self

    def _append_axis(self, *args, **kwargs):
        axis = DataAxis(*args, **kwargs)
        axis.axes_manager = self
        self._axes.append(axis)

    def _on_index_changed(self):
        self._update_attributes()
        self.events.indices_changed.trigger(obj=self)

    def _on_slice_changed(self):
        self._update_attributes()

    def _on_size_changed(self):
        self._update_attributes()
        self.events.any_axis_changed.trigger(obj=self)

    def _on_scale_changed(self):
        self.events.any_axis_changed.trigger(obj=self)

    def _on_offset_changed(self):
        self.events.any_axis_changed.trigger(obj=self)

    def update_axes_attributes_from(self,
                                    axes,
                                    attributes=["scale", "offset", "units"]):
        """Update the axes attributes to match those given.

        The axes are matched by their index in the array. The purpose of this
        method is to update multiple axes triggering `any_axis_changed` only
        once.

        Parameters
        ----------
        axes: iterable of `DataAxis` instances.
            The axes to copy the attributes from.
        attributes: iterable of strings.
            The attributes to copy.

        """

        # To only trigger once even with several changes, we suppress here
        # and trigger manually below if there were any changes.
        changes = False
        with self.events.any_axis_changed.suppress():
            for axis in axes:
                changed = self._axes[axis.index_in_array].update_from(
                    axis=axis, attributes=attributes)
                changes = changes or changed
        if changes:
            self.events.any_axis_changed.trigger(obj=self)

    def _update_attributes(self):
        getitem_tuple = []
        values = []
        self.signal_axes = ()
        self.navigation_axes = ()
        for axis in self._axes:
            # Until we find a better place, take property of the axes
            # here to avoid difficult to debug bugs.
            axis.axes_manager = self
            if axis.slice is None:
                getitem_tuple += axis.index,
                values.append(axis.value)
                self.navigation_axes += axis,
            else:
                getitem_tuple += axis.slice,
                self.signal_axes += axis,
        if not self.signal_axes and self.navigation_axes:
            getitem_tuple[-1] = slice(axis.index, axis.index + 1)

        self.signal_axes = self.signal_axes[::-1]
        self.navigation_axes = self.navigation_axes[::-1]
        self._getitem_tuple = tuple(getitem_tuple)
        self.signal_dimension = len(self.signal_axes)
        self.navigation_dimension = len(self.navigation_axes)
        if self.navigation_dimension != 0:
            self.navigation_shape = tuple(
                [axis.size for axis in self.navigation_axes])
        else:
            self.navigation_shape = ()

        if self.signal_dimension != 0:
            self.signal_shape = tuple([axis.size for axis in self.signal_axes])
        else:
            self.signal_shape = ()
        self.navigation_size = (np.cumprod(self.navigation_shape)[-1]
                                if self.navigation_shape else 0)
        self.signal_size = (np.cumprod(self.signal_shape)[-1]
                            if self.signal_shape else 0)
        self._update_max_index()

    def set_signal_dimension(self, value):
        """Set the dimension of the signal.

        Attributes
        ----------
        value : int

        Raises
        ------
        ValueError if value if greater than the number of axes or
        is negative

        """
        if len(self._axes) == 0:
            return
        elif value > len(self._axes):
            raise ValueError("The signal dimension cannot be greater"
                             " than the number of axes which is %i" %
                             len(self._axes))
        elif value < 0:
            raise ValueError("The signal dimension must be a positive integer")

        tl = [True] * len(self._axes)
        if value != 0:
            tl[-value:] = (False, ) * value

        for axis in self._axes:
            axis.navigate = tl.pop(0)

    def key_navigator(self, event):
        if len(self.navigation_axes) not in (1, 2):
            return
        x = self.navigation_axes[0]
        try:
            if event.key == "right" or event.key == "6":
                x.index += self._step
            elif event.key == "left" or event.key == "4":
                x.index -= self._step
            elif event.key == "pageup":
                self._step += 1
            elif event.key == "pagedown":
                if self._step > 1:
                    self._step -= 1
            if len(self.navigation_axes) == 2:
                y = self.navigation_axes[1]
                if event.key == "up" or event.key == "8":
                    y.index -= self._step
                elif event.key == "down" or event.key == "2":
                    y.index += self._step
        except TraitError:
            pass

    def gui(self):
        from hyperspy.gui.axes import data_axis_view
        for axis in self._axes:
            axis.edit_traits(view=data_axis_view)

    def copy(self):
        return copy.copy(self)

    def deepcopy(self):
        return copy.deepcopy(self)

    def __deepcopy__(self, *args):
        return AxesManager(self._get_axes_dicts())

    def _get_axes_dicts(self):
        axes_dicts = []
        for axis in self._axes:
            axes_dicts.append(axis.get_axis_dictionary())
        return axes_dicts

    def as_dictionary(self):
        am_dict = {}
        for i, axis in enumerate(self._axes):
            am_dict['axis-%i' % i] = axis.get_axis_dictionary()
        return am_dict

    def _get_signal_axes_dicts(self):
        return [axis.get_axis_dictionary() for axis in self.signal_axes[::-1]]

    def _get_navigation_axes_dicts(self):
        return [
            axis.get_axis_dictionary() for axis in self.navigation_axes[::-1]
        ]

    def show(self):
        from hyperspy.gui.axes import get_axis_group
        import traitsui.api as tui
        context = {}
        ag = []
        for n, axis in enumerate(self._get_axes_in_natural_order()):
            ag.append(get_axis_group(n, str(axis)))
            context['axis%i' % n] = axis
        ag = tuple(ag)
        self.edit_traits(view=tui.View(*ag), context=context)

    def _get_dimension_str(self):
        string = "("
        for axis in self.navigation_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += "|"
        for axis in self.signal_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += ")"
        return string

    def __repr__(self):
        text = ('<Axes manager, axes: %s>\n' % self._get_dimension_str())
        ax_signature = "% 16s | %6g | %6s | %7.2g | %7.2g | %6s "
        signature = "% 16s | %6s | %6s | %7s | %7s | %6s "
        text += signature % ('Name', 'size', 'index', 'offset', 'scale',
                             'units')
        text += '\n'
        text += signature % ('=' * 16, '=' * 6, '=' * 6, '=' * 7, '=' * 7,
                             '=' * 6)
        for ax in self.navigation_axes:
            text += '\n'
            text += ax_signature % (str(ax.name)[:16], ax.size, str(
                ax.index), ax.offset, ax.scale, ax.units)
        text += '\n'
        text += signature % ('-' * 16, '-' * 6, '-' * 6, '-' * 7, '-' * 7,
                             '-' * 6)
        for ax in self.signal_axes:
            text += '\n'
            text += ax_signature % (str(
                ax.name)[:16], ax.size, ' ', ax.offset, ax.scale, ax.units)

        return text

    def _repr_html_(self):
        text = ("<style>\n"
                "table, th, td {\n\t"
                "border: 1px solid black;\n\t"
                "border-collapse: collapse;\n}"
                "\nth, td {\n\t"
                "padding: 5px;\n}"
                "\n</style>")
        text += ('\n<p><b>< Axes manager, axes: %s ></b></p>\n' %
                 self._get_dimension_str())

        def format_row(*args, tag='td', bold=False):
            if bold:
                signature = "\n<tr class='bolder_row'> "
            else:
                signature = "\n<tr> "
            signature += " ".join(("{}" for _ in args)) + " </tr>"
            return signature.format(*map(
                lambda x: '\n<' + tag + '>{}</'.format(x) + tag + '>', args))

        if self.navigation_axes:
            text += "<table style='width:100%'>\n"
            text += format_row('Navigation axis name',
                               'size',
                               'index',
                               'offset',
                               'scale',
                               'units',
                               tag='th')
            for ax in self.navigation_axes:
                text += format_row(ax.name, ax.size, ax.index, ax.offset,
                                   ax.scale, ax.units)
            text += "</table>\n"
        if self.signal_axes:
            text += "<table style='width:100%'>\n"
            text += format_row('Signal axis name',
                               'size',
                               'offset',
                               'scale',
                               'units',
                               tag='th')
            for ax in self.signal_axes:
                text += format_row(ax.name, ax.size, ax.offset, ax.scale,
                                   ax.units)
            text += "</table>\n"
        return text

    @property
    def coordinates(self):
        """Get the coordinates of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.value for axis in self.navigation_axes])

    @coordinates.setter
    def coordinates(self, coordinates):
        """Set the coordinates of the navigation axes.

        Parameters
        ----------
        coordinates : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(coordinates) != self.navigation_dimension:
            raise AttributeError(
                "The number of coordinates must be equal to the "
                "navigation dimension that is %i" % self.navigation_dimension)
        for value, axis in zip(coordinates, self.navigation_axes):
            axis.value = value

    @property
    def indices(self):
        """Get the index of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.index for axis in self.navigation_axes])

    @indices.setter
    def indices(self, indices):
        """Set the index of the navigation axes.

        Parameters
        ----------
        indices : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(indices) != self.navigation_dimension:
            raise AttributeError("The number of indices must be equal to the "
                                 "navigation dimension that is %i" %
                                 self.navigation_dimension)
        for index, axis in zip(indices, self.navigation_axes):
            axis.index = index

    def _get_axis_attribute_values(self, attr):
        return [getattr(axis, attr) for axis in self._axes]

    def _set_axis_attribute_values(self, attr, values):
        """Set the given attribute of all the axes to the given
        value(s)

        Parameters
        ----------
        attr : string
            The DataAxis attribute to set.
        values: any
            If iterable, it must have the same number of items
            as axes are in this AxesManager instance. If not iterable,
            the attribute of all the axes are set to the given value.

        """
        if not isiterable(values):
            values = [
                values,
            ] * len(self._axes)
        elif len(values) != len(self._axes):
            raise ValueError("Values must have the same number"
                             "of items are axes are in this AxesManager")
        for axis, value in zip(self._axes, values):
            setattr(axis, attr, value)

    @property
    def navigation_indices_in_array(self):
        return tuple([axis.index_in_array for axis in self.navigation_axes])

    @property
    def signal_indices_in_array(self):
        return tuple([axis.index_in_array for axis in self.signal_axes])

    @property
    def axes_are_aligned_with_data(self):
        """Verify if the data axes are aligned with the signal axes.

        When the data are aligned with the axes the axes order in `self._axes`
        is [nav_n, nav_n-1, ..., nav_0, sig_m, sig_m-1 ..., sig_0].

        Returns
        -------
        aligned : bool

        """
        nav_iia_r = self.navigation_indices_in_array[::-1]
        sig_iia_r = self.signal_indices_in_array[::-1]
        iia_r = nav_iia_r + sig_iia_r
        aligned = iia_r == tuple(range(len(iia_r)))
        return aligned

    def _sort_axes(self):
        """Sort _axes to align them.

        When the data are aligned with the axes the axes order in `self._axes`
        is [nav_n, nav_n-1, ..., nav_0, sig_m, sig_m-1 ..., sig_0]. This method
        sort the axes in this way. Warning: this doesn't sort the `data` axes.

        """
        am = self
        new_axes = am.navigation_axes[::-1] + am.signal_axes[::-1]
        self._axes = list(new_axes)
예제 #6
0
class BackgroundRemoval(SpanSelectorInSpectrum):
    background_type = t.Enum('Power Law',
                             'Gaussian',
                             'Offset',
                             'Polynomial',
                             default='Power Law')
    polynomial_order = t.Range(1, 10)
    background_estimator = t.Instance(Component)
    bg_line_range = t.Enum('from_left_range',
                           'full',
                           'ss_range',
                           default='full')
    hi = t.Int(0)
    view = tu.View(tu.Group(
        'background_type',
        tu.Group('polynomial_order',
                 visible_when='background_type == \'Polynomial\''),
    ),
                   buttons=[OKButton, CancelButton],
                   handler=SpanSelectorInSpectrumHandler,
                   title='Background removal tool')

    def __init__(self, signal):
        super(BackgroundRemoval, self).__init__(signal)
        self.set_background_estimator()
        self.bg_line = None

    def on_disabling_span_selector(self):
        if self.bg_line is not None:
            self.bg_line.close()
            self.bg_line = None

    def set_background_estimator(self):

        if self.background_type == 'Power Law':
            self.background_estimator = components.PowerLaw()
            self.bg_line_range = 'from_left_range'
        elif self.background_type == 'Gaussian':
            self.background_estimator = components.Gaussian()
            self.bg_line_range = 'full'
        elif self.background_type == 'Offset':
            self.background_estimator = components.Offset()
            self.bg_line_range = 'full'
        elif self.background_type == 'Polynomial':
            self.background_estimator = \
                components.Polynomial(self.polynomial_order)
            self.bg_line_range = 'full'

    def _polynomial_order_changed(self, old, new):
        self.background_estimator = components.Polynomial(new)
        self.span_selector_changed()

    def _background_type_changed(self, old, new):
        self.set_background_estimator()
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        self.span_selector_changed()

    def create_background_line(self):
        self.bg_line = drawing.spectrum.SpectrumLine()
        self.bg_line.data_function = self.bg_to_plot
        self.bg_line.set_line_properties(color='blue', type='line')
        self.signal._plot.signal_plot.add_line(self.bg_line)
        self.bg_line.autoscale = False
        self.bg_line.plot()

    def bg_to_plot(self, axes_manager=None, fill_with=np.nan):
        # First try to update the estimation
        self.background_estimator.estimate_parameters(self.signal,
                                                      self.ss_left_value,
                                                      self.ss_right_value,
                                                      only_current=True)

        if self.bg_line_range == 'from_left_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:])
            return bg_array
        elif self.bg_line_range == 'full':
            return self.background_estimator.function(self.axis.axis)
        elif self.bg_line_range == 'ss_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            to_index = self.axis.value2index(self.ss_right_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:to_index])

    def span_selector_changed(self):
        if self.background_estimator is None:
            print("No bg estimator")
            return
        if self.bg_line is None and \
            self.background_estimator.estimate_parameters(
                self.signal, self.ss_left_value, self.ss_right_value,
                only_current = True) is True:
            self.create_background_line()
        else:
            self.bg_line.update()

    def apply(self):
        self.signal._plot.auto_update_plot = False
        maxval = self.signal.axes_manager.navigation_size
        if maxval > 0:
            pbar = progressbar(maxval=maxval)
        i = 0
        self.bg_line_range = 'full'
        for s in self.signal:
            s.data[:] -= \
            np.nan_to_num(self.bg_to_plot(self.signal.axes_manager,
                                          0))
            if self.background_type == 'Power Law':
                s.data[:self.axis.value2index(self.ss_right_value)] = 0

            i += 1
            if maxval > 0:
                pbar.update(i)
        if maxval > 0:
            pbar.finish()

        self.signal._replot()
        self.signal._plot.auto_update_plot = True
class FETS2DMITC(FETSEval):
    r'''MITC3 shell finite element:
        See https://www.sesamx.io/blog/standard_linear_triangular_shell_element/
        See http://dx.doi.org/10.1016/j.compstruc.2014.02.005
    '''

    vtk_r = tr.Array(np.float_, value=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    vtk_cells = [[0, 1, 2]]
    vtk_cell_types = 'Triangle'
    vtk_cell = [0, 1, 2]
    vtk_cell_type = 'Triangle'

    vtk_expand_operator = tr.Array(np.float_, value=DELTA23_ab)

    # =========================================================================
    # Surface integrals using numerical integration
    # =========================================================================
    # i is point indices, p is point coords in natural coords (zeta_1, zeta_2, zeta_3)
    eta_ip = tr.Array('float_')
    r'''Integration points within a triangle.
    '''
    def _eta_ip_default(self):
        # We applay the 7-point Gauss integration to integrate exactly on the plane defined by rr and ss
        # [r, s, t] at each Gauss point (r, s, t) are natural coords in the shell element r,s in plane
        # and t along thickness
        # 2 Gauss points along thickness t
        # 7 Gauss points on the plane of the element
        return np.array([[1. / 3., 1. / 3., -np.sqrt(1 / 3)],
                         [1. / 3., 1. / 3., np.sqrt(1 / 3)]],
                        dtype='f')

    w_m = tr.Array('float_')
    r'''Weight factors for numerical integration.
    '''

    def _w_m_default(self):
        print('w_m called!!')
        return np.array([1. / 2., 1. / 2.], dtype='f')

    # TODO, different node thickness in each node according to the original
    #  implementation can be easily integrated, but here the same thickness
    #  is used for simplicity.
    a = tr.Float(1.0, label='thickness')

    n_m = tr.Property(depends_on='w_m')
    r'''Number of integration points.
    '''

    @tr.cached_property
    def _get_n_m(self):
        return len(self.w_m)

    n_nodal_dofs = tr.Int(5)

    N_im = tr.Property(depends_on='eta_ip')
    r'''Shape function values in integration points.
    '''

    @tr.cached_property
    def _get_N_im(self):
        # Shouldn't be mi instead of im?
        eta = self.eta_ip
        return np.array([eta[:, 0], eta[:, 1], 1 - eta[:, 0] - eta[:, 1]],
                        dtype='f')

    dh_imr = tr.Property(depends_on='eta_ip')
    r'''Derivatives of the shape functions in the integration points.
    '''

    @tr.cached_property
    def _get_dh_imr(self):
        # Same for all Gauss points
        dh_ri = np.array(
            [
                [1, 0, -1],  # dh1/d_r, dh2/d_r, dh3/d_r
                [0, 1, -1],  # dh1/d_s, dh2/d_s, dh3/d_s
                [0, 0, 0]
            ],  # dh1/d_t, dh2/d_t, dh3/d_t
            dtype=np.float_)
        dh_mri = np.tile(dh_ri, (self.n_m, 1, 1))

        return np.einsum('mri->imr', dh_mri)

    dht_imr = tr.Property(depends_on='eta_ip')
    r'''Derivatives of the (shape functions * t) in the integration points.
    '''

    @tr.cached_property
    def _get_dht_imr(self):
        # m: gauss points, r: r, s, t, and i: h1, h2, h3
        eta_ip = self.eta_ip
        dh_mri = np.array(
            [
                [
                    [t, 0, -t],  # (t*dh1)/d_r, (t*dh2)/d_r, (t*dh3)/d_r
                    [0, t, -t],  # (t*dh1)/d_s, (t*dh2)/d_s, (t*dh3)/d_s
                    [r, s, 1 - r - s]
                ]  # (t*dh1)/d_t, (t*dh2)/d_t, (t*dh3)/d_t
                for r, s, t in zip(eta_ip[:, 0], eta_ip[:, 1], eta_ip[:, 2])
            ],
            dtype=np.float_)
        return np.einsum('mri->imr', dh_mri)

    dh_inr = tr.Property(depends_on='eta_ip')
    r'''Derivatives of the shape functions in the integration points.
    '''

    @tr.cached_property
    def _get_dh_inr(self):
        return self.dh_imr

    vtk_expand_operator = tr.Array(value=[1, 1, 0])
    vtk_node_cell_data = tr.Array
    vtk_ip_cell_data = tr.Array
예제 #8
0
class SpikesRemoval(SpanSelectorInSignal1D):
    interpolator_kind = t.Enum(
        'Linear',
        'Spline',
        default='Linear',
        desc="the type of interpolation to use when\n"
             "replacing the signal where a spike has been replaced")
    threshold = t.Float(400, desc="the derivative magnitude threshold above\n"
                        "which to find spikes")
    click_to_show_instructions = t.Button()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1, 10, 3,
                           desc="the order of the spline used to\n"
                           "connect the reconstructed data")
    interpolator = None
    default_spike_width = t.Int(5,
                                desc="the width over which to do the interpolation\n"
                                "when removing a spike (this can be "
                                "adjusted for each\nspike by clicking "
                                     "and dragging on the display during\n"
                                     "spike replacement)")
    index = t.Int(0)
    add_noise = t.Bool(True,
                       desc="whether to add noise to the interpolated\nportion"
                       "of the spectrum. The noise properties defined\n"
                       "in the Signal metadata are used if present,"
                            "otherwise\nshot noise is used as a default")

    def __init__(self, signal, navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [coordinate for coordinate in
                            signal.axes_manager._am_indices_generator()
                            if (navigation_mask is None or not
                                navigation_mask[coordinate[::-1]])]
        self.signal = signal
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        if len(self.coordinates) > 1:
            signal.axes_manager.indices = self.coordinates[0]
        self.index = 0
        self.argmax = None
        self.derivmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask
        md = self.signal.metadata
        from hyperspy.signal import BaseSignal

        if "Signal.Noise_properties" in md:
            if "Signal.Noise_properties.variance" in md:
                self.noise_variance = md.Signal.Noise_properties.variance
                if isinstance(md.Signal.Noise_properties.variance, BaseSignal):
                    self.noise_type = "heteroscedastic"
                else:
                    self.noise_type = "white"
            else:
                self.noise_type = "shot noise"
        else:
            self.noise_type = "shot noise"

    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()

    def _click_to_show_instructions_fired(self):
        from pyface.message_dialog import information
        m = information(None, SPIKES_REMOVAL_INSTRUCTIONS,
                        title="Instructions"),

    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                      navigation_mask=self.navigation_mask)

    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            self.derivmax = abs(derivative.max())
            return True
        else:
            return False

    def _reset_line(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()

    def find(self, back=False):
        self._reset_line()
        ncoordinates = len(self.coordinates)
        spike = self.detect_spike()
        with self.signal.axes_manager.events.indices_changed.suppress():
            while not spike and (
                    (self.index < ncoordinates - 1 and back is False) or
                    (self.index > 0 and back is True)):
                if back is False:
                    self.index += 1
                else:
                    self.index -= 1
                spike = self.detect_spike()

        if spike is False:
            m = SimpleMessage()
            m.text = 'End of dataset reached'
            try:
                m.gui()
            except (NotImplementedError, ImportError):
                # This is only available for traitsui, ipywidgets has a
                # progress bar instead.
                pass
            except ValueError as error:
                _logger.warning(error)
            self.index = 0
            self._reset_line()
            return
        else:
            minimum = max(0, self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            thresh_label = DerivativeTextParameters(
                text=r"$\mathsf{\delta}_\mathsf{max}=$",
                color="black")
            self.ax.legend([thresh_label], [repr(int(self.derivmax))],
                           handler_map={DerivativeTextParameters:
                                        DerivativeTextHandler()},
                           loc='best')
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(
                    minimum),
                self.signal.axes_manager.signal_axes[0].index2value(
                    maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        if len(self.coordinates) > 1:
            self.signal._plot.pointer._on_navigate(self.signal.axes_manager)

    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False

    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False

    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None

    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()

    def _add_noise_changed(self, old, new):
        self.span_selector_changed()

    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def create_interpolation_line(self):
        self.interpolated_line = drawing.signal1d.Signal1DLine()
        self.interpolated_line.data_function = self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(
            color='blue',
            type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.auto_update = False
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()

    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if np.isnan(self.ss_left_value) or np.isnan(self.ss_right_value):
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)

        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1

        return left, right

    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = self.spline_order
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = int(np.clip(left, 0, len(data)))
        right = int(np.clip(right, 0, len(data)))
        if ileft == 0:
            # Extrapolate to the left
            if right == iright:
                right -= 1
            data[:right] = data[right:iright].mean()

        elif iright == len(data):
            # Extrapolate to the right
            if left == ileft:
                left += 1
            data[left:] = data[ileft:left].mean()

        else:
            # Interpolate
            x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
            y = np.hstack((data[ileft:left], data[right:iright]))
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])

        # Add noise
        if self.add_noise is True:
            if self.noise_type == "white":
                data[left:right] += np.random.normal(
                    scale=np.sqrt(self.noise_variance),
                    size=right - left)
            elif self.noise_type == "heteroscedastic":
                noise_variance = self.noise_variance(
                    axes_manager=self.signal.axes_manager)[left:right]
                noise = [np.random.normal(scale=np.sqrt(item))
                         for item in noise_variance]
                data[left:right] += noise
            else:
                data[left:right] = np.random.poisson(
                    np.clip(data[left:right], 0, np.inf))

        return data

    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()

    def apply(self):
        if not self.interpolated_line:  # No spike selected
            return
        self.signal()[:] = self.get_interpolated_spectrum()
        self.signal.events.data_changed.trigger(obj=self.signal)
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
예제 #9
0
class TLoop(tr.HasTraits):

    tstep = tr.Instance(ITStep)
    sim = tr.DelegatesTo('tstep')
    tline = tr.Property

    def _get_tline(self):
        return self.sim.tline

    k_max = tr.Int(100, enter_set=True, auto_set=False)

    acc = tr.Float(1e-4, enter_set=True, auto_set=False)

    verbose = tr.Bool(False, enter_set=True, auto_set=False)

    paused = tr.Bool(False)

    restart = tr.Bool(True)

    user_wants_abort = tr.Property

    def _get_user_wants_abort(self):
        return self.restart or self.paused

    def init(self):
        if self.paused:
            self.paused = False
        if self.restart:
            self.tline.val = self.tline.min
            self.tstep.init_state()
            self.restart = False

    def eval(self):
        t_n1 = self.tline.val
        t_max = self.tline.max
        dt = self.tline.step

        if self.verbose:
            print('t:', end='')

        while t_n1 <= (t_max + 1e-8):
            if self.verbose:
                print('\t%5.2f' % t_n1, end='')
            k = 0
            self.tstep.t_n1 = t_n1
            while (k < self.k_max) and (not self.user_wants_abort):
                R_norm = self.tstep.R_norm
                if R_norm < self.acc:
                    if self.verbose:
                        print('(%g), ' % k, end='\n')
                    break
                try:
                    self.tstep.make_iter()
                except RuntimeError as e:
                    raise(e)
                k += 1
            else:  # handle unfinished iteration loop
                if k >= self.k_max:  # maximum number of restarts exceeded
                    # no success abort the simulation
                    self.restart = True
                    print('')
                    raise StopIteration('Warning: '
                                        'convergence not reached in %g iterations' % k)
                else:  # reduce the step size
                    dt /= 2
                    continue

            # accept the time step and record the state in history
            self.tstep.make_incr(t_n1)
            # update the line - launches notifiers to subscribers
            self.tline.val = min(t_n1, self.tline.max)
            # set a new target time
            t_n1 += dt
            self.tstep.t_n1 = t_n1
        return
예제 #10
0
class SmoothingSavitzkyGolay(Smoothing):

    polynomial_order = t.Int(
        3,
        desc="The order of the polynomial used to fit the samples."
             "`polyorder` must be less than `window_length`.")

    window_length = t.Int(
        5,
        desc="`window_length` must be a positive odd integer.")

    increase_window_length = t.Button(orientation="horizontal", label="+")
    decrease_window_length = t.Button(orientation="horizontal", label="-")

    def _increase_window_length_fired(self):
        if self.window_length % 2:
            nwl = self.window_length + 2
        else:
            nwl = self.window_length + 1
        if nwl < self.signal.axes_manager[2j].size:
            self.window_length = nwl

    def _decrease_window_length_fired(self):
        if self.window_length % 2:
            nwl = self.window_length - 2
        else:
            nwl = self.window_length - 1
        if nwl > self.polynomial_order:
            self.window_length = nwl
        else:
            _logger.warning(
                "The window length must be greater than the polynomial order")

    def _polynomial_order_changed(self, old, new):
        if self.window_length <= new:
            self.window_length = new + 2 if new % 2 else new + 1
            _logger.warning(
                "Polynomial order must be < window length. "
                "Window length set to %i.", self.window_length)
        self.update_lines()

    def _window_length_changed(self, old, new):
        self.update_lines()

    def _differential_order_changed(self, old, new):
        if new > self.polynomial_order:
            self.polynomial_order += 1
            _logger.warning(
                "Differential order must be <= polynomial order. "
                "Polynomial order set to %i.", self.polynomial_order)
        super(
            SmoothingSavitzkyGolay,
            self)._differential_order_changed(
            old,
            new)

    def diff_model2plot(self, axes_manager=None):
        self.single_spectrum.data = self.signal().copy()
        self.single_spectrum.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=self.differential_order)
        return self.single_spectrum.data

    def model2plot(self, axes_manager=None):
        self.single_spectrum.data = self.signal().copy()
        self.single_spectrum.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=0)
        return self.single_spectrum.data

    def apply(self):
        self.signal.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=self.differential_order)
        self.signal._replot()
예제 #11
0
class BackgroundRemoval(SpanSelectorInSignal1D):
    background_type = t.Enum(
        'Power Law',
        'Gaussian',
        'Offset',
        'Polynomial',
        default='Power Law')
    polynomial_order = t.Range(1, 10)
    fast = t.Bool(True,
                  desc=("Perform a fast (analytic, but possibly less accurate)"
                        " estimation of the background. Otherwise use "
                        "use non-linear least squares."))
    zero_fill = t.Bool(
        False,
        desc=("Set all spectral channels lower than the lower \n"
              "bound of the fitting range to zero (this is the \n"
              "default behavior of Gatan's DigitalMicrograph). \n"
              "Otherwise leave the pre-fitting region as-is \n"
              "(useful for inspecting quality of background fit)."))
    background_estimator = t.Instance(Component)
    bg_line_range = t.Enum('from_left_range',
                           'full',
                           'ss_range',
                           default='full')
    hi = t.Int(0)

    def __init__(self, signal, background_type='Power Law', polynomial_order=2,
                 fast=True, plot_remainder=True, zero_fill=False,
                 show_progressbar=None):
        super(BackgroundRemoval, self).__init__(signal)
        # setting the polynomial order will change the backgroud_type to
        # polynomial, so we set it before setting the background type
        self.polynomial_order = polynomial_order
        self.background_type = background_type
        self.set_background_estimator()
        self.fast = fast
        self.plot_remainder = plot_remainder
        self.zero_fill = zero_fill
        self.show_progressbar = show_progressbar
        self.bg_line = None
        self.rm_line = None

    def on_disabling_span_selector(self):
        if self.bg_line is not None:
            self.bg_line.close()
            self.bg_line = None
        if self.rm_line is not None:
            self.rm_line.close()
            self.rm_line = None

    def set_background_estimator(self):
        if self.background_type == 'Power Law':
            self.background_estimator = components1d.PowerLaw()
            self.bg_line_range = 'from_left_range'
        elif self.background_type == 'Gaussian':
            self.background_estimator = components1d.Gaussian()
            self.bg_line_range = 'full'
        elif self.background_type == 'Offset':
            self.background_estimator = components1d.Offset()
            self.bg_line_range = 'full'
        elif self.background_type == 'Polynomial':
            self.background_estimator = components1d.Polynomial(
                self.polynomial_order)
            self.bg_line_range = 'full'

    def _polynomial_order_changed(self, old, new):
        self.background_estimator = components1d.Polynomial(new)
        self.span_selector_changed()

    def _background_type_changed(self, old, new):
        self.set_background_estimator()
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def create_background_line(self):
        self.bg_line = drawing.signal1d.Signal1DLine()
        self.bg_line.data_function = self.bg_to_plot
        self.bg_line.set_line_properties(
            color='blue',
            type='line',
            scaley=False)
        self.signal._plot.signal_plot.add_line(self.bg_line)
        self.bg_line.autoscale = False
        self.bg_line.plot()

    def create_remainder_line(self):
        self.rm_line = drawing.signal1d.Signal1DLine()
        self.rm_line.data_function = self.rm_to_plot
        self.rm_line.set_line_properties(
            color='green',
            type='line',
            scaley=False)
        self.signal._plot.signal_plot.add_line(self.rm_line)
        self.rm_line.autoscale = False
        self.rm_line.plot()

    def bg_to_plot(self, axes_manager=None, fill_with=np.nan):
        # First try to update the estimation
        self.background_estimator.estimate_parameters(
            self.signal, self.ss_left_value, self.ss_right_value,
            only_current=True)

        if self.bg_line_range == 'from_left_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:])
            to_return = bg_array
        elif self.bg_line_range == 'full':
            to_return = self.background_estimator.function(self.axis.axis)
        elif self.bg_line_range == 'ss_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            to_index = self.axis.value2index(self.ss_right_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:to_index])
            to_return = bg_array

        if self.signal.metadata.Signal.binned is True:
            to_return *= self.axis.scale
        return to_return

    def rm_to_plot(self, axes_manager=None, fill_with=np.nan):
        return self.signal() - self.bg_line.line.get_ydata()

    def span_selector_changed(self):
        if self.ss_left_value is np.nan or self.ss_right_value is np.nan or\
                self.ss_right_value <= self.ss_left_value:
            return
        if self.background_estimator is None:
            return
        res = self.background_estimator.estimate_parameters(
            self.signal, self.ss_left_value,
            self.ss_right_value,
            only_current=True)
        if self.bg_line is None:
            if res:
                self.create_background_line()
        else:
            self.bg_line.update()
        if self.plot_remainder:
            if self.rm_line is None:
                if res:
                    self.create_remainder_line()
            else:
                self.rm_line.update()

    def apply(self):
        if self.signal._plot:
            self.signal._plot.close()
            plot = True
        else:
            plot = False
        background_type = ("PowerLaw" if self.background_type == "Power Law"
                           else self.background_type)
        new_spectra = self.signal.remove_background(
            signal_range=(self.ss_left_value, self.ss_right_value),
            background_type=background_type,
            fast=self.fast,
            zero_fill=self.zero_fill,
            polynomial_order=self.polynomial_order,
            show_progressbar=self.show_progressbar)
        self.signal.data = new_spectra.data
        self.signal.events.data_changed.trigger(self)
        if plot:
            self.signal.plot()
예제 #12
0
class Smoothing(t.HasTraits):
    # The following is disabled because as of traits 4.6 the Color trait
    # imports traitsui (!)
    # try:
    #     line_color = t.Color("blue")
    # except ModuleNotFoundError:
    #     # traitsui is required to define this trait so it is not defined when
    #     # traitsui is not installed.
    #     pass
    line_color_ipy = t.Str("blue")
    differential_order = t.Int(0)

    @property
    def line_color_rgb(self):
        if hasattr(self, "line_color"):
            try:
                # PyQt and WX
                return np.array(self.line_color.Get()) / 255.
            except AttributeError:
                try:
                    # PySide
                    return np.array(self.line_color.getRgb()) / 255.
                except BaseException:
                    return matplotlib.colors.to_rgb(self.line_color_ipy)
        else:
            return matplotlib.colors.to_rgb(self.line_color_ipy)

    def __init__(self, signal):
        self.ax = None
        self.data_line = None
        self.smooth_line = None
        self.signal = signal
        self.single_spectrum = self.signal.get_current_signal().deepcopy()
        self.axis = self.signal.axes_manager.signal_axes[0].axis
        self.plot()

    def plot(self):
        if self.signal._plot is None or not self.signal._plot.is_active:
            self.signal.plot()
        hse = self.signal._plot
        l1 = hse.signal_plot.ax_lines[0]
        self.original_color = l1.line.get_color()
        l1.set_line_properties(color=self.original_color,
                               type='scatter')
        l2 = drawing.signal1d.Signal1DLine()
        l2.data_function = self.model2plot

        l2.set_line_properties(
            color=self.line_color_rgb,
            type='line')
        # Add the line to the figure
        hse.signal_plot.add_line(l2)
        l2.plot()
        self.data_line = l1
        self.smooth_line = l2
        self.smooth_diff_line = None

    def update_lines(self):
        self.smooth_line.update()
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.update()

    def turn_diff_line_on(self, diff_order):

        self.signal._plot.signal_plot.create_right_axis()
        self.smooth_diff_line = drawing.signal1d.Signal1DLine()
        self.smooth_diff_line.axes_manager = self.signal.axes_manager
        self.smooth_diff_line.data_function = self.diff_model2plot
        self.smooth_diff_line.set_line_properties(
            color=self.line_color_rgb,
            type='line')
        self.signal._plot.signal_plot.add_line(self.smooth_diff_line,
                                               ax='right')

    def _line_color_ipy_changed(self):
        if hasattr(self, "line_color"):
            self.line_color = str(self.line_color_ipy)
        else:
            self._line_color_changed(None, None)

    def turn_diff_line_off(self):
        if self.smooth_diff_line is None:
            return
        self.smooth_diff_line.close()
        self.smooth_diff_line = None

    def _differential_order_changed(self, old, new):
        if new == 0:
            self.turn_diff_line_off()
            return
        if old == 0:
            self.turn_diff_line_on(new)
            self.smooth_diff_line.plot()
        else:
            self.smooth_diff_line.update(force_replot=False)

    def _line_color_changed(self, old, new):
        self.smooth_line.line_properties = {
            'color': self.line_color_rgb}
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.line_properties = {
                'color': self.line_color_rgb}
        try:
            # it seems that changing the properties can be done before the
            # first rendering event, which can cause issue with blitting
            self.update_lines()
        except AttributeError:
            pass

    def diff_model2plot(self, axes_manager=None):
        smoothed = np.diff(self.model2plot(axes_manager),
                           self.differential_order)
        return smoothed

    def close(self):
        if self.signal._plot.is_active:
            if self.differential_order != 0:
                self.turn_diff_line_off()
            self.smooth_line.close()
            self.data_line.set_line_properties(
                color=self.original_color,
                type='line')
예제 #13
0
class X(tr.HasStrictTraits):

    i = tr.Int()
예제 #14
0
class PlsrPcr(Model):
    """Represent the PlsrPcr model between one X and Y data set."""

    # Consumer liking
    ds_C = DataSet()
    # Descriptive analysis / sensory profiling
    ds_S = DataSet()
    ds_X = _traits.Property()
    ds_Y = _traits.Property()
    settings = _traits.WeakRef()
    # Checkbox bool for standardised results
    standardise_x = _traits.Bool(False)
    standardise_y = _traits.Bool(False)
    int_ext_mapping = _traits.Enum('Internal', 'External')
    plscr_method = _traits.Enum('PLSR', 'PCR')
    calc_n_pc = _traits.Int()
    min_pc = 2
    # max_pc = _traits.Property()
    max_pc = 10
    min_std = _traits.Float(0.001)
    C_zero_std = _traits.List()
    S_zero_std = _traits.List()


    def _get_res(self):
        if self._have_zero_std():
            raise InComputeable('Matrix have variables with zero variance',
                                self.C_zero_std, self.S_zero_std)
        n_pc = min(self.settings.calc_n_pc, self._get_max_pc())
        if self.settings.plscr_method == 'PLSR':
            pls = PLSR(self.ds_X.values, self.ds_Y.values,
                       numComp=n_pc, cvType=["loo"],
                       Xstand=self.settings.standardise_x, Ystand=self.settings.standardise_y)
            return self._pack_res(pls)
        elif self.settings.plscr_method == 'PCR':
            pcr = PCR(self.ds_X.values, self.ds_Y.values,
                      numComp=n_pc, cvType=["loo"],
                      Xstand=self.settings.standardise_x, Ystand=self.settings.standardise_y)
            return self._pack_res(pcr)


    def _have_zero_std(self):
        self.C_zero_std = []
        self.S_zero_std = []
        if self._std_C() and self._std_S():
            rC = self._C_have_zero_std_var()
            rS = self._S_have_zero_std_var()
            return rC or rS
        elif self._std_C():
            return self._C_have_zero_std_var()
        elif self._std_S():
            return self._S_have_zero_std_var()


    def _std_C(self):
        if self.settings.int_ext_mapping == 'Internal':
            return self.settings.standardise_x
        else:
            return self.settings.standardise_y


    def _std_S(self):
        if self.settings.int_ext_mapping == 'Internal':
            return self.settings.standardise_y
        else:
            return self.settings.standardise_x


    def _C_have_zero_std_var(self):
        self.C_zero_std = self._check_zero_std(self.ds_C)
        return bool(self.C_zero_std)


    def _S_have_zero_std_var(self):
        self.S_zero_std = self._check_zero_std(self.ds_S)
        return bool(self.S_zero_std)


    def _check_zero_std(self, ds):
        zero_std_var = []
        sv = ds.values.std(axis=0)
        dm = sv < self.min_std
        if _np.any(dm):
            vv = _np.array(ds.var_n)
            zero_std_var = list(vv[_np.nonzero(dm)])
        return zero_std_var


    def _get_ds_X(self):
        if self.settings.int_ext_mapping == 'Internal':
            return self.ds_C
        else:
            return self.ds_S


    def _get_ds_Y(self):
        if self.settings.int_ext_mapping == 'Internal':
            return self.ds_S
        else:
            return self.ds_C


    def _get_max_pc(self):
        if self.settings.int_ext_mapping == 'Internal':
            return max((min(self.ds_C.n_objs, self.ds_C.n_vars, 11) - 1), self.min_pc)
        else:
            return max((min(self.ds_S.n_objs, self.ds_S.n_vars, 11) - 1), self.min_pc)


    def _calc_n_pc_default(self):
        return self.max_pc


    def _mk_pred_ds(self, pred_mat, npc):
        pred_ds = DataSet(
            mat=_pd.DataFrame(
                data=pred_mat,
                index=self.ds_Y.obj_n,
                columns=self.ds_Y.var_n,
            ),
            display_name='Predicted after PC{}'.format(npc))
        return pred_ds


    def _pack_res(self, pls_obj):
        res = Result('PLSR/PCR {0}(X) & {1}(Y)'.format(self.ds_X.display_name, self.ds_Y.display_name))

        if self.settings.int_ext_mapping == 'External':
            res.external_mapping = True
        else:
            res.external_mapping = False

        res.plscr_method = self.settings.plscr_method

        # Scores X
        mT = pls_obj.X_scores()
        res.scores_x = DataSet(
            mat=_pd.DataFrame(
                data=mT,
                index=self.ds_X.obj_n,
                columns=["PC-{0}".format(i+1) for i in range(mT.shape[1])],
                ),
            display_name='X scores')

        # loadings_x
        mP = pls_obj.X_loadings()
        res.loadings_x = DataSet(
            mat=_pd.DataFrame(
                data=mP,
                index=self.ds_X.var_n,
                columns=["PC-{0}".format(i+1) for i in range(mP.shape[1])],
                ),
            display_name='X loadings')

        # loadings_y
        # Same as loading_x in external mapping?
        mQ = pls_obj.Y_loadings()
        res.loadings_y = DataSet(
            mat=_pd.DataFrame(
                data=mQ,
                index=self.ds_Y.var_n,
                columns=["PC-{0}".format(i+1) for i in range(mQ.shape[1])],
                ),
            display_name='Y loadings')

        # expl_var_x
        cal = pls_obj.X_calExplVar()
        cum_cal = pls_obj.X_cumCalExplVar()[1:]
        val = pls_obj.X_valExplVar()
        cum_val = pls_obj.X_cumValExplVar()[1:]
        res.expl_var_x = DataSet(
            mat=_pd.DataFrame(
                data=[cal, cum_cal, val, cum_val],
                index=['calibrated', 'cumulative calibrated', 'validated', 'cumulative validated'],
                columns=["PC-{0}".format(i+1) for i in range(len(cal))],
                ),
            display_name='Explained variance in X')

        # expl_var_y
        cal = pls_obj.Y_calExplVar()
        cum_cal = pls_obj.Y_cumCalExplVar()[1:]
        val = pls_obj.Y_valExplVar()
        cum_val = pls_obj.Y_cumValExplVar()[1:]
        res.expl_var_y = DataSet(
            mat=_pd.DataFrame(
                data=[cal, cum_cal, val, cum_val],
                index=['calibrated', 'cumulative calibrated', 'validated', 'cumulative validated'],
                columns=["PC-{0}".format(i+1) for i in range(len(cal))],
                ),
            display_name='Explained variance in Y')

        # X_corrLoadings()
        # corr_loadings_x
        mXcl = pls_obj.X_corrLoadings()
        res.corr_loadings_x = DataSet(
            mat=_pd.DataFrame(
                data=mXcl,
                index=self.ds_X.var_n,
                columns=["PC-{0}".format(i+1) for i in range(mXcl.shape[1])],
                ),
            display_name='X & Y correlation loadings')

        # Y_corrLoadings()
        # corr_loadings_y
        mYcl = pls_obj.Y_corrLoadings()
        res.corr_loadings_y = DataSet(
            mat=_pd.DataFrame(
                data=mYcl,
                index=self.ds_Y.var_n,
                columns=["PC-{0}".format(i+1) for i in range(mXcl.shape[1])],
                ),
            display_name=self.ds_Y.display_name)

        # Y_predCal()
        # Return a dict with Y pred for each PC
        pYc = pls_obj.Y_predCal()
        ks = pYc.keys()
        pYcs = [self._mk_pred_ds(pYc[k], k) for k in ks]
        res.pred_cal_y = pYcs

        # Y_predVal()
        # Return a dict with Y pred for each PC
        pYv = pls_obj.Y_predVal()
        ks = pYv.keys()
        pYvs = [self._mk_pred_ds(pYv[k], k) for k in ks]
        res.pred_val_y = pYvs

        return res
예제 #15
0
class BackgroundRemoval(SpanSelectorInSignal1D):
    background_type = t.Enum(
        'Power Law',
        'Gaussian',
        'Offset',
        'Polynomial',
        default='Power Law')
    polynomial_order = t.Range(1, 10)
    fast = t.Bool(True,
                  desc=("Perform a fast (analytic, but possibly less accurate)"
                        " estimation of the background. Otherwise use "
                        "non-linear least squares."))
    background_estimator = t.Instance(Component)
    bg_line_range = t.Enum('from_left_range',
                           'full',
                           'ss_range',
                           default='full')
    hi = t.Int(0)
    view = tu.View(
        tu.Group(
            'background_type',
            'fast',
            tu.Group(
                'polynomial_order',
                visible_when='background_type == \'Polynomial\''), ),
        buttons=[OKButton, CancelButton],
        handler=SpanSelectorInSignal1DHandler,
        title='Background removal tool',
        resizable=True,
        width=300,
    )

    def __init__(self, signal):
        super(BackgroundRemoval, self).__init__(signal)
        self.set_background_estimator()
        self.bg_line = None

    def on_disabling_span_selector(self):
        if self.bg_line is not None:
            self.bg_line.close()
            self.bg_line = None

    def set_background_estimator(self):

        if self.background_type == 'Power Law':
            self.background_estimator = components1d.PowerLaw()
            self.bg_line_range = 'from_left_range'
        elif self.background_type == 'Gaussian':
            self.background_estimator = components1d.Gaussian()
            self.bg_line_range = 'full'
        elif self.background_type == 'Offset':
            self.background_estimator = components1d.Offset()
            self.bg_line_range = 'full'
        elif self.background_type == 'Polynomial':
            self.background_estimator = components1d.Polynomial(
                self.polynomial_order)
            self.bg_line_range = 'full'

    def _polynomial_order_changed(self, old, new):
        self.background_estimator = components1d.Polynomial(new)
        self.span_selector_changed()

    def _background_type_changed(self, old, new):
        self.set_background_estimator()
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def create_background_line(self):
        self.bg_line = drawing.signal1d.Signal1DLine()
        self.bg_line.data_function = self.bg_to_plot
        self.bg_line.set_line_properties(
            color='blue',
            type='line',
            scaley=False)
        self.signal._plot.signal_plot.add_line(self.bg_line)
        self.bg_line.autoscale = False
        self.bg_line.plot()

    def bg_to_plot(self, axes_manager=None, fill_with=np.nan):
        # First try to update the estimation
        self.background_estimator.estimate_parameters(
            self.signal, self.ss_left_value, self.ss_right_value,
            only_current=True)

        if self.bg_line_range == 'from_left_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:])
            to_return = bg_array
        elif self.bg_line_range == 'full':
            to_return = self.background_estimator.function(self.axis.axis)
        elif self.bg_line_range == 'ss_range':
            bg_array = np.zeros(self.axis.axis.shape)
            bg_array[:] = fill_with
            from_index = self.axis.value2index(self.ss_left_value)
            to_index = self.axis.value2index(self.ss_right_value)
            bg_array[from_index:] = self.background_estimator.function(
                self.axis.axis[from_index:to_index])
            to_return = bg_array

        if self.signal.metadata.Signal.binned is True:
            to_return *= self.axis.scale
        return to_return

    def span_selector_changed(self):
        if self.ss_left_value is np.nan or self.ss_right_value is np.nan or\
                self.ss_right_value <= self.ss_left_value:
            return
        if self.background_estimator is None:
            return
        if self.bg_line is None and \
            self.background_estimator.estimate_parameters(
                self.signal, self.ss_left_value,
                self.ss_right_value,
                only_current=True) is True:
            self.create_background_line()
        else:
            self.bg_line.update()

    def apply(self):
        self.signal._plot.auto_update_plot = False
        new_spectra = self.signal._remove_background_cli(
            (self.ss_left_value, self.ss_right_value),
            self.background_estimator, fast=self.fast)
        self.signal.data = new_spectra.data
        self.signal._replot()
        self.signal._plot.auto_update_plot = True
예제 #16
0
class Smoothing(t.HasTraits):
    line_color = t.Color('blue')
    differential_order = t.Int(0)

    @property
    def line_color_rgb(self):
        try:
            # PyQt and WX
            return self.line_color.Get()
        except AttributeError:
            try:
                # PySide
                return self.line_color.getRgb()
            except:
                raise

    def __init__(self, signal):
        self.ax = None
        self.data_line = None
        self.smooth_line = None
        self.signal = signal
        self.single_spectrum = self.signal.get_current_signal().deepcopy()
        self.axis = self.signal.axes_manager.signal_axes[0].axis
        self.plot()

    def plot(self):
        if self.signal._plot is None or not \
                self.signal._plot.is_active():
            self.signal.plot()
        hse = self.signal._plot
        l1 = hse.signal_plot.ax_lines[0]
        self.original_color = l1.line.get_color()
        l1.set_line_properties(color=self.original_color, type='scatter')
        l2 = drawing.signal1d.Signal1DLine()
        l2.data_function = self.model2plot

        l2.set_line_properties(color=np.array(self.line_color_rgb) / 255.,
                               type='line')
        # Add the line to the figure
        hse.signal_plot.add_line(l2)
        l2.plot()
        self.data_line = l1
        self.smooth_line = l2
        self.smooth_diff_line = None

    def update_lines(self):
        self.smooth_line.update()
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.update()

    def turn_diff_line_on(self, diff_order):

        self.signal._plot.signal_plot.create_right_axis()
        self.smooth_diff_line = drawing.signal1d.Signal1DLine()
        self.smooth_diff_line.data_function = self.diff_model2plot
        self.smooth_diff_line.set_line_properties(
            color=np.array(self.line_color_rgb) / 255., type='line')
        self.signal._plot.signal_plot.add_line(self.smooth_diff_line,
                                               ax='right')
        self.smooth_diff_line.axes_manager = self.signal.axes_manager

    def turn_diff_line_off(self):
        if self.smooth_diff_line is None:
            return
        self.smooth_diff_line.close()
        self.smooth_diff_line = None

    def _differential_order_changed(self, old, new):
        if old == 0:
            self.turn_diff_line_on(new)
            self.smooth_diff_line.plot()
        if new == 0:
            self.turn_diff_line_off()
            return
        self.smooth_diff_line.update(force_replot=False)

    def _line_color_changed(self, old, new):
        self.smooth_line.line_properties = {
            'color': np.array(self.line_color_rgb) / 255.
        }
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.line_properties = {
                'color': np.array(self.line_color_rgb) / 255.
            }
        self.update_lines()

    def diff_model2plot(self, axes_manager=None):
        smoothed = np.diff(self.model2plot(axes_manager),
                           self.differential_order)
        return smoothed

    def close(self):
        if self.signal._plot.is_active():
            if self.differential_order != 0:
                self.turn_diff_line_off()
            self.smooth_line.close()
            self.data_line.set_line_properties(color=self.original_color,
                                               type='line')
예제 #17
0
class SpikesRemoval(SpanSelectorInSpectrum):
    interpolator_kind = t.Enum('Linear', 'Spline', default='Linear')
    threshold = t.Float()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1, 10, 3)
    interpolator = None
    default_spike_width = t.Int(5)
    index = t.Int(0)
    view = tu.View(tu.Group(
        tu.Group(
            tu.Item('show_derivative_histogram', show_label=False),
            'threshold',
            show_border=True,
        ),
        tu.Group('interpolator_kind',
                 'default_spike_width',
                 tu.Group('spline_order',
                          visible_when='interpolator_kind == \'Spline\''),
                 show_border=True,
                 label='Advanced settings'),
    ),
                   buttons=[
                       OKButton,
                       OurPreviousButton,
                       OurFindButton,
                       OurApplyButton,
                   ],
                   handler=SpikesRemovalHandler,
                   title='Spikes removal tool')

    def __init__(self, signal, navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [
            coordinate
            for coordinate in signal.axes_manager._am_indices_generator() if
            (navigation_mask is None or not navigation_mask[coordinate[::-1]])
        ]
        self.signal = signal
        sys.setrecursionlimit(np.cumprod(self.signal.data.shape)[-1])
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        signal.axes_manager.indices = self.coordinates[0]
        self.threshold = 400
        self.index = 0
        self.argmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask

    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()

    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                      navigation_mask=self.navigation_mask)

    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            return True
        else:
            return False

    def find(self, back=False):
        if ((self.index == len(self.coordinates) - 1 and back is False)
                or (back is True and self.index == 0)):
            messages.information('End of dataset reached')
            return
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()

        if self.detect_spike() is False:
            if back is False:
                self.index += 1
            else:
                self.index -= 1
            self.find(back=back)
        else:
            minimum = max(0, self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(minimum),
                self.signal.axes_manager.signal_axes[0].index2value(maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        self.signal._plot.pointer.update_patch_position()

    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False

    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False

    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None

    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()

    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        self.span_selector_changed()

    def create_interpolation_line(self):
        self.interpolated_line = drawing.spectrum.SpectrumLine()
        self.interpolated_line.data_function = \
            self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(color='blue', type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()

    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if self.ss_left_value == self.ss_right_value:
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)

        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1

        return left, right

    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = 10
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = np.clip(left, 0, len(data))
        right = np.clip(right, 0, len(data))
        x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
        y = np.hstack((data[ileft:left], data[right:iright]))
        if ileft == 0:
            # Extrapolate to the left
            data[left:right] = data[right + 1]

        elif iright == (len(data) - 1):
            # Extrapolate to the right
            data[left:right] = data[left - 1]

        else:
            # Interpolate
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])

        # Add noise
        data = np.random.poisson(np.clip(data, 0, np.inf))
        return data

    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()

    def apply(self):
        self.signal()[:] = self.get_interpolated_spectrum()
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
예제 #18
0
class SmoothingSavitzkyGolay(Smoothing):

    polynomial_order = t.Int(
        3,
        desc="The order of the polynomial used to fit the samples."
        "`polyorder` must be less than `window_length`.")

    window_length = t.Int(
        5, desc="`window_length` must be a positive odd integer.")

    increase_window_length = t.Button(orientation="horizontal", label="+")
    decrease_window_length = t.Button(orientation="horizontal", label="-")

    view = tu.View(
        tu.Group(
            tu.Group('window_length',
                     tu.Item('decrease_window_length', show_label=False),
                     tu.Item('increase_window_length', show_label=False),
                     orientation="horizontal"), 'polynomial_order',
            tu.Item(
                name='differential_order',
                tooltip='The order of the derivative to compute. This must '
                'be a nonnegative integer. The default is 0, which '
                'means to filter the data without differentiating.',
            ), 'line_color'),
        kind='live',
        handler=SmoothingHandler,
        buttons=OKCancelButtons,
        title='Savitzky-Golay Smoothing',
    )

    def _increase_window_length_fired(self):
        if self.window_length % 2:
            nwl = self.window_length + 2
        else:
            nwl = self.window_length + 1
        if nwl < self.signal.axes_manager[2j].size:
            self.window_length = nwl

    def _decrease_window_length_fired(self):
        if self.window_length % 2:
            nwl = self.window_length - 2
        else:
            nwl = self.window_length - 1
        if nwl > self.polynomial_order:
            self.window_length = nwl
        else:
            _logger.warn(
                "The window length must be greater than the polynomial order")

    def _polynomial_order_changed(self, old, new):
        if self.window_length <= new:
            self.window_length = new + 2 if new % 2 else new + 1
            _logger.warn(
                "Polynomial order must be < window length. "
                "Window length set to %i.", self.window_length)
        self.update_lines()

    def _window_length_changed(self, old, new):
        self.update_lines()

    def _differential_order_changed(self, old, new):
        if new > self.polynomial_order:
            self.polynomial_order += 1
            _logger.warn(
                "Differential order must be <= polynomial order. "
                "Polynomial order set to %i.", self.polynomial_order)
        super(SmoothingSavitzkyGolay,
              self)._differential_order_changed(old, new)

    def diff_model2plot(self, axes_manager=None):
        self.single_spectrum.data = self.signal().copy()
        self.single_spectrum.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=self.differential_order)
        return self.single_spectrum.data

    def model2plot(self, axes_manager=None):
        self.single_spectrum.data = self.signal().copy()
        self.single_spectrum.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=0)
        return self.single_spectrum.data

    def apply(self):
        self.signal.smooth_savitzky_golay(
            polynomial_order=self.polynomial_order,
            window_length=self.window_length,
            differential_order=self.differential_order)
        self.signal._replot()
예제 #19
0
class LogFilePlotFitter(traits.HasTraits):
    """This class allows the user to fit the data in log file plots with standard 
    functions or a custom function"""

    model = traits.Trait(
        "Gaussian", {
            "Linear": Model(fittingFunctions.linear),
            "Quadratic": Model(fittingFunctions.quadratic),
            "Gaussian": Model(fittingFunctions.gaussian),
            "lorentzian": Model(fittingFunctions.lorentzian),
            "parabola": Model(fittingFunctions.parabola),
            "exponential": Model(fittingFunctions.exponentialDecay),
            "sineWave": Model(fittingFunctions.sineWave),
            "sineWaveDecay1": Model(fittingFunctions.sineWaveDecay1),
            "sineWaveDecay2": Model(fittingFunctions.sineWaveDecay2),
            "sincSquared": Model(fittingFunctions.sincSquared),
            "sineSquared": Model(fittingFunctions.sineSquared),
            "sineSquaredDecay": Model(fittingFunctions.sineSquaredDecay),
            "custom": Model(custom)
        },
        desc="model selected for fitting the data"
    )  # mapped trait. so model --> string and model_ goes to Model object. see http://docs.enthought.com/traits/traits_user_manual/custom.html#mapped-traits
    parametersList = traits.List(
        Parameter, desc="list of parameters for fitting in chosen model")

    customCode = traits.Code(
        "def custom(x, param1, param2):\n\treturn param1*param2*x",
        desc="python code for a custom fitting function")
    customCodeCompileButton = traits.Button(
        "compile",
        desc=
        "defines the above function and assigns it to the custom model for fitting"
    )
    fitButton = traits.Button(
        "fit",
        desc="runs fit on selected data set using selected parameters and model"
    )
    usePreviousFitButton = traits.Button(
        "use previous fit",
        desc="use the fitted values as the initial guess for the next fit")
    guessButton = traits.Button(
        "guess",
        desc=
        "guess initial values from data using _guess function in library. If not defined button is disabled"
    )
    saveFitButton = traits.Button(
        "save fit",
        desc="writes fit parameters values and tolerances to a file")
    cycleAndFitButton = traits.Button(
        "cycle fit",
        desc=
        "fits using current initial parameters, saves fit, copies calculated values to initial guess and moves to next dataset in ordered dict"
    )
    dataSets = collections.OrderedDict(
    )  #dict mapping dataset name (for when we have multiple data sets) --> (xdata,ydata ) tuple (scipy arrays) e.g. {"myData": (array([1,2,3]), array([1,2,3]))}
    dataSetNames = traits.List(traits.String)
    selectedDataSet = traits.Enum(values="dataSetNames")
    modelFitResult = None
    logFilePlotReference = None
    modelFitMessage = traits.String("not yet fitted")
    isFitted = traits.Bool(False)
    maxFitTime = traits.Float(
        10.0, desc="maximum time fitting can last before abort")
    statisticsButton = traits.Button("stats")
    statisticsString = traits.String("statistics not calculated")
    plotPoints = traits.Int(200, label="Number of plot points")

    predefinedModelGroup = traitsui.VGroup(
        traitsui.Item("model", show_label=False),
        traitsui.Item("object.model_.definitionString",
                      style="readonly",
                      show_label=False,
                      visible_when="model!='custom'"))
    customFunctionGroup = traitsui.VGroup(traitsui.Item("customCode",
                                                        show_label=False),
                                          traitsui.Item(
                                              "customCodeCompileButton",
                                              show_label=False),
                                          visible_when="model=='custom'")
    modelGroup = traitsui.VGroup(predefinedModelGroup,
                                 customFunctionGroup,
                                 show_border=True)
    dataAndFittingGroup = traitsui.VGroup(
        traitsui.HGroup(
            traitsui.Item("selectedDataSet", label="dataset"),
            traitsui.Item("fitButton", show_label=False),
            traitsui.Item("usePreviousFitButton", show_label=False),
            traitsui.Item("guessButton",
                          show_label=False,
                          enabled_when="model_.guessFunction is not None")),
        traitsui.HGroup(traitsui.Item("cycleAndFitButton", show_label=False),
                        traitsui.Item("saveFitButton", show_label=False),
                        traitsui.Item("statisticsButton", show_label=False)),
        traitsui.Item("plotPoints"),
        traitsui.Item("statisticsString", style="readonly"),
        traitsui.Item("modelFitMessage", style="readonly"),
        show_border=True)
    variablesGroup = traitsui.VGroup(traitsui.Item(
        "parametersList",
        editor=traitsui.ListEditor(style="custom"),
        show_label=False,
        resizable=True),
                                     show_border=True,
                                     label="parameters")

    traits_view = traitsui.View(traitsui.Group(modelGroup,
                                               dataAndFittingGroup,
                                               variablesGroup,
                                               layout="split"),
                                resizable=True)

    def __init__(self, **traitsDict):
        super(LogFilePlotFitter, self).__init__(**traitsDict)
        self._set_parametersList()

    def _set_parametersList(self):
        """sets the parameter list to the correct values given the current model """
        self.parametersList = [
            Parameter(name=parameterName, parameter=parameterObject)
            for (parameterName,
                 parameterObject) in self.model_.parameters.iteritems()
        ]

    def _model_changed(self):
        """updates model and hences changes parameters appropriately"""
        self._set_parametersList()
        self._guessButton_fired(
        )  # will only guess if there is a valid guessing function

    def _customCodeCompileButton_fired(self):
        """defines function as defined by user """
        exec(self.customCode)
        self.model_.__init__(custom)
        self._set_parametersList()

    def setFitData(self, name, xData, yData):
        """updates the dataSets dictionary """
        self.dataSets[name] = (xData, yData)

    def cleanValidNames(self, uniqueValidNames):
        """removes any elements from datasets dictionary that do not 
        have a key in uniqueValidNames"""
        for dataSetName in self.dataSets.keys():
            if dataSetName not in uniqueValidNames:
                del self.dataSets[dataSetName]

    def setValidNames(self):
        """sets list of valid choices for datasets """
        self.dataSetNames = self.dataSets.keys()

    def getParameters(self):
        """ returns the lmfit parameters object for the fit function"""
        return lmfit.Parameters(
            {_.name: _.parameter
             for _ in self.parametersList})

    def _setCalculatedValues(self, modelFitResult):
        """updates calculated values with calculated argument """
        parametersResult = modelFitResult.params
        for variable in self.parametersList:
            variable.calculatedValue = parametersResult[variable.name].value

    def _setCalculatedValuesErrors(self, modelFitResult):
        """given the covariance matrix returned by scipy optimize fit
        convert this into stdeviation errors for parameters list and updated
        the stdevError attribute of variables"""
        parametersResult = modelFitResult.params
        for variable in self.parametersList:
            variable.stdevError = parametersResult[variable.name].stderr

    def fit(self):
        params = self.getParameters()
        x, y = self.dataSets[self.selectedDataSet]
        self.modelFitResult = self.model_.model.fit(y, x=x, params=params)
        #self.modelFitResult = self.model_.model.fit(y, x=x, params=params,iter_cb=self.getFitCallback(time.time()))#can also pass fit_kws= {"maxfev":1000}
        self._setCalculatedValues(
            self.modelFitResult)  #update fitting paramters final values
        self._setCalculatedValuesErrors(self.modelFitResult)
        self.modelFitMessage = self.modelFitResult.message
        if not self.modelFitResult.success:
            logger.error("failed to fit in LogFilePlotFitter")
        self.isFitted = True
        if self.logFilePlotReference is not None:
            self.logFilePlotReference.plotFit()

    def getFitCallback(self, startTime):
        """returns the callback function that is called at every iteration of fit to check if it 
        has been running too long"""
        def fitCallback(params, iter, resid, *args, **kws):
            """check the time and compare to start time """
            if time.time() - startTime > self.maxFitTime:
                return True

        return fitCallback

    def _fitButton_fired(self):
        self.fit()

    def _usePreviousFitButton_fired(self):
        """update the initial guess value with the fitted values of the parameter """
        for parameter in self.parametersList:
            parameter.initialValue = parameter.calculatedValue

    def _guessButton_fired(self):
        """calls _guess function and updates initial fit values accordingly """
        print "guess button clicked"
        if self.model_.guessFunction is None:
            print "attempted to guess initial values but no guess function is defined. returning without changing initial values"
            logger.error(
                "attempted to guess initial values but no guess function is defined. returning without changing initial values"
            )
            return
        logger.info("attempting to guess initial values using %s" %
                    self.model_.guessFunction.__name__)
        xs, ys = self.dataSets[self.selectedDataSet]
        guessDictionary = self.model_.guessFunction(xs, ys)
        logger.debug("guess results = %s" % guessDictionary)
        print "guess results = %s" % guessDictionary
        for parameterName, guessValue in guessDictionary.iteritems():
            for parameter in self.parametersList:
                if parameter.name == parameterName:
                    parameter.initialValue = guessValue

    def _saveFitButton_fired(self):
        saveFolder, filename = os.path.split(self.logFilePlotReference.logFile)
        parametersResult = self.modelFitResult.params
        logFileName = os.path.split(saveFolder)[1]
        functionName = self.model_.function.__name__
        saveFileName = os.path.join(
            saveFolder, logFileName + "-" + functionName + "-fitSave.csv")

        #parse selected data set name to get column names
        #selectedDataSet is like "aaaa=1.31 bbbb=1.21"
        seriesColumnNames = [
            seriesString.split("=")[0]
            for seriesString in self.selectedDataSet.split(" ")
        ]

        if not os.path.exists(saveFileName):  #create column names
            with open(saveFileName, "ab+") as csvFile:
                writer = csv.writer(csvFile)
                writer.writerow(
                    seriesColumnNames +
                    [variable.name for variable in self.parametersList] + [
                        variable.name + "-tolerance"
                        for variable in self.parametersList
                    ])
        with open(saveFileName, "ab+") as csvFile:  #write save to file
            writer = csv.writer(csvFile)
            seriesValues = [
                seriesString.split("=")[1]
                for seriesString in self.selectedDataSet.split(" ")
            ]  #values of the legend keys so you know what fit was associated with
            writer.writerow(seriesValues + [
                parametersResult[variable.name].value
                for variable in self.parametersList
            ] + [
                parametersResult[variable.name].stderr
                for variable in self.parametersList
            ])

    def _cycleAndFitButton_fired(self):
        logger.info("cycle and fit button pressed")
        self._fitButton_fired()
        self._saveFitButton_fired()
        self._usePreviousFitButton_fired()
        currentDataSetIndex = self.dataSets.keys().index(self.selectedDataSet)
        self.selectedDataSet = self.dataSets.keys()[currentDataSetIndex + 1]

    def _statisticsButton_fired(self):
        from scipy.stats import pearsonr
        xs, ys = self.dataSets[self.selectedDataSet]
        mean = scipy.mean(ys)
        median = scipy.median(ys)
        std = scipy.std(ys)
        minimum = scipy.nanmin(ys)
        maximum = scipy.nanmax(ys)
        peakToPeak = maximum - minimum
        pearsonCorrelation = pearsonr(xs, ys)
        resultString = "mean=%G , median=%G stdev =%G\nmin=%G,max=%G, pk-pk=%G\nPearson Correlation=(%G,%G)\n(stdev/mean)=%G" % (
            mean, median, std, minimum, maximum, peakToPeak,
            pearsonCorrelation[0], pearsonCorrelation[1], std / mean)
        self.statisticsString = resultString

    def getFitData(self):
        dataX = self.dataSets[self.selectedDataSet][0]
        # resample x data
        dataX = np.linspace(min(dataX), max(dataX), self.plotPoints)
        dataY = self.modelFitResult.eval(x=dataX)
        return dataX, dataY
예제 #20
0
class PhysicsProperties(traits.HasTraits):
    
    selectedElement = traits.Instance(element.Element)#default of Li6 set in init
    species = traits.Enum(*element.names)

    massATU = traits.Float(22.9897692807,label="mass (u)", desc="mass in atomic mass units")
    decayRateMHz = traits.Float(9.7946,label = u"Decay Rate \u0393 (MHz)", desc= "decay rate/ natural line width of 2S1/2 -> 2P3/2")
    crossSectionSigmaPlus = traits.Float(1.6573163925E-13, label=u"cross section \u03C3 + (m^2)",
                                         desc = "resonant cross section 2S1/2 -> 2P3/2. Warning not accurate for 6Li yet")
    scatteringLength = traits.Float(62.0, label="scattering length (a0)")
    IsatSigmaPlus = traits.Float(6.260021, width=10, label=u"Isat (mW/cm^2)",desc = "I sat sigma + 2S1/2 -> 2P3/2")
    
    TOFFromVariableBool = traits.Bool(True, label = "Use TOF variable?", desc = "Attempt to read TOF variable from latestSequence.xml. If found update the TOF variable automatically")
    TOFVariableName = traits.String("ImagingTimeTOFLi",label="variable name:", desc = "The name of the TOF variable in Experiment Control")
    timeOfFlightTimems = traits.Float(4.0, label = "TOF Time (ms)", desc = "Time of Flight time in ms. Should match experiment control")

    
    trapGradientXFromVariableBool = traits.Bool(True, label = "Use MTGradientX variable?", desc = "Attempt to read MTGradientX variable from latestSequence.xml. If found update the TOF variable automatically")
    trapGradientXVariableName = traits.String("MagneticTrapEvaporation2GradientX",label="variable name:", desc = "The name of the trapGradientX variable in Experiment Control")
    trapGradientX = traits.Float(20.0, label="Trap Gradient (small) (G/cm)", desc = "gradient of trap before time of flight. Smaller of the anti helmholtz gradients" )
    trapFrequencyXHz = traits.Float(100.0, label="Trap frequency X (Hz)", desc = "trap frequency in X direction in Hz")
    trapFrequencyYHz = traits.Float(100.0, label="Trap frequency Y (Hz)", desc = "trap frequency in Y direction in Hz") 
    trapFrequencyZHz = traits.Float(100.0, label="Trap frequency Z (Hz)", desc = "trap frequency in Z direction in Hz")
    imagingDetuningLinewidths = traits.Float(0.0, label= u"imaging detuning (\u0393)", desc = "imaging light detuning from resonance in units of linewidths")    
    
    inTrapSizeX = traits.Float(0.0, label="In trap Size X (pixels)", desc = "size of cloud in trap in x direciton in pixels. Use very short TOF to estimate" )    
    inTrapSizeY = traits.Float(0.0, label="In trap Size Y (pixels)", desc = "size of cloud in trap in y direciton in pixels. Use very short TOF to estimate" )    
    autoInTrapSizeBool = traits.Bool(False, label="Change TOFTime InTrap Calibration?", desc= "Allows user to change the TOF time for which the fit will automatically update the in trap size if the autoSetSize box is checked for Gaussian fit")
    inTrapSizeTOFTimems = traits.Float(0.2, label="In Trap Size TOFTime", desc= "If the TOF time is this value and the autoSetSize box is checked, then we will automatically update the size whenever the TOFTime equals this value" )
    
    pixelSize = traits.Float(9.0, label=u"Pixel Size (\u03BCm)", desc = "Actual pixel size of the camera (excluding magnification)")    
    magnification = traits.Float(0.5, label="Magnification", desc = "Magnification of the imaging system")
    binningSize = traits.Int(1, label=u"Binning Size (px)", desc = "Binning size; influences effective pixel size")
    
    latestSequenceFile = os.path.join("\\\\ursa","AQOGroupFolder","Experiment Humphry","Experiment Control And Software","currentSequence", "latestSequence.xml")    
    secondLatestSequenceFile = os.path.join("\\\\ursa","AQOGroupFolder","Experiment Humphry","Experiment Control And Software","currentSequence", "secondLatestSequence.xml")    
    
    traits_view = traitsui.View(
                    traitsui.HGroup(

                        traitsui.VGroup(
                            traitsui.Item("species"),
                            traitsui.Item("selectedElement",editor = traitsui.InstanceEditor(), style="custom", show_label=False),                
                            show_border=True, label = "Element Properties"
                        ),
                        traitsui.VGroup(
                            traitsui.HGroup(traitsui.Item("TOFFromVariableBool"),traitsui.Item("TOFVariableName", visible_when="TOFFromVariableBool"),traitsui.Item("timeOfFlightTimems",style="readonly",visible_when="(TOFFromVariableBool)"),traitsui.Item("timeOfFlightTimems",visible_when="(not TOFFromVariableBool)")),
                            traitsui.HGroup(traitsui.Item("trapGradientXFromVariableBool"),traitsui.Item("trapGradientXVariableName", visible_when="trapGradientXFromVariableBool"),traitsui.Item("trapGradientX",style="readonly",visible_when="(trapGradientXFromVariableBool)"),traitsui.Item("trapGradientX",visible_when="(not trapGradientXFromVariableBool)")),
                            traitsui.Item("trapFrequencyXHz"),
                            traitsui.Item("trapFrequencyYHz"),
                            traitsui.Item("trapFrequencyZHz"),
                            traitsui.Item("inTrapSizeX"),
                            traitsui.Item("inTrapSizeY"),
                            traitsui.HGroup(traitsui.Item("autoInTrapSizeBool"),traitsui.Item("inTrapSizeTOFTimems", visible_when="(autoInTrapSizeBool)")),
                             label="Experiment Parameters", show_border=True
                        ),
                        traitsui.VGroup(
                            traitsui.Item("imagingDetuningLinewidths"), 
                            traitsui.Item("pixelSize"),
                            traitsui.Item("binningSize"),
                            traitsui.Item("magnification"), label="Camera and Imaging", show_border=True
                        )
                    )                    
                    )
    
    
    def __init__(self, **traitsDict):
        super(PhysicsProperties, self).__init__(**traitsDict)
        # self.selectedElement = element.Li6#set default element
        self.species = element.Li6.nameID
        #pull some uselful variables from the physics constants dictionary for reference
        
        self.constants = scipy.constants.physical_constants
        self.u = self.constants["atomic mass constant"][0]
        self.bohrMagneton = self.constants["Bohr magneton"][0]
        self.bohrRadius = self.constants["Bohr radius"][0]
        self.kb = self.constants["Boltzmann constant"][0]
        self.joulesToKelvin = self.constants["joule-kelvin relationship"][0]
        self.hbar = self.constants["Planck constant over 2 pi"][0]
        self.h = self.constants["Planck constant"][0]
        self.joulesToHertz = self.constants["joule-hertz relationship"][0]
        self.hertzToKelvin =self.constants["hertz-kelvin relationship"][0]
        self.a0 = self.constants["Bohr radius"][0]
        
    def _species_changed(self):
        """update constants according to the changed species """
        logger.debug("species changed to %s" % self.species)
        self.selectedElement = element.elements[self.species]
    
    def updatePhysics(self):
        try:
            logger.debug("attempting to update physics from xml")
            if os.path.exists(self.latestSequenceFile):
                modifiedTime = os.path.getmtime(self.latestSequenceFile)
                imageTime = self.selectedFileModifiedTime
                now = time.time()
                timeDiff = imageTime-modifiedTime
                timeDiff += 31 # ToDo: debug this strange time offset!!
                if timeDiff>300.0: #>5min
                    logger.warning("Found a time difference of >5min between modification time of XML and of image. Are you sure the latest XML file is being updated? Check snake is running?")
                if timeDiff<0:
                    logger.error("Found very fresh sequence file. Probably read already variables of next sequence?")
                    logger.warning("Use second last sequence file instead..")
                    self.tree = ET.parse(self.secondLatestSequenceFile)
                else:
                    self.tree = ET.parse(self.latestSequenceFile)
                logger.warning("Age of sequence file: {}".format(timeDiff)) # for debugging, remove or reduce log level later ;P
                # logger.warning("Age of image file: {}".format(imageTime))
                # logger.warning("Now = {}".format(now)) # for debugging, remove or reduce log level later ;P
                # logger.warning("ModifiedTime of xml = {}".format(modifiedTime)) # for debugging, remove or reduce log level later ;P
                self.root = self.tree.getroot()
                variables = self.root.find("variables")
                self.variables = {child[0].text:float(child[1].text) for child in variables}
                logger.debug("Read a TOF time of %s from variables in XML " % self.variables[self.TOFVariableName])
            else:
                logger.error("Could not find latest xml File. cannot update physics.")
                return
        except Exception as e:
            logger.error("Error when trying to load XML %s" % e.message)
            return
            
        #update TOF Time
        if self.TOFFromVariableBool:
            logger.debug("attempting to update TOF time from xml")
            try:
                self.timeOfFlightTimems = self.variables[self.TOFVariableName]*1.0E3
            except KeyError as e:
                logger.error("incorrect variable name. No variable %s found. Using default 1ms" % self.TOFVariableName )
                self.timeOfFlightTimems = 1.0
            
        if self.trapGradientXFromVariableBool:
            logger.debug("attempting to update trapGradientX from xml")
            try:
                self.trapGradientX = self.variables[self.trapGradientXVariableName]
            except KeyError:
                logger.error("incorrect variable name. No variable %s found. Using default 50G/cm" % self.trapGradientXVariableName )
                self.trapGradientX = 50.0
예제 #21
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc="Workflow Description")
    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")

    # Execution

    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    timeout = traits.Float(14.0)
    # Subjects

    #subjects= traits.List(traits.Str, mandatory=True, usedefault=True,
    #    desc="Subject id's. Note: These MUST match the subject id's in the \
    #                            Freesurfer directory. For simplicity, the subject id's should \
    #                            also match with the location of individual functional files.")

    datagrabber = traits.Instance(Data, ())
    # First Level

    subjectinfo = traits.Code()
    contrasts = traits.Code()
    interscan_interval = traits.Float()
    film_threshold = traits.Float()
    input_units = traits.Enum('scans', 'secs')
    is_sparse = traits.Bool(False)
    model_hrf = traits.Bool(True)
    stimuli_as_impulses = traits.Bool(True)
    use_temporal_deriv = traits.Bool(True)
    volumes_in_cluster = traits.Int(1)
    ta = traits.Float()
    tr = traits.Float()
    hpcutoff = traits.Float()
    scan_onset = traits.Int(0)
    scale_regressors = traits.Bool(True)
    #bases = traits.Dict({'dgamma':{'derivs': False}},use_default=True)
    bases = traits.Dict(
        {'dgamma': {
            'derivs': False
        }}, use_default=True
    )  #traits.Enum('dgamma','gamma','none'), traits.Enum(traits.Dict(traits.Enum('derivs',None), traits.Bool),None), desc="name of basis function and options e.g., {'dgamma': {'derivs': True}}")

    # preprocessing info
    preproc_config = traits.File(desc="preproc config file")
    use_compcor = traits.Bool(desc="use noise components from CompCor")
    #advanced_options
    use_advanced_options = Bool(False)
    advanced_options = traits.Code()
    save_script_only = traits.Bool(False)
예제 #22
0
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_rows = tr.Int(4, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    force_name = tr.Str('Kraft')
    peak_force_before_cycles = tr.Float(30)
    plots_num = tr.Enum(1, 2, 3, 4, 6, 9)
    plot_list = tr.List()
    plot = tr.Button
    add_plot = tr.Button
    add_creep_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)

    figure = tr.Instance(Figure)

#     plots_list = tr.List(editor=ui.SetEditor(
#         values=['kumquats', 'pomegranates', 'kiwi'],
#         can_move_all=False,
#         left_column_title='List'))

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = ['*.csv', ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open', wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path

        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                nrows=1, header=None
            )
        )[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    #=========================================================================
    # Parameters of the filter algorithm
    #=========================================================================

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(len(self.columns_headers_list)):
            column_array = np.array(pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal, skiprows=self.skip_rows, usecols=[i]))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.columns_headers_list[i] + '.npy'), column_array)

        print('Finsihed parsing csv into npy files.')

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

#     def _add_columns_average_fired(self):
#         columns_average = ColumnsAverage(
#             columns_names=self.columns_headers_list)
#         # columns_average.set_columns_headers_list(self.columns_headers_list)
#         columns_average.configure_traits()

    def _generate_filtered_npy_fired(self):

        # 1- Export filtered force
        force = np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.force_name + '.npy')).flatten()
        peak_force_before_cycles_index = np.where(
            abs((force)) > abs(self.peak_force_before_cycles))[0][0]
        force_ascending = force[0:peak_force_before_cycles_index]
        force_rest = force[peak_force_before_cycles_index:]

        force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
            force_rest)

        force_max_min_indices = np.concatenate(
            (force_min_indices, force_max_indices))
        force_max_min_indices.sort()

        force_rest_filtered = force_rest[force_max_min_indices]
        force_filtered = np.concatenate((force_ascending, force_rest_filtered))
        np.save(os.path.join(self.npy_folder_path, self.file_name +
                             '_' + self.force_name + '_filtered.npy'), force_filtered)

        # 2- Export filtered displacements
        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):

                disp = np.load(os.path.join(self.npy_folder_path, self.file_name +
                                            '_' + self.columns_headers_list[i] + '.npy')).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]
                disp_ascending = savgol_filter(
                    disp_ascending, window_length=51, polyorder=2)
                disp_rest = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate((disp_ascending, disp_rest))
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_filtered.npy'), filtered_disp)

        # 3- Export creep for displacements
        # Cutting unwanted max min values to get correct full cycles and remove
        # false min/max values caused by noise
        force_max_indices_cutted, force_min_indices_cutted = self.cut_indices_in_range(force_rest,
                                                                                       force_max_indices,
                                                                                       force_min_indices,
                                                                                       self.force_max,
                                                                                       self.force_min)

        print("Cycles number= ", len(force_min_indices))
        print("Cycles number after cutting unwanted max-min range= ",
              len(force_min_indices_cutted))

        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):
                disp_rest_maxima = disp_rest[force_max_indices_cutted]
                disp_rest_minima = disp_rest[force_min_indices_cutted]
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_max.npy'), disp_rest_maxima)
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_min.npy'), disp_rest_minima)

        print('Filtered npy files are generated.')

    def cut_indices_in_range(self, array, max_indices, min_indices, range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = argrelextrema(input_array, np.greater_equal)[0]
            force_min_indices = argrelextrema(input_array, np.less_equal)[0]
        else:
            force_max_indices = argrelextrema(input_array, np.less_equal)[0]
            force_min_indices = argrelextrema(input_array, np.greater_equal)[0]

        # Remove subsequent max/min indices (np.greater_equal will give 1,2 for
        # [4, 8, 8, 1])
        force_max_indices = self.remove_subsequent_max_values(
            force_max_indices)
        force_min_indices = self.remove_subsequent_min_values(
            force_min_indices)

        # If size is not equal remove the last element from the big one
        if force_max_indices.size > force_min_indices.size:
            force_max_indices = force_max_indices[:-1]
        elif force_max_indices.size < force_min_indices.size:
            force_min_indices = force_min_indices[:-1]

        return force_max_indices, force_min_indices

    def remove_subsequent_max_values(self, force_max_indices):
        to_delete_from_maxima = []
        for i in range(force_max_indices.size - 1):
            if force_max_indices[i + 1] - force_max_indices[i] == 1:
                to_delete_from_maxima.append(i)

        force_max_indices = np.delete(force_max_indices, to_delete_from_maxima)
        return force_max_indices

    def remove_subsequent_min_values(self, force_min_indices):
        to_delete_from_minima = []
        for i in range(force_min_indices.size - 1):
            if force_min_indices[i + 1] - force_min_indices[i] == 1:
                to_delete_from_minima.append(i)
        force_min_indices = np.delete(force_min_indices, to_delete_from_minima)
        return force_min_indices

    #=========================================================================
    # Plotting
    #=========================================================================
    plot_figure_num = tr.Int(0)

    def _plot_fired(self):
        ax = self.figure.add_subplot()

    def x_plot_fired(self):
        self.plot_figure_num += 1
        plt.draw()
        plt.show()

    data_changed = tr.Event

    def _add_plot_fired(self):

        if False:  # (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        print('Loading npy files...')

        if self.apply_filters:
            x_axis_name = self.x_axis + '_filtered'
            y_axis_name = self.y_axis + '_filtered'
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_filtered.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '_filtered.npy'))
        else:
            x_axis_name = self.x_axis
            y_axis_name = self.y_axis
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '.npy'))

        print('Adding Plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

#        plt.figure(self.plot_figure_num)
        ax = self.figure.add_subplot(1, 1, 1)

        ax.set_xlabel('Displacement [mm]')
        ax.set_ylabel('kN')
        ax.set_title('Original data', fontsize=20)
        ax.plot(x_axis_array, y_axis_array, 'k', linewidth=0.8)

        self.plot_list.append('{}, {}'.format(x_axis_name, y_axis_name))
        self.data_changed = True
        print('Finished adding plot!')

    def apply_new_subplot(self):
        plt = self.figure
        if (self.plots_num == 1):
            plt.add_subplot(1, 1, 1)
        elif (self.plots_num == 2):
            plot_location = int('12' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 3):
            plot_location = int('13' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 4):
            plot_location = int('22' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 6):
            plot_location = int('23' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 9):
            plot_location = int('33' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)

    def _add_creep_plot_fired(self):

        plt = self.figure
        if (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        disp_max = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_max.npy'))
        disp_min = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_min.npy'))

        print('Adding creep plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        self.apply_new_subplot()
        plt.xlabel('Cycles number')
        plt.ylabel('mm')
        plt.title('Fatigue creep curve', fontsize=20)
        plt.plot(np.arange(0, disp_max.size), disp_max,
                 'k', linewidth=0.8, color='red')
        plt.plot(np.arange(0, disp_min.size), disp_min,
                 'k', linewidth=0.8, color='green')

        self.plot_list.append('Plot {}'.format(len(self.plot_list) + 1))

        print('Finished adding creep plot!')

    #=========================================================================
    # Configuration of the view
    #=========================================================================

    traits_view = ui.View(
        ui.HSplit(
            ui.VSplit(
                ui.HGroup(
                    ui.UItem('open_file_csv'),
                    ui.UItem('file_csv', style='readonly'),
                    label='Input data'
                ),
                ui.Item('add_columns_average', show_label=False),
                ui.VGroup(
                    ui.Item('skip_rows'),
                    ui.Item('decimal'),
                    ui.Item('delimiter'),
                    ui.Item('parse_csv_to_npy', show_label=False),
                    label='Filter parameters'
                ),
                ui.VGroup(
                    ui.Item('plots_num'),
                    ui.HGroup(ui.Item('x_axis'), ui.Item('x_axis_multiplier')),
                    ui.HGroup(ui.Item('y_axis'), ui.Item('y_axis_multiplier')),
                    ui.HGroup(ui.Item('add_plot', show_label=False),
                              ui.Item('apply_filters')),
                    ui.HGroup(ui.Item('add_creep_plot', show_label=False)),
                    ui.Item('plot_list'),
                    ui.Item('plot', show_label=False),
                    show_border=True,
                    label='Plotting settings'),
            ),
            ui.VGroup(
                ui.Item('force_name'),
                ui.HGroup(ui.Item('peak_force_before_cycles'),
                          show_border=True, label='Skip noise of ascending branch:'),
                #                     ui.Item('plots_list'),
                ui.VGroup(ui.Item('force_max'),
                          ui.Item('force_min'),
                          show_border=True,
                          label='Cut fake cycles for creep:'),
                ui.Item('generate_filtered_npy', show_label=False),
                show_border=True,
                label='Filters'
            ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     width=0.3,
                     label='2d plots'),
        ),
        title='HCFF Filter',
        resizable=True,
        width=0.6,
        height=0.6

    )
class config(HasTraits):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc='Workflow description')
    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    field_dir = Directory(
        desc="Base directory of field-map data (Should be subject-independent) \
                                                 Set this value to None if you don't want fieldmap distortion correction"
    )
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")

    # Execution

    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "PBSGraph",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    # Subjects

    subjects = traits.List(
        traits.Str,
        mandatory=True,
        usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files."
    )
    func_template = traits.String('%s/functional.nii.gz')
    run_datagrabber_without_submitting = traits.Bool(
        desc="Run the datagrabber without \
    submitting to the cluster")
    timepoints_to_remove = traits.Int(0, usedefault=True)

    do_slicetiming = Bool(True,
                          usedefault=True,
                          desc="Perform slice timing correction")
    SliceOrder = traits.List(traits.Int)
    order = traits.Enum('motion_slicetime',
                        'slicetime_motion',
                        use_default=True)
    TR = traits.Float(mandatory=True, desc="TR of functional")
    motion_correct_node = traits.Enum(
        'nipy',
        'fsl',
        'spm',
        'afni',
        desc="motion correction algorithm to use",
        usedefault=True,
    )

    csf_prob = traits.File(desc='CSF_prob_map')
    grey_prob = traits.File(desc='grey_prob_map')
    white_prob = traits.File(desc='white_prob_map')
    # Artifact Detection

    norm_thresh = traits.Float(1,
                               min=0,
                               usedefault=True,
                               desc="norm thresh for art")
    z_thresh = traits.Float(3, min=0, usedefault=True, desc="z thresh for art")

    # Smoothing
    fwhm = traits.Float(6.0, usedefault=True)
    save_script_only = traits.Bool(False)
    check_func_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.base_dir, self.func_template % s)):
                print "ERROR", os.path.join(self.base_dir, self.func_template %
                                            s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.func_template % s), "exists!"
예제 #24
0
class FileImportManager(tr.HasTraits):
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    skip_rows = tr.Int(4, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])

    parse_csv_to_npy = tr.Button

    view = ui.View(ui.VGroup(
        ui.HGroup(
            ui.UItem('open_file_csv'),
            ui.UItem('file_csv', style='readonly'),
        ),
        ui.Item('skip_rows'),
        ui.Item('decimal'),
        ui.Item('delimiter'),
        ui.Item('parse_csv_to_npy', show_label=False),
    ))

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = ['*.csv', ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open', wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path

        """ Fill columns_headers_list """
        headers_array = np.array(
            pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                nrows=1, header=None
            )
        )[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(len(self.columns_headers_list)):
            column_array = np.array(pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal, skiprows=self.skip_rows, usecols=[i]))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.columns_headers_list[i] + '.npy'), column_array)

        print('Finsihed parsing csv into npy files.')
예제 #25
0
파일: axes.py 프로젝트: malliwi88/hyperspy
class DataAxis(t.HasTraits):
    name = t.Str()
    units = t.Str()
    scale = t.Float()
    offset = t.Float()
    size = t.CInt()
    low_value = t.Float()
    high_value = t.Float()
    value = t.Range('low_value', 'high_value')
    low_index = t.Int(0)
    high_index = t.Int()
    slice = t.Instance(slice)
    navigate = t.Bool(t.Undefined)
    index = t.Range('low_index', 'high_index')
    axis = t.Array()
    continuous_value = t.Bool(False)

    def __init__(self,
                 size,
                 index_in_array=None,
                 name=t.Undefined,
                 scale=1.,
                 offset=0.,
                 units=t.Undefined,
                 navigate=t.Undefined):
        super(DataAxis, self).__init__()
        self.events = Events()
        self.events.index_changed = Event("""
            Event that triggers when the index of the `DataAxis` changes

            Triggers after the internal state of the `DataAxis` has been
            updated.

            Arguments:
            ---------
            obj : The DataAxis that the event belongs to.
            index : The new index
            """,
                                          arguments=["obj", 'index'])
        self.events.value_changed = Event("""
            Event that triggers when the value of the `DataAxis` changes

            Triggers after the internal state of the `DataAxis` has been
            updated.

            Arguments:
            ---------
            obj : The DataAxis that the event belongs to.
            value : The new value
            """,
                                          arguments=["obj", 'value'])
        self._suppress_value_changed_trigger = False
        self._suppress_update_value = False
        self.name = name
        self.units = units
        self.scale = scale
        self.offset = offset
        self.size = size
        self.high_index = self.size - 1
        self.low_index = 0
        self.index = 0
        self.update_axis()
        self.navigate = navigate
        self.axes_manager = None
        self.on_trait_change(self.update_axis, ['scale', 'offset', 'size'])
        self.on_trait_change(self._update_slice, 'navigate')
        self.on_trait_change(self.update_index_bounds, 'size')
        # The slice must be updated even if the default value did not
        # change to correctly set its value.
        self._update_slice(self.navigate)

    def _index_changed(self, name, old, new):
        self.events.index_changed.trigger(obj=self, index=self.index)
        if not self._suppress_update_value:
            new_value = self.axis[self.index]
            if new_value != self.value:
                self.value = new_value

    def _value_changed(self, name, old, new):
        old_index = self.index
        new_index = self.value2index(new)
        if self.continuous_value is False:  # Only values in the grid alowed
            if old_index != new_index:
                self.index = new_index
                if new == self.axis[self.index]:
                    self.events.value_changed.trigger(obj=self, value=new)
            elif old_index == new_index:
                new_value = self.index2value(new_index)
                if new_value == old:
                    self._suppress_value_changed_trigger = True
                    try:
                        self.value = new_value
                    finally:
                        self._suppress_value_changed_trigger = False

                elif new_value == new and not\
                        self._suppress_value_changed_trigger:
                    self.events.value_changed.trigger(obj=self, value=new)
        else:  # Intergrid values are alowed. This feature is deprecated
            self.events.value_changed.trigger(obj=self, value=new)
            if old_index != new_index:
                self._suppress_update_value = True
                self.index = new_index
                self._suppress_update_value = False

    @property
    def index_in_array(self):
        if self.axes_manager is not None:
            return self.axes_manager._axes.index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    @property
    def index_in_axes_manager(self):
        if self.axes_manager is not None:
            return self.axes_manager._get_axes_in_natural_order().\
                index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    def _get_positive_index(self, index):
        if index < 0:
            index = self.size + index
            if index < 0:
                raise IndexError("index out of bounds")
        return index

    def _get_index(self, value):
        if isfloat(value):
            return self.value2index(value)
        else:
            return value

    def _get_array_slices(self, slice_):
        """Returns a slice to slice the corresponding data axis without
        changing the offset and scale of the DataAxis.

        Parameters
        ----------
        slice_ : {float, int, slice}

        Returns
        -------
        my_slice : slice

        """
        v2i = self.value2index

        if isinstance(slice_, slice):
            start = slice_.start
            stop = slice_.stop
            step = slice_.step
        else:
            if isfloat(slice_):
                start = v2i(slice_)
            else:
                start = self._get_positive_index(slice_)
            stop = start + 1
            step = None

        if isfloat(step):
            step = int(round(step / self.scale))
        if isfloat(start):
            try:
                start = v2i(start)
            except ValueError:
                if start > self.high_value:
                    # The start value is above the axis limit
                    raise IndexError(
                        "Start value above axis high bound for  axis %s."
                        "value: %f high_bound: %f" %
                        (repr(self), start, self.high_value))
                else:
                    # The start value is below the axis limit,
                    # we slice from the start.
                    start = None
        if isfloat(stop):
            try:
                stop = v2i(stop)
            except ValueError:
                if stop < self.low_value:
                    # The stop value is below the axis limits
                    raise IndexError(
                        "Stop value below axis low bound for  axis %s."
                        "value: %f low_bound: %f" %
                        (repr(self), stop, self.low_value))
                else:
                    # The stop value is below the axis limit,
                    # we slice until the end.
                    stop = None

        if step == 0:
            raise ValueError("slice step cannot be zero")

        return slice(start, stop, step)

    def _slice_me(self, slice_):
        """Returns a slice to slice the corresponding data axis and
        change the offset and scale of the DataAxis acordingly.

        Parameters
        ----------
        slice_ : {float, int, slice}

        Returns
        -------
        my_slice : slice

        """
        i2v = self.index2value

        my_slice = self._get_array_slices(slice_)

        start, stop, step = my_slice.start, my_slice.stop, my_slice.step

        if start is None:
            if step is None or step > 0:
                start = 0
            else:
                start = self.size - 1
        self.offset = i2v(start)
        if step is not None:
            self.scale *= step

        return my_slice

    def _get_name(self):
        if self.name is t.Undefined:
            if self.axes_manager is None:
                name = "Unnamed"
            else:
                name = "Unnamed " + ordinal(self.index_in_axes_manager)
        else:
            name = self.name
        return name

    def __repr__(self):
        text = '<%s axis, size: %i' % (
            self._get_name(),
            self.size,
        )
        if self.navigate is True:
            text += ", index: %i" % self.index
        text += ">"
        return text

    def __str__(self):
        return self._get_name() + " axis"

    def update_index_bounds(self):
        self.high_index = self.size - 1

    def update_axis(self):
        self.axis = generate_axis(self.offset, self.scale, self.size)
        if len(self.axis) != 0:
            self.low_value, self.high_value = (self.axis.min(),
                                               self.axis.max())

    def _update_slice(self, value):
        if value is False:
            self.slice = slice(None)
        else:
            self.slice = None

    def get_axis_dictionary(self):
        adict = {
            'name': self.name,
            'scale': self.scale,
            'offset': self.offset,
            'size': self.size,
            'units': self.units,
            'navigate': self.navigate
        }
        return adict

    def copy(self):
        return DataAxis(**self.get_axis_dictionary())

    def __copy__(self):
        return self.copy()

    def __deepcopy__(self, memo):
        cp = self.copy()
        return cp

    def value2index(self, value, rounding=round):
        """Return the closest index to the given value if between the limit.

        Parameters
        ----------
        value : number or numpy array

        Returns
        -------
        index : integer or numpy array

        Raises
        ------
        ValueError if any value is out of the axis limits.

        """
        if value is None:
            return None

        if isinstance(value, np.ndarray):
            if rounding is round:
                rounding = np.round
            elif rounding is math.ceil:
                rounding = np.ceil
            elif rounding is math.floor:
                rounding = np.floor

        index = rounding((value - self.offset) / self.scale)

        if isinstance(value, np.ndarray):
            index = index.astype(int)
            if np.all(self.size > index) and np.all(index >= 0):
                return index
            else:
                raise ValueError("A value is out of the axis limits")
        else:
            index = int(index)
            if self.size > index >= 0:
                return index
            else:
                raise ValueError("The value is out of the axis limits")

    def index2value(self, index):
        if isinstance(index, np.ndarray):
            return self.axis[index.ravel()].reshape(index.shape)
        else:
            return self.axis[index]

    def calibrate(self, value_tuple, index_tuple, modify_calibration=True):
        scale = (value_tuple[1] - value_tuple[0]) /\
            (index_tuple[1] - index_tuple[0])
        offset = value_tuple[0] - scale * index_tuple[0]
        if modify_calibration is True:
            self.offset = offset
            self.scale = scale
        else:
            return offset, scale

    def value_range_to_indices(self, v1, v2):
        """Convert the given range to index range.

        When an out of the axis limits, the endpoint is used instead.

        Parameters
        ----------
        v1, v2 : float
            The end points of the interval in the axis units. v2 must be
            greater than v1.

        """
        if v1 is not None and v2 is not None and v1 > v2:
            raise ValueError("v2 must be greater than v1.")

        if v1 is not None and self.low_value < v1 <= self.high_value:
            i1 = self.value2index(v1)
        else:
            i1 = 0
        if v2 is not None and self.high_value > v2 >= self.low_value:
            i2 = self.value2index(v2)
        else:
            i2 = self.size - 1
        return i1, i2

    def update_from(self, axis, attributes=["scale", "offset", "units"]):
        """Copy values of specified axes fields from the passed AxesManager.

        Parameters
        ----------
        axis : DataAxis
            The DataAxis instance to use as a source for values.
        attributes : iterable container of strings.
            The name of the attribute to update. If the attribute does not
            exist in either of the AxesManagers, an AttributeError will be
            raised.
        Returns
        -------
        A boolean indicating whether any changes were made.

        """
        any_changes = False
        changed = {}
        for f in attributes:
            if getattr(self, f) != getattr(axis, f):
                changed[f] = getattr(axis, f)
        if len(changed) > 0:
            self.trait_set(**changed)
            any_changes = True
        return any_changes
예제 #26
0
파일: CXDVizNX.py 프로젝트: bfrosik/cdi
class CXDViz(tr.HasTraits):
    coords = tr.Array()
    arr = tr.Array()

    cropx = tr.Int()
    cropy = tr.Int()
    cropz = tr.Int()

    def __init__(self):
        self.imd = tvtk.ImageData()
        self.sg = tvtk.StructuredGrid()
        pass

    def set_geometry(self, params, shape):
        lam = params.lamda
        tth = params.delta
        gam = params.gamma
        dpx = params.dpx
        dpy = params.dpy
        dth = params.dth
        dx = 1.0 / shape[0]
        dy = 1.0 / shape[1]
        dz = 1.0 / shape[2]
        dQdpx = np.zeros(3)
        dQdpy = np.zeros(3)
        dQdth = np.zeros(3)
        Astar = np.zeros(3)
        Bstar = np.zeros(3)
        Cstar = np.zeros(3)

        dQdpx[0] = -m.cos(tth) * m.cos(gam)
        dQdpx[1] = 0.0
        dQdpx[2] = +m.sin(tth) * m.cos(gam)

        dQdpy[0] = m.sin(tth) * m.sin(gam)
        dQdpy[1] = -m.cos(gam)
        dQdpy[2] = m.cos(tth) * m.sin(gam)

        dQdth[0] = -m.cos(tth) * m.cos(gam) + 1.0
        dQdth[1] = 0.0
        dQdth[2] = m.sin(tth) * m.cos(gam)

        Astar[0] = 2 * m.pi / lam * dpx * dQdpx[0]
        Astar[1] = 2 * m.pi / lam * dpx * dQdpx[1]
        Astar[2] = 2 * m.pi / lam * dpx * dQdpx[2]

        Bstar[0] = (2 * m.pi / lam) * dpy * dQdpy[0]
        Bstar[1] = (2 * m.pi / lam) * dpy * dQdpy[1]
        Bstar[2] = (2 * m.pi / lam) * dpy * dQdpy[2]

        Cstar[0] = (2 * m.pi / lam) * dth * dQdth[0]
        Cstar[1] = (2 * m.pi / lam) * dth * dQdth[1]
        Cstar[2] = (2 * m.pi / lam) * dth * dQdth[2]

        denom = np.dot(Astar, np.cross(Bstar, Cstar))
        A = 2 * m.pi * np.cross(Bstar, Cstar) / denom
        B = 2 * m.pi * np.cross(Cstar, Astar) / denom
        C = 2 * m.pi * np.cross(Astar, Bstar) / denom

        self.T = np.zeros(9)
        self.T.shape = (3, 3)
        space = 'direct'
        if space == 'recip':
            self.T[:, 0] = Astar
            self.T[:, 1] = Bstar
            self.T[:, 2] = Cstar
            self.dx = 1.0
            self.dy = 1.0
            self.dz = 1.0
        elif space == 'direct':
            self.T = np.array((A, B, C))
            self.dx = dx
            self.dy = dy
            self.dz = dz
        else:
            pass

    def update_coords(self):
        dims = list(self.arr[self.cropobj].shape)

        r = np.mgrid[(dims[0] - 1) * self.dx:-self.dx:-self.dx, \
            0:dims[1] * self.dy:self.dy, 0:dims[2] * self.dz:self.dz]

        r.shape = 3, dims[0] * dims[1] * dims[2]
        r = r.transpose()

        print r.shape
        print self.T.shape

        self.coords = np.dot(r, self.T)

    def set_array(self, array, logentry=None):
        self.arr = array
        if len(self.arr.shape) < 3:
            newdims = list(self.arr.shape)
            for i in range(3 - len(newdims)):
                newdims.append(1)
            self.arr.shape = tuple(newdims)

    def set_crop(self, cropx, cropy, cropz):
        dims = list(self.arr.shape)
        if len(dims) == 2:
            dims.append(1)

        if dims[0] > cropx and cropx > 0:
            self.cropx = cropx
        else:
            self.cropx = dims[0]

        if dims[1] > cropy and cropy > 0:
            self.cropy = cropy
        else:
            self.cropy = dims[1]

        if dims[2] > cropz and cropz > 0:
            self.cropz = cropz
        else:
            self.cropz = dims[2]

        start1 = dims[0] / 2 - self.cropx / 2
        end1 = dims[0] / 2 + self.cropx / 2
        if start1 == end1:
            end1 = end1 + 1
        start2 = dims[1] / 2 - self.cropy / 2
        end2 = dims[1] / 2 + self.cropy / 2
        if start2 == end2:
            end2 = end2 + 1
        start3 = dims[2] / 2 - self.cropz / 2
        end3 = dims[2] / 2 + self.cropz / 2
        if start3 == end3:
            end3 = end3 + 1

        self.cropobj = (slice(start1, end1, None), slice(start2, end2, None),
                        slice(start3, end3, None))

    def get_structured_grid(self, **args):
        self.update_coords()
        dims = list(self.arr[self.cropobj].shape)
        self.sg.points = self.coords
        if args.has_key("mode"):
            if args["mode"] == "Phase":
                arr1 = self.arr[self.cropobj].ravel()
                arr = (np.arctan2(arr1.imag, arr1.real))
            else:
                arr = np.abs(self.arr[self.cropobj].ravel())
        else:
            arr = self.arr[self.cropobj].ravel()
        if (arr.dtype == np.complex128 or arr.dtype == np.complex64):
            self.sg.point_data.scalars = np.abs(arr)
            self.sg.point_data.scalars.name = "Amp"
            ph = tvtk.DoubleArray()
            ph.from_array(np.arctan2(arr.imag, arr.real))
            ph.name = "Phase"
            self.sg.point_data.add_array(ph)
        else:
            self.sg.point_data.scalars = arr
        self.sg.dimensions = (dims[2], dims[1], dims[0])
        self.sg.extent = 0, dims[2] - 1, 0, dims[1] - 1, 0, dims[0] - 1
        return self.sg

    def get_image_data(self, **args):
        self.set_crop(self.cropx, self.cropy, self.cropz)
        dims = list(self.arr[self.cropobj].shape)
        if len(dims) == 2:
            dims.append(1)
        self.imd.dimensions = tuple(dims)
        self.imd.extent = 0, dims[2] - 1, 0, dims[1] - 1, 0, dims[0] - 1
        self.imd.point_data.scalars = self.arr[self.cropobj].ravel()
        return self.imd

    def write_structured_grid(self, filename, **args):
        print 'in WriteStructuredGrid'
        sgwriter = tvtk.StructuredGridWriter()
        sgwriter.file_type = 'binary'
        if filename.endswith(".vtk"):
            sgwriter.file_name = filename
        else:
            sgwriter.file_name = filename + '.vtk'

        sgwriter.set_input_data(self.get_structured_grid())
        print sgwriter.file_name
        sgwriter.write()
예제 #27
0
 class spec2(nib.CommandLineInputSpec):
     moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
                    position=2)
     doo = nib.File(exists=True, argstr="%s", position=1)
     goo = traits.Int(argstr="%d", position=4)
     poo = nib.File(name_source=['goo'], hash_files=False, argstr="%s",position=3)
예제 #28
0
class SpikesRemoval(SpanSelectorInSignal1D):
    interpolator_kind = t.Enum(
        'Linear',
        'Spline',
        default='Linear',
        desc="the type of interpolation to use when\n"
             "replacing the signal where a spike has been replaced")
    threshold = t.Float(desc="the derivative magnitude threshold above\n"
                        "which to find spikes")
    click_to_show_instructions = t.Button()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1, 10, 3,
                           desc="the order of the spline used to\n"
                           "connect the reconstructed data")
    interpolator = None
    default_spike_width = t.Int(5,
                                desc="the width over which to do the interpolation\n"
                                "when removing a spike (this can be "
                                "adjusted for each\nspike by clicking "
                                     "and dragging on the display during\n"
                                     "spike replacement)")
    index = t.Int(0)
    add_noise = t.Bool(True,
                       desc="whether to add noise to the interpolated\nportion"
                       "of the spectrum. The noise properties defined\n"
                       "in the Signal metadata are used if present,"
                            "otherwise\nshot noise is used as a default")

    thisOKButton = tu.Action(name="OK",
                             action="OK",
                             tooltip="Close the spikes removal tool")

    thisApplyButton = tu.Action(name="Remove spike",
                                action="apply",
                                tooltip="Remove the current spike by "
                                "interpolating\n"
                                       "with the specified settings (and find\n"
                                       "the next spike automatically)")
    thisFindButton = tu.Action(name="Find next",
                               action="find",
                               tooltip="Find the next (in terms of navigation\n"
                               "dimensions) spike in the data.")

    thisPreviousButton = tu.Action(name="Find previous",
                                   action="back",
                                   tooltip="Find the previous (in terms of "
                                   "navigation\n"
                                          "dimensions) spike in the data.")
    view = tu.View(tu.Group(
        tu.Group(
            tu.Item('click_to_show_instructions',
                    show_label=False, ),
            tu.Item('show_derivative_histogram',
                    show_label=False,
                    tooltip="To determine the appropriate threshold,\n"
                            "plot the derivative magnitude histogram, \n"
                            "and look for outliers at high magnitudes \n"
                            "(which represent sudden spikes in the data)"),
            'threshold',
            show_border=True,
        ),
        tu.Group(
            'add_noise',
            'interpolator_kind',
            'default_spike_width',
            tu.Group(
                'spline_order',
                enabled_when='interpolator_kind == \'Spline\''),
            show_border=True,
            label='Advanced settings'),
    ),
        buttons=[thisOKButton,
                 thisPreviousButton,
                 thisFindButton,
                 thisApplyButton, ],
        handler=SpikesRemovalHandler,
        title='Spikes removal tool',
        resizable=False,
    )

    def __init__(self, signal, navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [coordinate for coordinate in
                            signal.axes_manager._am_indices_generator()
                            if (navigation_mask is None or not
                                navigation_mask[coordinate[::-1]])]
        self.signal = signal
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        if len(self.coordinates) > 1:
            signal.axes_manager.indices = self.coordinates[0]
        self.threshold = 400
        self.index = 0
        self.argmax = None
        self.derivmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask
        md = self.signal.metadata
        from hyperspy.signal import BaseSignal

        if "Signal.Noise_properties" in md:
            if "Signal.Noise_properties.variance" in md:
                self.noise_variance = md.Signal.Noise_properties.variance
                if isinstance(md.Signal.Noise_properties.variance, BaseSignal):
                    self.noise_type = "heteroscedastic"
                else:
                    self.noise_type = "white"
            else:
                self.noise_type = "shot noise"
        else:
            self.noise_type = "shot noise"

    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()

    def _click_to_show_instructions_fired(self):
        m = information(None,
                        "\nTo remove spikes from the data:\n\n"

                        "   1. Click \"Show derivative histogram\" to "
                        "determine at what magnitude the spikes are present.\n"
                        "   2. Enter a suitable threshold (lower than the "
                        "lowest magnitude outlier in the histogram) in the "
                        "\"Threshold\" box, which will be the magnitude "
                        "from which to search. \n"
                        "   3. Click \"Find next\" to find the first spike.\n"
                        "   4. If desired, the width and position of the "
                        "boundaries used to replace the spike can be "
                        "adjusted by clicking and dragging on the displayed "
                        "plot.\n "
                        "   5. View the spike (and the replacement data that "
                        "will be added) and click \"Remove spike\" in order "
                        "to alter the data as shown. The tool will "
                        "automatically find the next spike to replace.\n"
                        "   6. Repeat this process for each spike throughout "
                        "the dataset, until the end of the dataset is "
                        "reached.\n"
                        "   7. Click \"OK\" when finished to close the spikes "
                        "removal tool.\n\n"

                        "Note: Various settings can be configured in "
                        "the \"Advanced settings\" section. Hover the "
                        "mouse over each parameter for a description of what "
                        "it does."

                        "\n",
                        title="Instructions"),

    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                      navigation_mask=self.navigation_mask)

    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            self.derivmax = abs(derivative.max())
            return True
        else:
            return False

    def _reset_line(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()

    def find(self, back=False):
        self._reset_line()
        ncoordinates = len(self.coordinates)
        spike = self.detect_spike()
        while not spike and (
                (self.index < ncoordinates - 1 and back is False) or
                (self.index > 0 and back is True)):
            if back is False:
                self.index += 1
            else:
                self.index -= 1
            spike = self.detect_spike()

        if spike is False:
            messages.information('End of dataset reached')
            self.index = 0
            self._reset_line()
            return
        else:
            minimum = max(0, self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            thresh_label = DerivativeTextParameters(
                text=r"$\mathsf{\delta}_\mathsf{max}=$",
                color="black")
            self.ax.legend([thresh_label], [repr(int(self.derivmax))],
                           handler_map={DerivativeTextParameters:
                                        DerivativeTextHandler()},
                           loc='best')
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(
                    minimum),
                self.signal.axes_manager.signal_axes[0].index2value(
                    maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        if len(self.coordinates) > 1:
            self.signal._plot.pointer._update_patch_position()

    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False

    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False

    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None

    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()

    def _add_noise_changed(self, old, new):
        self.span_selector_changed()

    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def create_interpolation_line(self):
        self.interpolated_line = drawing.signal1d.Signal1DLine()
        self.interpolated_line.data_function = self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(
            color='blue',
            type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()

    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if np.isnan(self.ss_left_value) or np.isnan(self.ss_right_value):
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)

        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1

        return left, right

    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = 10
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = int(np.clip(left, 0, len(data)))
        right = int(np.clip(right, 0, len(data)))
        x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
        y = np.hstack((data[ileft:left], data[right:iright]))
        if ileft == 0:
            # Extrapolate to the left
            data[left:right] = data[right + 1]

        elif iright == (len(data) - 1):
            # Extrapolate to the right
            data[left:right] = data[left - 1]

        else:
            # Interpolate
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])

        # Add noise
        if self.add_noise is True:
            if self.noise_type == "white":
                data[left:right] += np.random.normal(
                    scale=np.sqrt(self.noise_variance),
                    size=right - left)
            elif self.noise_type == "heteroscedastic":
                noise_variance = self.noise_variance(
                    axes_manager=self.signal.axes_manager)[left:right]
                noise = [np.random.normal(scale=np.sqrt(item))
                         for item in noise_variance]
                data[left:right] += noise
            else:
                data[left:right] = np.random.poisson(
                    np.clip(data[left:right], 0, np.inf))

        return data

    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()

    def apply(self):
        self.signal()[:] = self.get_interpolated_spectrum()
        self.signal.events.data_changed.trigger(obj=self.signal)
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
예제 #29
0
class FloatRangeEditor(EditorFactory):
    low = tr.Float
    high = tr.Float
    low_name = tr.Str
    high_name = tr.Str
    n_steps = tr.Int(20)
    n_steps_name = tr.Str
    continuous_update = tr.Bool(False)
    readout = tr.Bool(True)
    readout_format = tr.Str

    def render(self):
        if self.low_name:
            self.low = getattr(self.model, str(self.low_name))
        if self.high_name:
            self.high = getattr(self.model, str(self.high_name))
        if self.n_steps_name:
            self.n_steps = getattr(self.model, str(self.n_steps_name))
        step = (self.high - self.low) / self.n_steps

        round_value = self._get_round_value(self.low, self.high, self.n_steps)
        if not self.readout_format:
            self.readout_format = '.' + str(round_value) + 'f'

        # There's a bug in FloatSlider for very small step, see https://github.com/jupyter-widgets/ipywidgets/issues/259
        # it will be fixed in ipywidgets v8.0.0, but until then, the following fix will be used
        # with this implementation, entering the number manually in the readout will not work
        values = np.linspace(self.low, self.high, int(self.n_steps))
        values = np.round(values, round_value)

        # This is for SelectionSlider because 'value' must match exactly one of values array
        self.value = self._find_nearest(values, self.value)

        return ipw.SelectionSlider(
            options=values,
            value=self.value,
            tooltip=self.tooltip,
            continuous_update=self.continuous_update,
            description=self.label,
            disabled=self.disabled,
            readout=self.readout,
            style=style
        )

        # return ipw.FloatSlider(
        #     value=self.value,
        #     min=self.low,
        #     max=self.high,
        #     step=step,
        #     tooltip=self.tooltip,
        #     continuous_update=self.continuous_update,
        #     description=self.label,
        #     disabled=self.disabled,
        #     readout=self.readout,
        #     readout_format=self.readout_format
        # )

    def _find_nearest(self, array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return array[idx]

    def _get_round_value(self, low, high, n_steps):
        magnitude_n_steps = self._get_order_of_magnitude(n_steps)
        magnitude_low = self._get_order_of_magnitude(low)
        magnitude_high = self._get_order_of_magnitude(high)
        min_magnitude = min(magnitude_low, magnitude_high)
        if min_magnitude >= 0:
            req_decimals = 2
        else:
            req_decimals = abs(min_magnitude) + magnitude_n_steps
        return req_decimals

    def _get_order_of_magnitude(self, num):
        sci_num = '{:.1e}'.format(num)
        sci_num_suffix = sci_num.split('e')[1]
        return int(sci_num_suffix)
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    records_per_second = tr.Float(100)
    take_time_from_first_column = tr.Bool
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_first_rows = tr.Int(3, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    normalize_cycles = tr.Bool
    smooth = tr.Bool
    plot_every_nth_point = tr.Range(low=1, high=1000000, mode='spinner')
    force_name = tr.Str('Kraft')
    old_peak_force_before_cycles = tr.Float
    peak_force_before_cycles = tr.Float
    window_length = tr.Int(31)
    polynomial_order = tr.Int(2)
    activate = tr.Bool(False)
    plots_num = tr.Enum(1, 2, 3, 4, 6, 9)
    plot_list = tr.List()
    add_plot = tr.Button
    add_creep_plot = tr.Button(desc='Creep plot of X axis array')
    clear_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_and_creep_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)
    min_cycle_force_range = tr.Float(50)
    cutting_method = tr.Enum('Define min cycle range(force difference)',
                             'Define Max, Min')
    columns_to_be_averaged = tr.List

    figure = tr.Instance(Figure)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):

        self.reset()
        """ Handles the user clicking the 'Open...' button.
        """
        extns = [
            '*.csv',
        ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open',
                            wildcard=wildcard,
                            default_path=self.file_csv)

        result = dialog.open()
        """ Test if the user opened a file to avoid throwing an exception if he 
        doesn't """
        if result == OK:
            self.file_csv = dialog.path
        else:
            return
        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(self.file_csv,
                        delimiter=self.delimiter,
                        decimal=self.decimal,
                        nrows=1,
                        header=None))[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)
        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(
                len(self.columns_headers_list) -
                len(self.columns_to_be_averaged)):
            column_array = np.array(
                pd.read_csv(self.file_csv,
                            delimiter=self.delimiter,
                            decimal=self.decimal,
                            skiprows=self.skip_first_rows,
                            usecols=[i]))
            """ TODO! Create time array supposing it's column is the
            first one in the file and that we have 100 reads in 1 second """
            if i == 0 and self.take_time_from_first_column == False:
                column_array = np.arange(start=0.0,
                                         stop=len(column_array) /
                                         self.records_per_second,
                                         step=1.0 / self.records_per_second)

            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '.npy'), column_array)
        """ Exporting npy arrays of averaged columns """
        for columns_names in self.columns_to_be_averaged:
            temp = np.zeros((1))
            for column_name in columns_names:
                temp = temp + np.load(
                    os.path.join(self.npy_folder_path, self.file_name + '_' +
                                 column_name + '.npy')).flatten()
            avg = temp / len(columns_names)

            avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                columns_names)
            np.save(
                os.path.join(self.npy_folder_path,
                             self.file_name + '_' + avg_file_suffex + '.npy'),
                avg)

        print('Finsihed parsing csv into npy files.')

    def get_suffex_for_columns_to_be_averaged(self, columns_names):
        suffex_for_saved_file_name = 'avg_' + '_'.join(columns_names)
        return suffex_for_saved_file_name

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(c for c in original_file_name
                                      if c in valid_chars)
        return new_valid_file_name

    def _clear_plot_fired(self):
        self.figure.clear()
        self.plot_list = []
        self.data_changed = True

    def _add_columns_average_fired(self):
        columns_average = ColumnsAverage()
        for name in self.columns_headers_list:
            columns_average.columns.append(Column(column_name=name))

        # kind='modal' pauses the implementation until the window is closed
        columns_average.configure_traits(kind='modal')

        columns_to_be_averaged_temp = []
        for i in columns_average.columns:
            if i.selected:
                columns_to_be_averaged_temp.append(i.column_name)

        if columns_to_be_averaged_temp:  # If it's not empty
            self.columns_to_be_averaged.append(columns_to_be_averaged_temp)

            avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                columns_to_be_averaged_temp)
            self.columns_headers_list.append(avg_file_suffex)

    def _generate_filtered_and_creep_npy_fired(self):

        if self.npy_files_exist(
                os.path.join(self.npy_folder_path, self.file_name + '_' +
                             self.force_name + '.npy')) == False:
            return

        # 1- Export filtered force
        force = np.load(
            os.path.join(self.npy_folder_path, self.file_name + '_' +
                         self.force_name + '.npy')).flatten()
        peak_force_before_cycles_index = np.where(
            abs((force)) > abs(self.peak_force_before_cycles))[0][0]
        force_ascending = force[0:peak_force_before_cycles_index]
        force_rest = force[peak_force_before_cycles_index:]

        force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
            force_rest)

        force_max_min_indices = np.concatenate(
            (force_min_indices, force_max_indices))
        force_max_min_indices.sort()

        force_rest_filtered = force_rest[force_max_min_indices]
        force_filtered = np.concatenate((force_ascending, force_rest_filtered))
        np.save(
            os.path.join(
                self.npy_folder_path,
                self.file_name + '_' + self.force_name + '_filtered.npy'),
            force_filtered)

        # 2- Export filtered displacements
        # TODO I skipped time presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):

                disp = np.load(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.columns_headers_list[i] + '.npy')).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]

                if self.activate == True:
                    disp_ascending = savgol_filter(
                        disp_ascending,
                        window_length=self.window_length,
                        polyorder=self.polynomial_order)

                disp_rest_filtered = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate(
                    (disp_ascending, disp_rest_filtered))
                np.save(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.columns_headers_list[i] + '_filtered.npy'),
                    filtered_disp)

        # 3- Export creep for displacements
        # Cutting unwanted max min values to get correct full cycles and remove
        # false min/max values caused by noise
        if self.cutting_method == "Define Max, Min":
            force_max_indices_cutted, force_min_indices_cutted = \
                self.cut_indices_of_min_max_range(force_rest,
                                                  force_max_indices,
                                                  force_min_indices,
                                                  self.force_max,
                                                  self.force_min)
        elif self.cutting_method == "Define min cycle range(force difference)":
            force_max_indices_cutted, force_min_indices_cutted = \
                self.cut_indices_of_defined_range(force_rest,
                                                  force_max_indices,
                                                  force_min_indices,
                                                  self.min_cycle_force_range)

        print("Cycles number= ", len(force_min_indices))
        print("Cycles number after cutting fake cycles because of noise= ",
              len(force_min_indices_cutted))

        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            array = np.load(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '.npy')).flatten()
            array_rest = array[peak_force_before_cycles_index:]
            array_rest_maxima = array_rest[force_max_indices_cutted]
            array_rest_minima = array_rest[force_min_indices_cutted]
            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '_max.npy'),
                array_rest_maxima)
            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '_min.npy'),
                array_rest_minima)

        print('Filtered and creep npy files are generated.')

    def cut_indices_of_min_max_range(self, array, max_indices, min_indices,
                                     range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def cut_indices_of_defined_range(self, array, max_indices, min_indices,
                                     range_):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index, min_index in zip(max_indices, min_indices):
            if abs(array[max_index] - array[min_index]) > range_:
                cutted_max_indices.append(max_index)
                cutted_min_indices.append(min_index)

        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = argrelextrema(input_array, np.greater_equal)[0]
            force_min_indices = argrelextrema(input_array, np.less_equal)[0]
        else:
            force_max_indices = argrelextrema(input_array, np.less_equal)[0]
            force_min_indices = argrelextrema(input_array, np.greater_equal)[0]

        # Remove subsequent max/min indices (np.greater_equal will give 1,2 for
        # [4, 8, 8, 1])
        force_max_indices = self.remove_subsequent_max_values(
            force_max_indices)
        force_min_indices = self.remove_subsequent_min_values(
            force_min_indices)

        # If size is not equal remove the last element from the big one
        if force_max_indices.size > force_min_indices.size:
            force_max_indices = force_max_indices[:-1]
        elif force_max_indices.size < force_min_indices.size:
            force_min_indices = force_min_indices[:-1]

        return force_max_indices, force_min_indices

    def remove_subsequent_max_values(self, force_max_indices):
        to_delete_from_maxima = []
        for i in range(force_max_indices.size - 1):
            if force_max_indices[i + 1] - force_max_indices[i] == 1:
                to_delete_from_maxima.append(i)

        force_max_indices = np.delete(force_max_indices, to_delete_from_maxima)
        return force_max_indices

    def remove_subsequent_min_values(self, force_min_indices):
        to_delete_from_minima = []
        for i in range(force_min_indices.size - 1):
            if force_min_indices[i + 1] - force_min_indices[i] == 1:
                to_delete_from_minima.append(i)
        force_min_indices = np.delete(force_min_indices, to_delete_from_minima)
        return force_min_indices

    def _activate_changed(self):
        if self.activate == False:
            self.old_peak_force_before_cycles = self.peak_force_before_cycles
            self.peak_force_before_cycles = 0
        else:
            self.peak_force_before_cycles = self.old_peak_force_before_cycles

    def _window_length_changed(self, new):

        if new <= self.polynomial_order:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be bigger than polynomial order.')
            dialog.open()

        if new % 2 == 0 or new <= 0:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be odd positive integer.')
            dialog.open()

    def _polynomial_order_changed(self, new):

        if new >= self.window_length:
            dialog = MessageDialog(
                title='Attention!',
                message='Polynomial order must be less than window length.')
            dialog.open()

    #=========================================================================
    # Plotting
    #=========================================================================

    plot_list_current_elements_num = tr.Int(0)

    def npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            dialog = MessageDialog(
                title='Attention!',
                message='Please parse csv file to generate npy files first.'.
                format(self.plots_num))
            dialog.open()
            return False

    def filtered_and_creep_npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            dialog = MessageDialog(
                title='Attention!',
                message='Please generate filtered and creep npy files first.'.
                format(self.plots_num))
            dialog.open()
            return False

    def max_plots_number_is_reached(self):
        if len(self.plot_list) >= self.plots_num:
            dialog = MessageDialog(title='Attention!',
                                   message='Max plots number is {}'.format(
                                       self.plots_num))
            dialog.open()
            return True
        else:
            return False

    def _plot_list_changed(self):
        if len(self.plot_list) > self.plot_list_current_elements_num:
            self.plot_list_current_elements_num = len(self.plot_list)

    data_changed = tr.Event

    def _add_plot_fired(self):

        if self.max_plots_number_is_reached() == True:
            return

        if self.apply_filters:

            if self.filtered_and_creep_npy_files_exist(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.x_axis + '_filtered.npy')) == False:
                return

            x_axis_name = self.x_axis + '_filtered'
            y_axis_name = self.y_axis + '_filtered'

            print('Loading npy files...')

            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis
                                     + '_filtered.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis
                                     + '_filtered.npy'))
        else:

            if self.npy_files_exist(
                    os.path.join(self.npy_folder_path, self.file_name + '_' +
                                 self.x_axis + '.npy')) == False:
                return

            x_axis_name = self.x_axis
            y_axis_name = self.y_axis

            print('Loading npy files...')

            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis
                                     + '.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis
                                     + '.npy'))

        print('Adding Plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        ax = self.apply_new_subplot()

        ax.set_xlabel(x_axis_name)
        ax.set_ylabel(y_axis_name)
        ax.plot(x_axis_array,
                y_axis_array,
                'k',
                linewidth=1.2,
                color=np.random.rand(3, ),
                label=self.file_name + ', ' + x_axis_name)

        ax.legend()

        self.plot_list.append('{}, {}'.format(x_axis_name, y_axis_name))
        self.data_changed = True
        print('Finished adding plot!')

    def apply_new_subplot(self):
        plt = self.figure
        if (self.plots_num == 1):
            return plt.add_subplot(1, 1, 1)
        elif (self.plots_num == 2):
            plot_location = int('12' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 3):
            plot_location = int('13' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 4):
            plot_location = int('22' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 6):
            plot_location = int('23' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 9):
            plot_location = int('33' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)

    def _add_creep_plot_fired(self):

        if self.filtered_and_creep_npy_files_exist(
                os.path.join(self.npy_folder_path, self.file_name + '_' +
                             self.x_axis + '_max.npy')) == False:
            return

        if self.max_plots_number_is_reached() == True:
            return

        disp_max = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_max.npy'))
        disp_min = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_min.npy'))
        complete_cycles_number = disp_max.size

        print('Adding creep-fatigue plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        ax = self.apply_new_subplot()

        ax.set_xlabel('Cycles number')
        ax.set_ylabel(self.x_axis)

        if self.plot_every_nth_point > 1:
            disp_max = disp_max[0::self.plot_every_nth_point]
            disp_min = disp_min[0::self.plot_every_nth_point]

        if self.smooth:
            # Keeping the first item of the array and filtering the rest
            disp_max = np.concatenate(
                (np.array([disp_max[0]]),
                 savgol_filter(disp_max[1:],
                               window_length=self.window_length,
                               polyorder=self.polynomial_order)))
            disp_min = np.concatenate(
                (np.array([disp_min[0]]),
                 savgol_filter(disp_min[1:],
                               window_length=self.window_length,
                               polyorder=self.polynomial_order)))

        if self.normalize_cycles:
            ax.plot(np.linspace(0, 1., disp_max.size),
                    disp_max,
                    'k',
                    linewidth=1.2,
                    color='red',
                    label='Max' + ', ' + self.file_name + ', ' + self.x_axis)
            ax.plot(np.linspace(0, 1., disp_max.size),
                    disp_min,
                    'k',
                    linewidth=1.2,
                    color='green',
                    label='Min' + ', ' + self.file_name + ', ' + self.x_axis)
        else:
            ax.plot(np.linspace(0, complete_cycles_number, disp_max.size),
                    disp_max,
                    'k',
                    linewidth=1.2,
                    color='red',
                    label='Max' + ', ' + self.file_name + ', ' + self.x_axis)
            ax.plot(np.linspace(0, complete_cycles_number, disp_max.size),
                    disp_min,
                    'k',
                    linewidth=1.2,
                    color='green',
                    label='Min' + ', ' + self.file_name + ', ' + self.x_axis)

        ax.legend()

        self.plot_list.append('Creep-fatigue: {}, {}'.format(
            self.x_axis, self.y_axis))
        self.data_changed = True

        print('Finished adding creep-fatigue plot!')

    def reset(self):
        self.delimiter = ';'
        self.skip_first_rows = 3
        self.columns_headers_list = []
        self.npy_folder_path = ''
        self.file_name = ''
        self.apply_filters = False
        self.force_name = 'Kraft'
        self.plot_list = []
        self.columns_to_be_averaged = []

    #=========================================================================
    # Configuration of the view
    #=========================================================================

    traits_view = ui.View(ui.HSplit(
        ui.VSplit(
            ui.HGroup(ui.UItem('open_file_csv'),
                      ui.UItem('file_csv', style='readonly', width=0.1),
                      label='Input data'),
            ui.Item('add_columns_average', show_label=False),
            ui.VGroup(
                ui.VGroup(ui.Item(
                    'records_per_second',
                    enabled_when='take_time_from_first_column == False'),
                          ui.Item('take_time_from_first_column'),
                          label='Time calculation',
                          show_border=True),
                ui.VGroup(ui.Item('skip_first_rows'),
                          ui.Item('decimal'),
                          ui.Item('delimiter'),
                          ui.Item('parse_csv_to_npy', show_label=False),
                          label='Processing csv file',
                          show_border=True),
                ui.VGroup(ui.HGroup(ui.Item('plots_num'),
                                    ui.Item('clear_plot')),
                          ui.HGroup(ui.Item('x_axis'),
                                    ui.Item('x_axis_multiplier')),
                          ui.HGroup(ui.Item('y_axis'),
                                    ui.Item('y_axis_multiplier')),
                          ui.VGroup(ui.HGroup(
                              ui.Item('add_plot', show_label=False),
                              ui.Item('apply_filters')),
                                    show_border=True,
                                    label='Plotting X axis with Y axis'),
                          ui.VGroup(ui.HGroup(
                              ui.Item('add_creep_plot', show_label=False),
                              ui.VGroup(ui.Item('normalize_cycles'),
                                        ui.Item('smooth'),
                                        ui.Item('plot_every_nth_point'))),
                                    show_border=True,
                                    label='Plotting Creep-fatigue of x-axis'),
                          ui.Item('plot_list'),
                          show_border=True,
                          label='Plotting'))),
        ui.VGroup(
            ui.Item('force_name'),
            ui.VGroup(ui.VGroup(
                ui.Item('window_length'),
                ui.Item('polynomial_order'),
                enabled_when='activate == True or smooth == True'),
                      show_border=True,
                      label='Smoothing parameters (Savitzky-Golay filter):'),
            ui.VGroup(ui.VGroup(
                ui.Item('activate'),
                ui.Item('peak_force_before_cycles',
                        enabled_when='activate == True')),
                      show_border=True,
                      label='Smooth ascending branch for all displacements:'),
            ui.VGroup(
                ui.Item('cutting_method'),
                ui.VGroup(ui.Item('force_max'),
                          ui.Item('force_min'),
                          label='Max, Min:',
                          show_border=True,
                          enabled_when='cutting_method == "Define Max, Min"'),
                ui.VGroup(
                    ui.Item('min_cycle_force_range'),
                    label='Min cycle force range:',
                    show_border=True,
                    enabled_when=
                    'cutting_method == "Define min cycle range(force difference)"'
                ),
                show_border=True,
                label='Cut fake cycles for creep:'),
            ui.Item('generate_filtered_and_creep_npy', show_label=False),
            show_border=True,
            label='Filters'),
        ui.UItem('figure',
                 editor=MPLFigureEditor(),
                 resizable=True,
                 springy=True,
                 width=0.8,
                 label='2d plots')),
                          title='HCFF Filter',
                          resizable=True,
                          width=0.85,
                          height=0.7)