Exemple #1
0
class HCFF2(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    hcf = tr.Instance(HCFFRoot)

    def _hcf_default(self):
        return HCFFRoot(import_manager=FileImportManager())

    figure = tr.Instance(Figure)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    traits_view = ui.View(
        ui.HSplit(
            ui.Item(name='hcf',
                    editor=tree_editor,
                    show_label=False,
                    width=0.3
                    ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     label='2d plots')
        ),
        title='HCF Filter',
        resizable=True,
        width=0.6,
        height=0.6
    )
class FileFrame(ta.HasTraits):
    """
    Frame for file selecting
    """
    def_file = '/home/jackdra/LQCD/Scripts/EDM_paper/graphs/FF/FullFFFit/Neutron_ContFit_a.pdf'
    def_folder = '/home/jackdra/LQCD/Scripts/EDM_paper/graphs/FF/FullFFFit/'
    file_directory = ta.Directory(def_folder)
    file_name = ta.File(def_file, filter=['*.pdf'])

    Add_File = ta.Button()
    Add_Folder = ta.Button()
    # Undo_Add = ta.Button()

    view = tua.View(
        tua.HSplit(
            tua.Item('file_directory', style='custom', springy=True),
            tua.Item('file_name', style='custom', springy=True),
            tua.VGroup(tua.Item('file_directory', springy=True),
                       tua.Item('file_name', springy=True),
                       tua.Item('Add_File', show_label=False),
                       tua.Item('Add_Folder', show_label=False)
                       # tua.Item('Undo_Add',show_label=False),
                       )),
        resizable=True,
        height=1000,
        width=1500)

    def _file_name_changed(self):
        self.file_directory = '/'.join(self.file_name.split('/')[:-1]) + '/'

    def _file_directory_changed(self):
        file_list = GetAllPDF(self.file_directory)
        if len(file_list) > 0:
            self.file_name = GetAllPDF(self.file_directory)[0]

    def _Add_File_fired(self):
        global files_selected
        files_selected.file_list.append(self.file_name)

    def _Add_Folder_fired(self):
        global files_selected
        files_selected.file_list += GetAllPDF(self.file_directory)
Exemple #3
0
def create_view(window_name):
    view = tuiapi.View(tuiapi.HSplit(
        tuiapi.VGroup(
            tuiapi.Item('Multi_Select',
                        show_label=False,
                        width=224,
                        height=668,
                        springy=True,
                        resizable=True),
            tuiapi.Item('Change_Axis', show_label=False),
        ),
        tuiapi.Item('Plot_Data',
                    show_label=False,
                    width=800,
                    height=768,
                    springy=True,
                    resizable=True)),
                       style='custom',
                       width=1124,
                       height=868,
                       resizable=True,
                       title=window_name)
    return view
Exemple #4
0
def create_model_plot(sources, handler=None, metadata=None):
    """Create the plot window

    Parameters
    ----------

    """
    if not isinstance(sources, (list)):
        stop("*** error: sources must be list of files")

    def genrunid(path):
        return os.path.splitext(os.path.basename(path))[0]

    if metadata is not None:
        stop("*** error: call create_view directly")
        metadata.plot.configure_traits(view=view)
        return

    if [source for source in sources if F_EVALDB in os.path.basename(source)]:
        if len(sources) > 1:
            stop(
                "*** error: only one source allowed with {0}".format(F_EVALDB))
        source = sources[0]
        if not os.path.isfile(source):
            stop("*** error: {0}: no such file".format(source))
        filepaths, variables = readtabular(source)
        runid = genrunid(filepaths[0])

    else:
        filepaths = []
        for source in sources:
            if not os.path.isfile(source):
                logerr("{0}: {1}: no such file".format(iam, source))
                continue
            fname, fext = os.path.splitext(source)
            if fext not in L_REC_EXT:
                logerr("{0}: unrecognized file extension".format(source))
                continue
            filepaths.append(source)
        if logerr():
            stop("***error: stopping due to previous errors")
        variables = [""] * len(filepaths)
        runid = ("Material Model Laboratory"
                 if len(filepaths) > 1 else genrunid(filepaths[0]))

    view = tuiapi.View(tuiapi.HSplit(
        tuiapi.VGroup(
            tuiapi.Item('Multi_Select', show_label=False),
            tuiapi.Item('Change_Axis', show_label=False),
            tuiapi.Item('Reset_Zoom', show_label=False),
            tuiapi.Item('Reload_Data', show_label=False),
            tuiapi.Item('Print_to_PDF', show_label=False),
            tuiapi.VGroup(tuiapi.HGroup(
                tuiapi.Item("X_Scale",
                            label="X Scale",
                            editor=tuiapi.TextEditor(multi_line=False)),
                tuiapi.Item("Y_Scale",
                            label="Y Scale",
                            editor=tuiapi.TextEditor(multi_line=False))),
                          show_border=True),
            tuiapi.VGroup(tuiapi.HGroup(
                tuiapi.Item('Load_Overlay', show_label=False, springy=True),
                tuiapi.Item('Close_Overlay', show_label=False, springy=True),
            ),
                          tuiapi.Item('Single_Select_Overlay_Files',
                                      show_label=False,
                                      resizable=False),
                          show_border=True)),
        tuiapi.VGroup(
            tuiapi.Item('Plot_Data',
                        show_label=False,
                        width=800,
                        height=568,
                        springy=True,
                        resizable=True),
            tuiapi.Item('Step',
                        editor=tuiapi.RangeEditor(low_name='low_step',
                                                  high_name='high_step',
                                                  format='%d',
                                                  label_width=28,
                                                  mode='slider')),
        )),
                       style='custom',
                       width=1124,
                       height=868,
                       resizable=True,
                       title=runid)

    main_window = ModelPlot(filepaths=filepaths, file_variables=variables)
    main_window.configure_traits(view=view, handler=handler)
    return main_window
class ImagePlotInspector(traits.HasTraits):
    #Traits view definitions:

    settingsGroup = traitsui.VGroup(
        traitsui.VGroup(traitsui.HGroup('autoRangeColor', 'colorMapRangeLow',
                                        'colorMapRangeHigh'),
                        traitsui.HGroup('horizontalAutoRange',
                                        'horizontalLowerLimit',
                                        'horizontalUpperLimit'),
                        traitsui.HGroup('verticalAutoRange',
                                        'verticalLowerLimit',
                                        'verticalUpperLimit'),
                        label="axis limits",
                        show_border=True),
        traitsui.VGroup(traitsui.HGroup('object.model.scale',
                                        'object.model.offset'),
                        traitsui.HGroup(
                            traitsui.Item('object.model.pixelsX',
                                          label="Pixels X"),
                            traitsui.Item('object.model.pixelsY',
                                          label="Pixels Y")),
                        traitsui.HGroup(
                            traitsui.Item('object.model.ODCorrectionBool',
                                          label="Correct OD?"),
                            traitsui.Item('object.model.ODSaturationValue',
                                          label="OD saturation value")),
                        traitsui.HGroup(
                            traitsui.Item('contourLevels',
                                          label="Contour Levels"),
                            traitsui.Item('colormap', label="Colour Map")),
                        traitsui.HGroup(
                            traitsui.Item("cameraModel",
                                          label="Update Camera Settings to:")),
                        label="advanced",
                        show_border=True),
        label="settings")
    plotGroup = traitsui.Group(
        traitsui.Item('container',
                      editor=ComponentEditor(size=(800, 600)),
                      show_label=False))
    mainPlotGroup = traitsui.HSplit(plotGroup, label="Image")

    traits_view = traitsui.View(settingsGroup,
                                plotGroup,
                                handler=EagleHandler,
                                resizable=True)

    model = CameraImage()
    contourLevels = traits.Int(15)
    colormap = traits.Enum(colormaps.color_map_name_dict.keys())

    autoRangeColor = traits.Bool(True)
    colorMapRangeLow = traits.Float
    colorMapRangeHigh = traits.Float

    horizontalAutoRange = traits.Bool(True)
    horizontalLowerLimit = traits.Float
    horizontalUpperLimit = traits.Float

    verticalAutoRange = traits.Bool(True)
    verticalLowerLimit = traits.Float
    verticalUpperLimit = traits.Float

    fixAspectRatioBool = traits.Bool(False)

    cameraModel = traits.Enum("Custom", "ALTA0", "ANDOR0", "ALTA1", "ANDOR1")

    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------
    _image_index = traits.Instance(chaco.GridDataSource)
    _image_value = traits.Instance(chaco.ImageData)
    _cmap = traits.Trait(colormaps.jet, traits.Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(ImagePlotInspector, self).__init__(*args, **kwargs)
        #self.update(self.model)
        self.create_plot()
        #self._selectedFile_changed()
        logger.info("initialisation of experiment Eagle complete")

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = chaco.GridDataSource(scipy.array([]),
                                                 scipy.array([]),
                                                 sort_order=("ascending",
                                                             "ascending"))
        image_index_range = chaco.DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = chaco.ImageData(data=scipy.array([]),
                                            value_depth=1)

        image_value_range = chaco.DataRange1D(self._image_value)

        # Create the contour plots
        #self.polyplot = ContourPolyPlot(index=self._image_index,
        self.polyplot = chaco.CMapImagePlot(
            index=self._image_index,
            value=self._image_value,
            index_mapper=chaco.GridMapper(range=image_index_range),
            color_mapper=self._cmap(image_value_range))

        # Add a left axis to the plot
        left = chaco.PlotAxis(orientation='left',
                              title="y",
                              mapper=self.polyplot.index_mapper._ymapper,
                              component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = chaco.PlotAxis(orientation='bottom',
                                title="x",
                                mapper=self.polyplot.index_mapper._xmapper,
                                component=self.polyplot)
        self.polyplot.overlays.append(bottom)

        # Add some tools to the plot
        self.polyplot.tools.append(
            tools.PanTool(self.polyplot,
                          constrain_key="shift",
                          drag_button="middle"))

        self.polyplot.overlays.append(
            tools.ZoomTool(component=self.polyplot,
                           tool_mode="box",
                           always_on=False))

        self.lineInspectorX = clickableLineInspector.ClickableLineInspector(
            component=self.polyplot,
            axis='index_x',
            inspect_mode="indexed",
            write_metadata=True,
            is_listener=False,
            color="white")

        self.lineInspectorY = clickableLineInspector.ClickableLineInspector(
            component=self.polyplot,
            axis='index_y',
            inspect_mode="indexed",
            write_metadata=True,
            color="white",
            is_listener=False)

        self.polyplot.overlays.append(self.lineInspectorX)
        self.polyplot.overlays.append(self.lineInspectorY)

        self.boxSelection2D = boxSelection2D.BoxSelection2D(
            component=self.polyplot)
        self.polyplot.overlays.append(self.boxSelection2D)

        # Add these two plots to one container
        self.centralContainer = chaco.OverlayPlotContainer(padding=0,
                                                           use_backbuffer=True,
                                                           unified_draw=True)
        self.centralContainer.add(self.polyplot)

        # Create a colorbar
        cbar_index_mapper = chaco.LinearMapper(range=image_value_range)
        self.colorbar = chaco.ColorBar(
            index_mapper=cbar_index_mapper,
            plot=self.polyplot,
            padding_top=self.polyplot.padding_top,
            padding_bottom=self.polyplot.padding_bottom,
            padding_right=40,
            resizable='v',
            width=30)

        self.plotData = chaco.ArrayPlotData(
            line_indexHorizontal=scipy.array([]),
            line_valueHorizontal=scipy.array([]),
            scatter_indexHorizontal=scipy.array([]),
            scatter_valueHorizontal=scipy.array([]),
            scatter_colorHorizontal=scipy.array([]),
            fitLine_indexHorizontal=scipy.array([]),
            fitLine_valueHorizontal=scipy.array([]))

        self.crossPlotHorizontal = chaco.Plot(self.plotData, resizable="h")
        self.crossPlotHorizontal.height = 100
        self.crossPlotHorizontal.padding = 20
        self.crossPlotHorizontal.plot(
            ("line_indexHorizontal", "line_valueHorizontal"), line_style="dot")
        self.crossPlotHorizontal.plot(
            ("scatter_indexHorizontal", "scatter_valueHorizontal",
             "scatter_colorHorizontal"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=4)

        self.crossPlotHorizontal.index_range = self.polyplot.index_range.x_range

        self.plotData.set_data("line_indexVertical", scipy.array([]))
        self.plotData.set_data("line_valueVertical", scipy.array([]))
        self.plotData.set_data("scatter_indexVertical", scipy.array([]))
        self.plotData.set_data("scatter_valueVertical", scipy.array([]))
        self.plotData.set_data("scatter_colorVertical", scipy.array([]))
        self.plotData.set_data("fitLine_indexVertical", scipy.array([]))
        self.plotData.set_data("fitLine_valueVertical", scipy.array([]))

        self.crossPlotVertical = chaco.Plot(self.plotData,
                                            width=140,
                                            orientation="v",
                                            resizable="v",
                                            padding=20,
                                            padding_bottom=160)
        self.crossPlotVertical.plot(
            ("line_indexVertical", "line_valueVertical"), line_style="dot")

        self.crossPlotVertical.plot(
            ("scatter_indexVertical", "scatter_valueVertical",
             "scatter_colorVertical"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=4)

        self.crossPlotVertical.index_range = self.polyplot.index_range.y_range

        # Create a container and add components
        self.container = chaco.HPlotContainer(padding=40,
                                              fill_padding=True,
                                              bgcolor="white",
                                              use_backbuffer=False)

        inner_cont = chaco.VPlotContainer(padding=40, use_backbuffer=True)
        inner_cont.add(self.crossPlotHorizontal)
        inner_cont.add(self.centralContainer)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.crossPlotVertical)

    def update(self, model):
        print "updating"
        logger.info("updating plot")
        #        if self.selectedFile=="":
        #            logger.warning("selected file was empty. Will not attempt to update plot.")
        #            return
        if self.autoRangeColor:
            self.colorbar.index_mapper.range.low = model.minZ
            self.colorbar.index_mapper.range.high = model.maxZ
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.plotData.set_data("line_indexHorizontal", model.xs)
        self.plotData.set_data("line_indexVertical", model.ys)
        self.updatePlotLimits()
        self._image_index.metadata_changed = True
        self.container.invalidate_draw()
        self.container.request_redraw()

    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""
        if self.horizontalAutoRange:
            self.crossPlotHorizontal.value_range.low = self.model.minZ
            self.crossPlotHorizontal.value_range.high = self.model.maxZ
        if self.verticalAutoRange:
            self.crossPlotVertical.value_range.low = self.model.minZ
            self.crossPlotVertical.value_range.high = self.model.maxZ
        if self._image_index.metadata.has_key("selections"):
            selections = self._image_index.metadata["selections"]
            if not selections:  #selections is empty list
                return  #don't need to do update lines as no mouse over screen. This happens at beginning of script
            x_ndx, y_ndx = selections
            if y_ndx and x_ndx:
                self.plotData.set_data("line_valueHorizontal",
                                       self._image_value.data[y_ndx, :])
                self.plotData.set_data("line_valueVertical",
                                       self._image_value.data[:, x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.plotData.set_data("scatter_indexHorizontal",
                                       scipy.array([xdata[x_ndx]]))
                self.plotData.set_data("scatter_indexVertical",
                                       scipy.array([ydata[y_ndx]]))
                self.plotData.set_data(
                    "scatter_valueHorizontal",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_valueVertical",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_colorHorizontal",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_colorVertical",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
        else:
            self.plotData.set_data("scatter_valueHorizontal", scipy.array([]))
            self.plotData.set_data("scatter_valueVertical", scipy.array([]))
            self.plotData.set_data("line_valueHorizontal", scipy.array([]))
            self.plotData.set_data("line_valueVertical", scipy.array([]))
            self.plotData.set_data("fitLine_valueHorizontal", scipy.array([]))
            self.plotData.set_data("fitLine_valueVertical", scipy.array([]))

    def _colormap_changed(self):
        self._cmap = colormaps.color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.crossPlotHorizontal.color_mapper.range
            self.crossPlotHorizontal.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.crossPlotHorizontal.plots["dot"][0].color_mapper = self._cmap(
                value_range)
            self.crossPlotVertical.plots["dot"][0].color_mapper = self._cmap(
                value_range)
            self.container.request_redraw()

    def _colorMapRangeLow_changed(self):
        self.colorbar.index_mapper.range.low = self.colorMapRangeLow

    def _colorMapRangeHigh_changed(self):
        self.colorbar.index_mapper.range.high = self.colorMapRangeHigh

    def _horizontalLowerLimit_changed(self):
        self.crossPlotHorizontal.value_range.low = self.horizontalLowerLimit

    def _horizontalUpperLimit_changed(self):
        self.crossPlotHorizontal.value_range.high = self.horizontalUpperLimit

    def _verticalLowerLimit_changed(self):
        self.crossPlotVertical.value_range.low = self.verticalLowerLimit

    def _verticalUpperLimit_changed(self):
        self.crossPlotVertical.value_range.high = self.verticalUpperLimit

    def _autoRange_changed(self):
        if self.autoRange:
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels

    def _colorMapRangeLow_default(self):
        logger.debug("setting color map rangle low default")
        return self.model.minZ

    def _colorMapRangeHigh_default(self):
        return self.model.maxZ

    def _horizontalLowerLimit_default(self):
        return self.model.minZ

    def _horizontalUpperLimit_default(self):
        return self.model.maxZ

    def _verticalLowerLimit_default(self):
        return self.model.minZ

    def _verticalUpperLimit_default(self):
        return self.model.maxZ

    def _selectedFit_changed(self, selected):
        logger.debug("selected fit was changed")

    def _fixAspectRatioBool_changed(self):
        if self.fixAspectRatioBool:
            #using zoom range works but then when you reset zoom this function isn't called...
            #            rangeObject = self.polyplot.index_mapper.range
            #            xrangeValue = rangeObject.high[0]-rangeObject.low[0]
            #            yrangeValue = rangeObject.high[1]-rangeObject.low[1]
            #            logger.info("xrange = %s, yrange = %s " % (xrangeValue, yrangeValue))
            #            aspectRatioSquare = (xrangeValue)/(yrangeValue)
            #            self.polyplot.aspect_ratio=aspectRatioSquare
            self.centralContainer.aspect_ratio = float(
                self.model.pixelsX) / float(self.model.pixelsY)
            #self.polyplot.aspect_ratio = self.model.pixelsX/self.model.pixelsY

        else:
            self.centralContainer.aspect_ratio = None
            #self.polyplot.aspect_ratio = None
        self.container.request_redraw()
        self.centralContainer.request_redraw()

    def updatePlotLimits(self):
        """just updates the values in the GUI  """
        if self.autoRangeColor:
            self.colorMapRangeLow = self.model.minZ
            self.colorMapRangeHigh = self.model.maxZ
        if self.horizontalAutoRange:
            self.horizontalLowerLimit = self.model.minZ
            self.horizontalUpperLimit = self.model.maxZ
        if self.verticalAutoRange:
            self.verticalLowerLimit = self.model.minZ
            self.verticalUpperLimit = self.model.maxZ

    def _selectedFile_changed(self):
        self.model.getImageData(self.selectedFile)
        if self.updatePhysicsBool:
            self.physics.updatePhysics()
        for fit in self.fitList:
            fit.fitted = False
            fit.fittingStatus = fit.notFittedForCurrentStatus
            if fit.autoFitBool:  #we should automatically start fitting for this Fit
                fit._fit_routine(
                )  #starts a thread to perform the fit. auto guess and auto draw will be handled automatically
        self.update_view()
        #update log file plot if autorefresh is selected
        if self.logFilePlotObject.autoRefresh:
            try:
                self.logFilePlotObject.refreshPlot()
            except Exception as e:
                logger.error("failed to update log plot -  %s...." % e.message)

    def _cameraModel_changed(self):
        """camera model enum can be used as a helper. It just sets all the relevant
        editable parameters to the correct values. e.g. pixels size, etc.

        cameras:  "Andor Ixon 3838", "Apogee ALTA"
        """
        logger.info("camera model changed")
        if self.cameraModel == "ANDOR0":
            self.model.pixelsX = 512
            self.model.pixelsY = 512
            self.physics.pixelSize = 16.0
            self.physics.magnification = 2.0
            self.searchString = "ANDOR0"
        elif self.cameraModel == "ALTA0":
            self.model.pixelsX = 768
            self.model.pixelsY = 512
            self.physics.pixelSize = 9.0
            self.physics.magnification = 0.5
            self.searchString = "ALTA0"
        elif self.cameraModel == "ALTA1":
            self.model.pixelsX = 768
            self.model.pixelsY = 512
            self.physics.pixelSize = 9.0
            self.physics.magnification = 4.25
            self.searchString = "ALTA1"
        elif self.cameraModel == "ANDOR1":
            self.model.pixelsX = 512
            self.model.pixelsY = 512
            self.physics.pixelSize = 16.0
            self.physics.magnification = 2.0
            self.searchString = "ANDOR1"
        else:
            logger.error("unrecognised camera model")
        self.refreshFitReferences()
        self.model.getImageData(self.selectedFile)

    def refreshFitReferences(self):
        """When aspects of the image change so that the fits need to have
        properties updated, it should be done by this function"""
        for fit in self.fitList:
            fit.endX = self.model.pixelsX
            fit.endY = self.model.pixelsY

    def _pixelsX_changed(self):
        """If pixelsX or pixelsY change, we must send the new arrays to the fit functions """
        logger.info("pixels X Change detected")
        self.refreshFitReferences()
        self.update(self.model)
        self.model.getImageData(self.selectedFile)

    def _pixelsY_changed(self):
        """If pixelsX or pixelsY change, we must send the new arrays to the fit functions """
        logger.info("pixels Y Change detected")
        self.refreshFitReferences()
        self.update(self.model)
        self.model.getImageData(self.selectedFile)

    @traits.on_trait_change('model')
    def update_view(self):
        if self.model is not None:
            self.update(self.model)
Exemple #6
0
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_rows = tr.Int(4, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    force_name = tr.Str('Kraft')
    peak_force_before_cycles = tr.Float(30)
    plots_num = tr.Enum(1, 2, 3, 4, 6, 9)
    plot_list = tr.List()
    plot = tr.Button
    add_plot = tr.Button
    add_creep_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)

    figure = tr.Instance(Figure)

#     plots_list = tr.List(editor=ui.SetEditor(
#         values=['kumquats', 'pomegranates', 'kiwi'],
#         can_move_all=False,
#         left_column_title='List'))

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = ['*.csv', ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open', wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path

        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                nrows=1, header=None
            )
        )[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    #=========================================================================
    # Parameters of the filter algorithm
    #=========================================================================

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(len(self.columns_headers_list)):
            column_array = np.array(pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal, skiprows=self.skip_rows, usecols=[i]))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.columns_headers_list[i] + '.npy'), column_array)

        print('Finsihed parsing csv into npy files.')

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

#     def _add_columns_average_fired(self):
#         columns_average = ColumnsAverage(
#             columns_names=self.columns_headers_list)
#         # columns_average.set_columns_headers_list(self.columns_headers_list)
#         columns_average.configure_traits()

    def _generate_filtered_npy_fired(self):

        # 1- Export filtered force
        force = np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.force_name + '.npy')).flatten()
        peak_force_before_cycles_index = np.where(
            abs((force)) > abs(self.peak_force_before_cycles))[0][0]
        force_ascending = force[0:peak_force_before_cycles_index]
        force_rest = force[peak_force_before_cycles_index:]

        force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
            force_rest)

        force_max_min_indices = np.concatenate(
            (force_min_indices, force_max_indices))
        force_max_min_indices.sort()

        force_rest_filtered = force_rest[force_max_min_indices]
        force_filtered = np.concatenate((force_ascending, force_rest_filtered))
        np.save(os.path.join(self.npy_folder_path, self.file_name +
                             '_' + self.force_name + '_filtered.npy'), force_filtered)

        # 2- Export filtered displacements
        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):

                disp = np.load(os.path.join(self.npy_folder_path, self.file_name +
                                            '_' + self.columns_headers_list[i] + '.npy')).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]
                disp_ascending = savgol_filter(
                    disp_ascending, window_length=51, polyorder=2)
                disp_rest = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate((disp_ascending, disp_rest))
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_filtered.npy'), filtered_disp)

        # 3- Export creep for displacements
        # Cutting unwanted max min values to get correct full cycles and remove
        # false min/max values caused by noise
        force_max_indices_cutted, force_min_indices_cutted = self.cut_indices_in_range(force_rest,
                                                                                       force_max_indices,
                                                                                       force_min_indices,
                                                                                       self.force_max,
                                                                                       self.force_min)

        print("Cycles number= ", len(force_min_indices))
        print("Cycles number after cutting unwanted max-min range= ",
              len(force_min_indices_cutted))

        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):
                disp_rest_maxima = disp_rest[force_max_indices_cutted]
                disp_rest_minima = disp_rest[force_min_indices_cutted]
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_max.npy'), disp_rest_maxima)
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_min.npy'), disp_rest_minima)

        print('Filtered npy files are generated.')

    def cut_indices_in_range(self, array, max_indices, min_indices, range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = argrelextrema(input_array, np.greater_equal)[0]
            force_min_indices = argrelextrema(input_array, np.less_equal)[0]
        else:
            force_max_indices = argrelextrema(input_array, np.less_equal)[0]
            force_min_indices = argrelextrema(input_array, np.greater_equal)[0]

        # Remove subsequent max/min indices (np.greater_equal will give 1,2 for
        # [4, 8, 8, 1])
        force_max_indices = self.remove_subsequent_max_values(
            force_max_indices)
        force_min_indices = self.remove_subsequent_min_values(
            force_min_indices)

        # If size is not equal remove the last element from the big one
        if force_max_indices.size > force_min_indices.size:
            force_max_indices = force_max_indices[:-1]
        elif force_max_indices.size < force_min_indices.size:
            force_min_indices = force_min_indices[:-1]

        return force_max_indices, force_min_indices

    def remove_subsequent_max_values(self, force_max_indices):
        to_delete_from_maxima = []
        for i in range(force_max_indices.size - 1):
            if force_max_indices[i + 1] - force_max_indices[i] == 1:
                to_delete_from_maxima.append(i)

        force_max_indices = np.delete(force_max_indices, to_delete_from_maxima)
        return force_max_indices

    def remove_subsequent_min_values(self, force_min_indices):
        to_delete_from_minima = []
        for i in range(force_min_indices.size - 1):
            if force_min_indices[i + 1] - force_min_indices[i] == 1:
                to_delete_from_minima.append(i)
        force_min_indices = np.delete(force_min_indices, to_delete_from_minima)
        return force_min_indices

    #=========================================================================
    # Plotting
    #=========================================================================
    plot_figure_num = tr.Int(0)

    def _plot_fired(self):
        ax = self.figure.add_subplot()

    def x_plot_fired(self):
        self.plot_figure_num += 1
        plt.draw()
        plt.show()

    data_changed = tr.Event

    def _add_plot_fired(self):

        if False:  # (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        print('Loading npy files...')

        if self.apply_filters:
            x_axis_name = self.x_axis + '_filtered'
            y_axis_name = self.y_axis + '_filtered'
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_filtered.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '_filtered.npy'))
        else:
            x_axis_name = self.x_axis
            y_axis_name = self.y_axis
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '.npy'))

        print('Adding Plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

#        plt.figure(self.plot_figure_num)
        ax = self.figure.add_subplot(1, 1, 1)

        ax.set_xlabel('Displacement [mm]')
        ax.set_ylabel('kN')
        ax.set_title('Original data', fontsize=20)
        ax.plot(x_axis_array, y_axis_array, 'k', linewidth=0.8)

        self.plot_list.append('{}, {}'.format(x_axis_name, y_axis_name))
        self.data_changed = True
        print('Finished adding plot!')

    def apply_new_subplot(self):
        plt = self.figure
        if (self.plots_num == 1):
            plt.add_subplot(1, 1, 1)
        elif (self.plots_num == 2):
            plot_location = int('12' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 3):
            plot_location = int('13' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 4):
            plot_location = int('22' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 6):
            plot_location = int('23' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 9):
            plot_location = int('33' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)

    def _add_creep_plot_fired(self):

        plt = self.figure
        if (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        disp_max = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_max.npy'))
        disp_min = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_min.npy'))

        print('Adding creep plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        self.apply_new_subplot()
        plt.xlabel('Cycles number')
        plt.ylabel('mm')
        plt.title('Fatigue creep curve', fontsize=20)
        plt.plot(np.arange(0, disp_max.size), disp_max,
                 'k', linewidth=0.8, color='red')
        plt.plot(np.arange(0, disp_min.size), disp_min,
                 'k', linewidth=0.8, color='green')

        self.plot_list.append('Plot {}'.format(len(self.plot_list) + 1))

        print('Finished adding creep plot!')

    #=========================================================================
    # Configuration of the view
    #=========================================================================

    traits_view = ui.View(
        ui.HSplit(
            ui.VSplit(
                ui.HGroup(
                    ui.UItem('open_file_csv'),
                    ui.UItem('file_csv', style='readonly'),
                    label='Input data'
                ),
                ui.Item('add_columns_average', show_label=False),
                ui.VGroup(
                    ui.Item('skip_rows'),
                    ui.Item('decimal'),
                    ui.Item('delimiter'),
                    ui.Item('parse_csv_to_npy', show_label=False),
                    label='Filter parameters'
                ),
                ui.VGroup(
                    ui.Item('plots_num'),
                    ui.HGroup(ui.Item('x_axis'), ui.Item('x_axis_multiplier')),
                    ui.HGroup(ui.Item('y_axis'), ui.Item('y_axis_multiplier')),
                    ui.HGroup(ui.Item('add_plot', show_label=False),
                              ui.Item('apply_filters')),
                    ui.HGroup(ui.Item('add_creep_plot', show_label=False)),
                    ui.Item('plot_list'),
                    ui.Item('plot', show_label=False),
                    show_border=True,
                    label='Plotting settings'),
            ),
            ui.VGroup(
                ui.Item('force_name'),
                ui.HGroup(ui.Item('peak_force_before_cycles'),
                          show_border=True, label='Skip noise of ascending branch:'),
                #                     ui.Item('plots_list'),
                ui.VGroup(ui.Item('force_max'),
                          ui.Item('force_min'),
                          show_border=True,
                          label='Cut fake cycles for creep:'),
                ui.Item('generate_filtered_npy', show_label=False),
                show_border=True,
                label='Filters'
            ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     width=0.3,
                     label='2d plots'),
        ),
        title='HCFF Filter',
        resizable=True,
        width=0.6,
        height=0.6

    )
Exemple #7
0
class MainWindow(tr.HasStrictTraits):

    forming_process_view = tr.Instance(FormingProcessView, ())
    forming_task_scene = tr.Instance(FormingTaskView3D, ())

    forming_process = tr.Property

    def _get_forming_process(self):
        return self.forming_process_view.forming_process

    def _set_forming_process(self, fp):
        self.forming_process_view.forming_process = fp

    def _selected_node_changed(self):
        self.selected_node.ui = self

    def get_vot_range(self):
        return self.forming_task_scene.get_vot_range()

    vot = tr.DelegatesTo('forming_task_scene')

    data_changed = tr.Event

    replot = tr.Button

    def _replot_fired(self):
        self.figure.clear()
        self.selected_node.plot(self.figure)
        self.data_changed = True

    clear = tr.Button()

    def _clear_fired(self):
        self.figure.clear()
        self.data_changed = True

    view = tu.View(tu.HSplit(
        tu.VGroup(
            tu.Item(
                'forming_process_view@',
                id='oricreate.hsplit.left.tree.id',
                resizable=True,
                show_label=False,
                width=300,
            ),
            id='oricreate.hsplit.left.id',
        ),
        tu.VGroup(
            tu.Item(
                'forming_task_scene@',
                show_label=False,
                resizable=True,
                id='oricreate.hsplit.viz3d.notebook.id',
            ),
            id='oricreate.hsplit.viz3d.id',
            label='viz sheet',
        ),
        id='oricreate.hsplit.id',
    ),
                   id='oricreate.id',
                   width=1.0,
                   height=1.0,
                   title='OriCreate',
                   resizable=True,
                   handler=TreeViewHandler(),
                   key_bindings=key_bindings,
                   toolbar=tu.ToolBar(*toolbar_actions,
                                      image_size=(32, 32),
                                      show_tool_names=False,
                                      show_divider=True,
                                      name='view_toolbar'),
                   menubar=tu.MenuBar(
                       Menu(menu_exit, Separator(), name='File'), ))
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    records_per_second = tr.Float(100)
    take_time_from_first_column = tr.Bool
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_first_rows = tr.Int(3, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    normalize_cycles = tr.Bool
    smooth = tr.Bool
    plot_every_nth_point = tr.Range(low=1, high=1000000, mode='spinner')
    force_name = tr.Str('Kraft')
    old_peak_force_before_cycles = tr.Float
    peak_force_before_cycles = tr.Float
    window_length = tr.Int(31)
    polynomial_order = tr.Int(2)
    activate = tr.Bool(False)
    plots_num = tr.Enum(1, 2, 3, 4, 6, 9)
    plot_list = tr.List()
    add_plot = tr.Button
    add_creep_plot = tr.Button(desc='Creep plot of X axis array')
    clear_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_and_creep_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)
    min_cycle_force_range = tr.Float(50)
    cutting_method = tr.Enum('Define min cycle range(force difference)',
                             'Define Max, Min')
    columns_to_be_averaged = tr.List

    figure = tr.Instance(Figure)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):

        self.reset()
        """ Handles the user clicking the 'Open...' button.
        """
        extns = [
            '*.csv',
        ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open',
                            wildcard=wildcard,
                            default_path=self.file_csv)

        result = dialog.open()
        """ Test if the user opened a file to avoid throwing an exception if he 
        doesn't """
        if result == OK:
            self.file_csv = dialog.path
        else:
            return
        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(self.file_csv,
                        delimiter=self.delimiter,
                        decimal=self.decimal,
                        nrows=1,
                        header=None))[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)
        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(
                len(self.columns_headers_list) -
                len(self.columns_to_be_averaged)):
            column_array = np.array(
                pd.read_csv(self.file_csv,
                            delimiter=self.delimiter,
                            decimal=self.decimal,
                            skiprows=self.skip_first_rows,
                            usecols=[i]))
            """ TODO! Create time array supposing it's column is the
            first one in the file and that we have 100 reads in 1 second """
            if i == 0 and self.take_time_from_first_column == False:
                column_array = np.arange(start=0.0,
                                         stop=len(column_array) /
                                         self.records_per_second,
                                         step=1.0 / self.records_per_second)

            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '.npy'), column_array)
        """ Exporting npy arrays of averaged columns """
        for columns_names in self.columns_to_be_averaged:
            temp = np.zeros((1))
            for column_name in columns_names:
                temp = temp + np.load(
                    os.path.join(self.npy_folder_path, self.file_name + '_' +
                                 column_name + '.npy')).flatten()
            avg = temp / len(columns_names)

            avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                columns_names)
            np.save(
                os.path.join(self.npy_folder_path,
                             self.file_name + '_' + avg_file_suffex + '.npy'),
                avg)

        print('Finsihed parsing csv into npy files.')

    def get_suffex_for_columns_to_be_averaged(self, columns_names):
        suffex_for_saved_file_name = 'avg_' + '_'.join(columns_names)
        return suffex_for_saved_file_name

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(c for c in original_file_name
                                      if c in valid_chars)
        return new_valid_file_name

    def _clear_plot_fired(self):
        self.figure.clear()
        self.plot_list = []
        self.data_changed = True

    def _add_columns_average_fired(self):
        columns_average = ColumnsAverage()
        for name in self.columns_headers_list:
            columns_average.columns.append(Column(column_name=name))

        # kind='modal' pauses the implementation until the window is closed
        columns_average.configure_traits(kind='modal')

        columns_to_be_averaged_temp = []
        for i in columns_average.columns:
            if i.selected:
                columns_to_be_averaged_temp.append(i.column_name)

        if columns_to_be_averaged_temp:  # If it's not empty
            self.columns_to_be_averaged.append(columns_to_be_averaged_temp)

            avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                columns_to_be_averaged_temp)
            self.columns_headers_list.append(avg_file_suffex)

    def _generate_filtered_and_creep_npy_fired(self):

        if self.npy_files_exist(
                os.path.join(self.npy_folder_path, self.file_name + '_' +
                             self.force_name + '.npy')) == False:
            return

        # 1- Export filtered force
        force = np.load(
            os.path.join(self.npy_folder_path, self.file_name + '_' +
                         self.force_name + '.npy')).flatten()
        peak_force_before_cycles_index = np.where(
            abs((force)) > abs(self.peak_force_before_cycles))[0][0]
        force_ascending = force[0:peak_force_before_cycles_index]
        force_rest = force[peak_force_before_cycles_index:]

        force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
            force_rest)

        force_max_min_indices = np.concatenate(
            (force_min_indices, force_max_indices))
        force_max_min_indices.sort()

        force_rest_filtered = force_rest[force_max_min_indices]
        force_filtered = np.concatenate((force_ascending, force_rest_filtered))
        np.save(
            os.path.join(
                self.npy_folder_path,
                self.file_name + '_' + self.force_name + '_filtered.npy'),
            force_filtered)

        # 2- Export filtered displacements
        # TODO I skipped time presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):

                disp = np.load(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.columns_headers_list[i] + '.npy')).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]

                if self.activate == True:
                    disp_ascending = savgol_filter(
                        disp_ascending,
                        window_length=self.window_length,
                        polyorder=self.polynomial_order)

                disp_rest_filtered = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate(
                    (disp_ascending, disp_rest_filtered))
                np.save(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.columns_headers_list[i] + '_filtered.npy'),
                    filtered_disp)

        # 3- Export creep for displacements
        # Cutting unwanted max min values to get correct full cycles and remove
        # false min/max values caused by noise
        if self.cutting_method == "Define Max, Min":
            force_max_indices_cutted, force_min_indices_cutted = \
                self.cut_indices_of_min_max_range(force_rest,
                                                  force_max_indices,
                                                  force_min_indices,
                                                  self.force_max,
                                                  self.force_min)
        elif self.cutting_method == "Define min cycle range(force difference)":
            force_max_indices_cutted, force_min_indices_cutted = \
                self.cut_indices_of_defined_range(force_rest,
                                                  force_max_indices,
                                                  force_min_indices,
                                                  self.min_cycle_force_range)

        print("Cycles number= ", len(force_min_indices))
        print("Cycles number after cutting fake cycles because of noise= ",
              len(force_min_indices_cutted))

        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            array = np.load(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '.npy')).flatten()
            array_rest = array[peak_force_before_cycles_index:]
            array_rest_maxima = array_rest[force_max_indices_cutted]
            array_rest_minima = array_rest[force_min_indices_cutted]
            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '_max.npy'),
                array_rest_maxima)
            np.save(
                os.path.join(
                    self.npy_folder_path, self.file_name + '_' +
                    self.columns_headers_list[i] + '_min.npy'),
                array_rest_minima)

        print('Filtered and creep npy files are generated.')

    def cut_indices_of_min_max_range(self, array, max_indices, min_indices,
                                     range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def cut_indices_of_defined_range(self, array, max_indices, min_indices,
                                     range_):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index, min_index in zip(max_indices, min_indices):
            if abs(array[max_index] - array[min_index]) > range_:
                cutted_max_indices.append(max_index)
                cutted_min_indices.append(min_index)

        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = argrelextrema(input_array, np.greater_equal)[0]
            force_min_indices = argrelextrema(input_array, np.less_equal)[0]
        else:
            force_max_indices = argrelextrema(input_array, np.less_equal)[0]
            force_min_indices = argrelextrema(input_array, np.greater_equal)[0]

        # Remove subsequent max/min indices (np.greater_equal will give 1,2 for
        # [4, 8, 8, 1])
        force_max_indices = self.remove_subsequent_max_values(
            force_max_indices)
        force_min_indices = self.remove_subsequent_min_values(
            force_min_indices)

        # If size is not equal remove the last element from the big one
        if force_max_indices.size > force_min_indices.size:
            force_max_indices = force_max_indices[:-1]
        elif force_max_indices.size < force_min_indices.size:
            force_min_indices = force_min_indices[:-1]

        return force_max_indices, force_min_indices

    def remove_subsequent_max_values(self, force_max_indices):
        to_delete_from_maxima = []
        for i in range(force_max_indices.size - 1):
            if force_max_indices[i + 1] - force_max_indices[i] == 1:
                to_delete_from_maxima.append(i)

        force_max_indices = np.delete(force_max_indices, to_delete_from_maxima)
        return force_max_indices

    def remove_subsequent_min_values(self, force_min_indices):
        to_delete_from_minima = []
        for i in range(force_min_indices.size - 1):
            if force_min_indices[i + 1] - force_min_indices[i] == 1:
                to_delete_from_minima.append(i)
        force_min_indices = np.delete(force_min_indices, to_delete_from_minima)
        return force_min_indices

    def _activate_changed(self):
        if self.activate == False:
            self.old_peak_force_before_cycles = self.peak_force_before_cycles
            self.peak_force_before_cycles = 0
        else:
            self.peak_force_before_cycles = self.old_peak_force_before_cycles

    def _window_length_changed(self, new):

        if new <= self.polynomial_order:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be bigger than polynomial order.')
            dialog.open()

        if new % 2 == 0 or new <= 0:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be odd positive integer.')
            dialog.open()

    def _polynomial_order_changed(self, new):

        if new >= self.window_length:
            dialog = MessageDialog(
                title='Attention!',
                message='Polynomial order must be less than window length.')
            dialog.open()

    #=========================================================================
    # Plotting
    #=========================================================================

    plot_list_current_elements_num = tr.Int(0)

    def npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            dialog = MessageDialog(
                title='Attention!',
                message='Please parse csv file to generate npy files first.'.
                format(self.plots_num))
            dialog.open()
            return False

    def filtered_and_creep_npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            dialog = MessageDialog(
                title='Attention!',
                message='Please generate filtered and creep npy files first.'.
                format(self.plots_num))
            dialog.open()
            return False

    def max_plots_number_is_reached(self):
        if len(self.plot_list) >= self.plots_num:
            dialog = MessageDialog(title='Attention!',
                                   message='Max plots number is {}'.format(
                                       self.plots_num))
            dialog.open()
            return True
        else:
            return False

    def _plot_list_changed(self):
        if len(self.plot_list) > self.plot_list_current_elements_num:
            self.plot_list_current_elements_num = len(self.plot_list)

    data_changed = tr.Event

    def _add_plot_fired(self):

        if self.max_plots_number_is_reached() == True:
            return

        if self.apply_filters:

            if self.filtered_and_creep_npy_files_exist(
                    os.path.join(
                        self.npy_folder_path, self.file_name + '_' +
                        self.x_axis + '_filtered.npy')) == False:
                return

            x_axis_name = self.x_axis + '_filtered'
            y_axis_name = self.y_axis + '_filtered'

            print('Loading npy files...')

            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis
                                     + '_filtered.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis
                                     + '_filtered.npy'))
        else:

            if self.npy_files_exist(
                    os.path.join(self.npy_folder_path, self.file_name + '_' +
                                 self.x_axis + '.npy')) == False:
                return

            x_axis_name = self.x_axis
            y_axis_name = self.y_axis

            print('Loading npy files...')

            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis
                                     + '.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis
                                     + '.npy'))

        print('Adding Plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        ax = self.apply_new_subplot()

        ax.set_xlabel(x_axis_name)
        ax.set_ylabel(y_axis_name)
        ax.plot(x_axis_array,
                y_axis_array,
                'k',
                linewidth=1.2,
                color=np.random.rand(3, ),
                label=self.file_name + ', ' + x_axis_name)

        ax.legend()

        self.plot_list.append('{}, {}'.format(x_axis_name, y_axis_name))
        self.data_changed = True
        print('Finished adding plot!')

    def apply_new_subplot(self):
        plt = self.figure
        if (self.plots_num == 1):
            return plt.add_subplot(1, 1, 1)
        elif (self.plots_num == 2):
            plot_location = int('12' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 3):
            plot_location = int('13' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 4):
            plot_location = int('22' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 6):
            plot_location = int('23' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)
        elif (self.plots_num == 9):
            plot_location = int('33' + str(len(self.plot_list) + 1))
            return plt.add_subplot(plot_location)

    def _add_creep_plot_fired(self):

        if self.filtered_and_creep_npy_files_exist(
                os.path.join(self.npy_folder_path, self.file_name + '_' +
                             self.x_axis + '_max.npy')) == False:
            return

        if self.max_plots_number_is_reached() == True:
            return

        disp_max = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_max.npy'))
        disp_min = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_min.npy'))
        complete_cycles_number = disp_max.size

        print('Adding creep-fatigue plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        ax = self.apply_new_subplot()

        ax.set_xlabel('Cycles number')
        ax.set_ylabel(self.x_axis)

        if self.plot_every_nth_point > 1:
            disp_max = disp_max[0::self.plot_every_nth_point]
            disp_min = disp_min[0::self.plot_every_nth_point]

        if self.smooth:
            # Keeping the first item of the array and filtering the rest
            disp_max = np.concatenate(
                (np.array([disp_max[0]]),
                 savgol_filter(disp_max[1:],
                               window_length=self.window_length,
                               polyorder=self.polynomial_order)))
            disp_min = np.concatenate(
                (np.array([disp_min[0]]),
                 savgol_filter(disp_min[1:],
                               window_length=self.window_length,
                               polyorder=self.polynomial_order)))

        if self.normalize_cycles:
            ax.plot(np.linspace(0, 1., disp_max.size),
                    disp_max,
                    'k',
                    linewidth=1.2,
                    color='red',
                    label='Max' + ', ' + self.file_name + ', ' + self.x_axis)
            ax.plot(np.linspace(0, 1., disp_max.size),
                    disp_min,
                    'k',
                    linewidth=1.2,
                    color='green',
                    label='Min' + ', ' + self.file_name + ', ' + self.x_axis)
        else:
            ax.plot(np.linspace(0, complete_cycles_number, disp_max.size),
                    disp_max,
                    'k',
                    linewidth=1.2,
                    color='red',
                    label='Max' + ', ' + self.file_name + ', ' + self.x_axis)
            ax.plot(np.linspace(0, complete_cycles_number, disp_max.size),
                    disp_min,
                    'k',
                    linewidth=1.2,
                    color='green',
                    label='Min' + ', ' + self.file_name + ', ' + self.x_axis)

        ax.legend()

        self.plot_list.append('Creep-fatigue: {}, {}'.format(
            self.x_axis, self.y_axis))
        self.data_changed = True

        print('Finished adding creep-fatigue plot!')

    def reset(self):
        self.delimiter = ';'
        self.skip_first_rows = 3
        self.columns_headers_list = []
        self.npy_folder_path = ''
        self.file_name = ''
        self.apply_filters = False
        self.force_name = 'Kraft'
        self.plot_list = []
        self.columns_to_be_averaged = []

    #=========================================================================
    # Configuration of the view
    #=========================================================================

    traits_view = ui.View(ui.HSplit(
        ui.VSplit(
            ui.HGroup(ui.UItem('open_file_csv'),
                      ui.UItem('file_csv', style='readonly', width=0.1),
                      label='Input data'),
            ui.Item('add_columns_average', show_label=False),
            ui.VGroup(
                ui.VGroup(ui.Item(
                    'records_per_second',
                    enabled_when='take_time_from_first_column == False'),
                          ui.Item('take_time_from_first_column'),
                          label='Time calculation',
                          show_border=True),
                ui.VGroup(ui.Item('skip_first_rows'),
                          ui.Item('decimal'),
                          ui.Item('delimiter'),
                          ui.Item('parse_csv_to_npy', show_label=False),
                          label='Processing csv file',
                          show_border=True),
                ui.VGroup(ui.HGroup(ui.Item('plots_num'),
                                    ui.Item('clear_plot')),
                          ui.HGroup(ui.Item('x_axis'),
                                    ui.Item('x_axis_multiplier')),
                          ui.HGroup(ui.Item('y_axis'),
                                    ui.Item('y_axis_multiplier')),
                          ui.VGroup(ui.HGroup(
                              ui.Item('add_plot', show_label=False),
                              ui.Item('apply_filters')),
                                    show_border=True,
                                    label='Plotting X axis with Y axis'),
                          ui.VGroup(ui.HGroup(
                              ui.Item('add_creep_plot', show_label=False),
                              ui.VGroup(ui.Item('normalize_cycles'),
                                        ui.Item('smooth'),
                                        ui.Item('plot_every_nth_point'))),
                                    show_border=True,
                                    label='Plotting Creep-fatigue of x-axis'),
                          ui.Item('plot_list'),
                          show_border=True,
                          label='Plotting'))),
        ui.VGroup(
            ui.Item('force_name'),
            ui.VGroup(ui.VGroup(
                ui.Item('window_length'),
                ui.Item('polynomial_order'),
                enabled_when='activate == True or smooth == True'),
                      show_border=True,
                      label='Smoothing parameters (Savitzky-Golay filter):'),
            ui.VGroup(ui.VGroup(
                ui.Item('activate'),
                ui.Item('peak_force_before_cycles',
                        enabled_when='activate == True')),
                      show_border=True,
                      label='Smooth ascending branch for all displacements:'),
            ui.VGroup(
                ui.Item('cutting_method'),
                ui.VGroup(ui.Item('force_max'),
                          ui.Item('force_min'),
                          label='Max, Min:',
                          show_border=True,
                          enabled_when='cutting_method == "Define Max, Min"'),
                ui.VGroup(
                    ui.Item('min_cycle_force_range'),
                    label='Min cycle force range:',
                    show_border=True,
                    enabled_when=
                    'cutting_method == "Define min cycle range(force difference)"'
                ),
                show_border=True,
                label='Cut fake cycles for creep:'),
            ui.Item('generate_filtered_and_creep_npy', show_label=False),
            show_border=True,
            label='Filters'),
        ui.UItem('figure',
                 editor=MPLFigureEditor(),
                 resizable=True,
                 springy=True,
                 width=0.8,
                 label='2d plots')),
                          title='HCFF Filter',
                          resizable=True,
                          width=0.85,
                          height=0.7)
class ImagePlotInspector(traits.HasTraits):
    #Traits view definitions:

    settingsGroup = traitsui.VGroup(
        traitsui.VGroup(
            traitsui.Item("watchFolderBool", label="Watch Folder?"),
            traitsui.HGroup(traitsui.Item("selectedFile",
                                          label="Select a File"),
                            visible_when="not watchFolderBool"),
            traitsui.HGroup(traitsui.Item("watchFolder",
                                          label="Select a Directory"),
                            visible_when="watchFolderBool"),
            traitsui.HGroup(traitsui.Item("searchString",
                                          label="Filename sub-string"),
                            visible_when="watchFolderBool"),
            label="File Settings",
            show_border=True),
        traitsui.VGroup(traitsui.HGroup('autoRangeColor', 'colorMapRangeLow',
                                        'colorMapRangeHigh'),
                        traitsui.HGroup('horizontalAutoRange',
                                        'horizontalLowerLimit',
                                        'horizontalUpperLimit'),
                        traitsui.HGroup('verticalAutoRange',
                                        'verticalLowerLimit',
                                        'verticalUpperLimit'),
                        label="axis limits",
                        show_border=True),
        traitsui.VGroup(traitsui.HGroup('object.model.scale',
                                        'object.model.offset'),
                        traitsui.HGroup(
                            traitsui.Item('object.model.pixelsX',
                                          label="Pixels X"),
                            traitsui.Item('object.model.pixelsY',
                                          label="Pixels Y")),
                        traitsui.HGroup(
                            traitsui.Item('object.model.ODCorrectionBool',
                                          label="Correct OD?"),
                            traitsui.Item('object.model.ODSaturationValue',
                                          label="OD saturation value")),
                        traitsui.HGroup(
                            traitsui.Item('contourLevels',
                                          label="Contour Levels"),
                            traitsui.Item('colormap', label="Colour Map")),
                        traitsui.HGroup(
                            traitsui.Item('fixAspectRatioBool',
                                          label="Fix Plot Aspect Ratio?")),
                        traitsui.HGroup(
                            traitsui.Item('updatePhysicsBool',
                                          label="Update Physics with XML?")),
                        traitsui.HGroup(
                            traitsui.Item("cameraModel",
                                          label="Update Camera Settings to:")),
                        label="advanced",
                        show_border=True),
        label="settings")

    plotGroup = traitsui.Group(
        traitsui.Item('container',
                      editor=ComponentEditor(size=(800, 600)),
                      show_label=False))
    fitsGroup = traitsui.Group(traitsui.Item('fitList',
                                             style="custom",
                                             editor=traitsui.ListEditor(
                                                 use_notebook=True,
                                                 selected="selectedFit",
                                                 deletable=False,
                                                 export='DockWindowShell',
                                                 page_name=".name"),
                                             label="Fits",
                                             show_label=False),
                               springy=True)

    mainPlotGroup = traitsui.HSplit(plotGroup, fitsGroup, label="Image")

    fftGroup = traitsui.Group(label="Fourier Transform")

    physicsGroup = traitsui.Group(traitsui.Item(
        "physics",
        editor=traitsui.InstanceEditor(),
        style="custom",
        show_label=False),
                                  label="Physics")

    logFilePlotGroup = traitsui.Group(traitsui.Item(
        "logFilePlotObject",
        editor=traitsui.InstanceEditor(),
        style="custom",
        show_label=False),
                                      label="Log File Plotter")

    eagleMenubar = traitsmenu.MenuBar(
        traitsmenu.Menu(
            traitsui.Action(name='Save Image Copy As...',
                            action='_save_image_as'),
            traitsui.Action(name='Save Image Copy',
                            action='_save_image_default'),
            name="File",
        ))

    traits_view = traitsui.View(settingsGroup,
                                mainPlotGroup,
                                physicsGroup,
                                logFilePlotGroup,
                                buttons=traitsmenu.NoButtons,
                                menubar=eagleMenubar,
                                handler=EagleHandler,
                                title="Experiment Eagle",
                                statusbar="selectedFile",
                                icon=pyface.image_resource.ImageResource(
                                    os.path.join('icons', 'eagles.ico')),
                                resizable=True)

    model = CameraImage()
    physics = physicsProperties.physicsProperties.PhysicsProperties(
    )  #create a physics properties object
    logFilePlotObject = logFilePlot.LogFilePlot()
    fitList = model.fitList
    selectedFit = traits.Instance(fits.Fit)
    drawFitRequest = traits.Event
    drawFitBool = traits.Bool(False)  # true when drawing a fit as well
    selectedFile = traits.File()
    watchFolderBool = traits.Bool(False)
    watchFolder = traits.Directory()
    searchString = traits.String(
        desc=
        "sub string that must be contained in file name for it to be shown in Eagle. Can be used to allow different instances of Eagle to detect different saved images."
    )
    oldFiles = set()
    contourLevels = traits.Int(15)
    colormap = traits.Enum(colormaps.color_map_name_dict.keys())

    autoRangeColor = traits.Bool(True)
    colorMapRangeLow = traits.Float
    colorMapRangeHigh = traits.Float

    horizontalAutoRange = traits.Bool(True)
    horizontalLowerLimit = traits.Float
    horizontalUpperLimit = traits.Float

    verticalAutoRange = traits.Bool(True)
    verticalLowerLimit = traits.Float
    verticalUpperLimit = traits.Float

    fixAspectRatioBool = traits.Bool(False)
    updatePhysicsBool = traits.Bool(True)

    cameraModel = traits.Enum("Custom", "ALTA0", "ANDOR0", "ALTA1", "ANDOR1")

    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------
    _image_index = traits.Instance(chaco.GridDataSource)
    _image_value = traits.Instance(chaco.ImageData)
    _cmap = traits.Trait(colormaps.jet, traits.Callable)

    #---------------------------------------------------------------------------
    # Public View interface
    #---------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(ImagePlotInspector, self).__init__(*args, **kwargs)
        #self.update(self.model)
        self.create_plot()
        for fit in self.fitList:
            fit.imageInspectorReference = self
            fit.physics = self.physics
        self.selectedFit = self.fitList[0]
        self.logFilePlotObject.physicsReference = self.physics
        #self._selectedFile_changed()
        logger.info("initialisation of experiment Eagle complete")

    def create_plot(self):

        # Create the mapper, etc
        self._image_index = chaco.GridDataSource(scipy.array([]),
                                                 scipy.array([]),
                                                 sort_order=("ascending",
                                                             "ascending"))
        image_index_range = chaco.DataRange2D(self._image_index)
        self._image_index.on_trait_change(self._metadata_changed,
                                          "metadata_changed")

        self._image_value = chaco.ImageData(data=scipy.array([]),
                                            value_depth=1)

        image_value_range = chaco.DataRange1D(self._image_value)

        # Create the contour plots
        #self.polyplot = ContourPolyPlot(index=self._image_index,
        self.polyplot = chaco.CMapImagePlot(
            index=self._image_index,
            value=self._image_value,
            index_mapper=chaco.GridMapper(range=image_index_range),
            color_mapper=self._cmap(image_value_range))

        # Add a left axis to the plot
        left = chaco.PlotAxis(orientation='left',
                              title="y",
                              mapper=self.polyplot.index_mapper._ymapper,
                              component=self.polyplot)
        self.polyplot.overlays.append(left)

        # Add a bottom axis to the plot
        bottom = chaco.PlotAxis(orientation='bottom',
                                title="x",
                                mapper=self.polyplot.index_mapper._xmapper,
                                component=self.polyplot)
        self.polyplot.overlays.append(bottom)

        # Add some tools to the plot
        self.polyplot.tools.append(
            tools.PanTool(self.polyplot,
                          constrain_key="shift",
                          drag_button="middle"))

        self.polyplot.overlays.append(
            tools.ZoomTool(component=self.polyplot,
                           tool_mode="box",
                           always_on=False))

        self.lineInspectorX = clickableLineInspector.ClickableLineInspector(
            component=self.polyplot,
            axis='index_x',
            inspect_mode="indexed",
            write_metadata=True,
            is_listener=False,
            color="white")

        self.lineInspectorY = clickableLineInspector.ClickableLineInspector(
            component=self.polyplot,
            axis='index_y',
            inspect_mode="indexed",
            write_metadata=True,
            color="white",
            is_listener=False)

        self.polyplot.overlays.append(self.lineInspectorX)
        self.polyplot.overlays.append(self.lineInspectorY)

        self.boxSelection2D = boxSelection2D.BoxSelection2D(
            component=self.polyplot)
        self.polyplot.overlays.append(self.boxSelection2D)

        # Add these two plots to one container
        self.centralContainer = chaco.OverlayPlotContainer(padding=0,
                                                           use_backbuffer=True,
                                                           unified_draw=True)
        self.centralContainer.add(self.polyplot)

        # Create a colorbar
        cbar_index_mapper = chaco.LinearMapper(range=image_value_range)
        self.colorbar = chaco.ColorBar(
            index_mapper=cbar_index_mapper,
            plot=self.polyplot,
            padding_top=self.polyplot.padding_top,
            padding_bottom=self.polyplot.padding_bottom,
            padding_right=40,
            resizable='v',
            width=30)

        self.plotData = chaco.ArrayPlotData(
            line_indexHorizontal=scipy.array([]),
            line_valueHorizontal=scipy.array([]),
            scatter_indexHorizontal=scipy.array([]),
            scatter_valueHorizontal=scipy.array([]),
            scatter_colorHorizontal=scipy.array([]),
            fitLine_indexHorizontal=scipy.array([]),
            fitLine_valueHorizontal=scipy.array([]))

        self.crossPlotHorizontal = chaco.Plot(self.plotData, resizable="h")
        self.crossPlotHorizontal.height = 100
        self.crossPlotHorizontal.padding = 20
        self.crossPlotHorizontal.plot(
            ("line_indexHorizontal", "line_valueHorizontal"), line_style="dot")
        self.crossPlotHorizontal.plot(
            ("scatter_indexHorizontal", "scatter_valueHorizontal",
             "scatter_colorHorizontal"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=4)

        self.crossPlotHorizontal.index_range = self.polyplot.index_range.x_range

        self.plotData.set_data("line_indexVertical", scipy.array([]))
        self.plotData.set_data("line_valueVertical", scipy.array([]))
        self.plotData.set_data("scatter_indexVertical", scipy.array([]))
        self.plotData.set_data("scatter_valueVertical", scipy.array([]))
        self.plotData.set_data("scatter_colorVertical", scipy.array([]))
        self.plotData.set_data("fitLine_indexVertical", scipy.array([]))
        self.plotData.set_data("fitLine_valueVertical", scipy.array([]))

        self.crossPlotVertical = chaco.Plot(self.plotData,
                                            width=140,
                                            orientation="v",
                                            resizable="v",
                                            padding=20,
                                            padding_bottom=160)
        self.crossPlotVertical.plot(
            ("line_indexVertical", "line_valueVertical"), line_style="dot")

        self.crossPlotVertical.plot(
            ("scatter_indexVertical", "scatter_valueVertical",
             "scatter_colorVertical"),
            type="cmap_scatter",
            name="dot",
            color_mapper=self._cmap(image_value_range),
            marker="circle",
            marker_size=4)

        self.crossPlotVertical.index_range = self.polyplot.index_range.y_range

        # Create a container and add components
        self.container = chaco.HPlotContainer(padding=40,
                                              fill_padding=True,
                                              bgcolor="white",
                                              use_backbuffer=False)

        inner_cont = chaco.VPlotContainer(padding=40, use_backbuffer=True)
        inner_cont.add(self.crossPlotHorizontal)
        inner_cont.add(self.centralContainer)
        self.container.add(self.colorbar)
        self.container.add(inner_cont)
        self.container.add(self.crossPlotVertical)

    def initialiseFitPlot(self):
        """called if this is the first Fit Plot to be drawn """
        xstep = 1.0
        ystep = 1.0
        self.contourXS = scipy.linspace(xstep / 2.,
                                        self.model.pixelsX - xstep / 2.,
                                        self.model.pixelsX - 1)
        self.contourYS = scipy.linspace(ystep / 2.,
                                        self.model.pixelsY - ystep / 2.,
                                        self.model.pixelsY - 1)
        logger.debug("contour initialise fit debug. xs shape %s" %
                     self.contourXS.shape)
        logger.debug("contour initialise xs= %s" % self.contourXS)
        self._fit_value = chaco.ImageData(data=scipy.array([]), value_depth=1)

        self.lineplot = chaco.ContourLinePlot(
            index=self._image_index,
            value=self._fit_value,
            index_mapper=chaco.GridMapper(
                range=self.polyplot.index_mapper.range),
            levels=self.contourLevels)

        self.centralContainer.add(self.lineplot)
        self.plotData.set_data("fitLine_indexHorizontal", self.model.xs)
        self.plotData.set_data("fitLine_indexVertical", self.model.ys)
        self.crossPlotVertical.plot(
            ("fitLine_indexVertical", "fitLine_valueVertical"),
            type="line",
            name="fitVertical")
        self.crossPlotHorizontal.plot(
            ("fitLine_indexHorizontal", "fitLine_valueHorizontal"),
            type="line",
            name="fitHorizontal")
        logger.debug("initialise fit plot %s " % self.crossPlotVertical.plots)

    def addFitPlot(self, fit):
        """add a contour plot on top using fitted data and add additional plots to sidebars (TODO) """
        logger.debug("adding fit plot with fit %s " % fit)
        if not fit.fitted:
            logger.error(
                "cannot add a fitted plot for unfitted data. Run fit first")
            return
        if not self.drawFitBool:
            logger.info("first fit plot so initialising contour plot")
            self.initialiseFitPlot()
        logger.info("attempting to set fit data")
        self.contourPositions = [
            scipy.tile(self.contourXS, len(self.contourYS)),
            scipy.repeat(self.contourYS, len(self.contourXS))
        ]  #for creating data necessary for gauss2D function
        zsravelled = fit.fitFunc(self.contourPositions,
                                 *fit._getCalculatedValues())
        #        logger.debug("zs ravelled shape %s " % zsravelled.shape)
        self.contourZS = zsravelled.reshape(
            (len(self.contourYS), len(self.contourXS)))
        #        logger.debug("zs contour shape %s " % self.contourZS.shape)
        #        logger.info("shape contour = %s " % self.contourZS)
        self._fit_value.data = self.contourZS
        self.container.invalidate_draw()
        self.container.request_redraw()
        self.drawFitBool = True

    def update(self, model):
        logger.info("updating plot")
        #        if self.selectedFile=="":
        #            logger.warning("selected file was empty. Will not attempt to update plot.")
        #            return
        if self.autoRangeColor:
            self.colorbar.index_mapper.range.low = model.minZ
            self.colorbar.index_mapper.range.high = model.maxZ
        self._image_index.set_data(model.xs, model.ys)
        self._image_value.data = model.zs
        self.plotData.set_data("line_indexHorizontal", model.xs)
        self.plotData.set_data("line_indexVertical", model.ys)
        if self.drawFitBool:
            self.plotData.set_data("fitLine_indexHorizontal", self.contourXS)
            self.plotData.set_data("fitLine_indexVertical", self.contourYS)
        self.updatePlotLimits()
        self._image_index.metadata_changed = True
        self.container.invalidate_draw()
        self.container.request_redraw()

    #---------------------------------------------------------------------------
    # Event handlers
    #---------------------------------------------------------------------------

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter
        plots."""
        if self.horizontalAutoRange:
            self.crossPlotHorizontal.value_range.low = self.model.minZ
            self.crossPlotHorizontal.value_range.high = self.model.maxZ
        if self.verticalAutoRange:
            self.crossPlotVertical.value_range.low = self.model.minZ
            self.crossPlotVertical.value_range.high = self.model.maxZ
        if self._image_index.metadata.has_key("selections"):
            selections = self._image_index.metadata["selections"]
            if not selections:  #selections is empty list
                return  #don't need to do update lines as no mouse over screen. This happens at beginning of script
            x_ndx, y_ndx = selections
            if y_ndx and x_ndx:
                self.plotData.set_data("line_valueHorizontal",
                                       self._image_value.data[y_ndx, :])
                self.plotData.set_data("line_valueVertical",
                                       self._image_value.data[:, x_ndx])
                xdata, ydata = self._image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()
                self.plotData.set_data("scatter_indexHorizontal",
                                       scipy.array([xdata[x_ndx]]))
                self.plotData.set_data("scatter_indexVertical",
                                       scipy.array([ydata[y_ndx]]))
                self.plotData.set_data(
                    "scatter_valueHorizontal",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_valueVertical",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_colorHorizontal",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                self.plotData.set_data(
                    "scatter_colorVertical",
                    scipy.array([self._image_value.data[y_ndx, x_ndx]]))
                if self.drawFitBool:
                    self.plotData.set_data("fitLine_valueHorizontal",
                                           self._fit_value.data[y_ndx, :])
                    self.plotData.set_data("fitLine_valueVertical",
                                           self._fit_value.data[:, x_ndx])
        else:
            self.plotData.set_data("scatter_valueHorizontal", scipy.array([]))
            self.plotData.set_data("scatter_valueVertical", scipy.array([]))
            self.plotData.set_data("line_valueHorizontal", scipy.array([]))
            self.plotData.set_data("line_valueVertical", scipy.array([]))
            self.plotData.set_data("fitLine_valueHorizontal", scipy.array([]))
            self.plotData.set_data("fitLine_valueVertical", scipy.array([]))

    def _colormap_changed(self):
        self._cmap = colormaps.color_map_name_dict[self.colormap]
        if hasattr(self, "polyplot"):
            value_range = self.polyplot.color_mapper.range
            self.polyplot.color_mapper = self._cmap(value_range)
            value_range = self.crossPlotHorizontal.color_mapper.range
            self.crossPlotHorizontal.color_mapper = self._cmap(value_range)
            # FIXME: change when we decide how best to update plots using
            # the shared colormap in plot object
            self.crossPlotHorizontal.plots["dot"][0].color_mapper = self._cmap(
                value_range)
            self.crossPlotVertical.plots["dot"][0].color_mapper = self._cmap(
                value_range)
            self.container.request_redraw()

    def _colorMapRangeLow_changed(self):
        self.colorbar.index_mapper.range.low = self.colorMapRangeLow

    def _colorMapRangeHigh_changed(self):
        self.colorbar.index_mapper.range.high = self.colorMapRangeHigh

    def _horizontalLowerLimit_changed(self):
        self.crossPlotHorizontal.value_range.low = self.horizontalLowerLimit

    def _horizontalUpperLimit_changed(self):
        self.crossPlotHorizontal.value_range.high = self.horizontalUpperLimit

    def _verticalLowerLimit_changed(self):
        self.crossPlotVertical.value_range.low = self.verticalLowerLimit

    def _verticalUpperLimit_changed(self):
        self.crossPlotVertical.value_range.high = self.verticalUpperLimit

    def _autoRange_changed(self):
        if self.autoRange:
            self.colorbar.index_mapper.range.low = self.minz
            self.colorbar.index_mapper.range.high = self.maxz

    def _num_levels_changed(self):
        if self.num_levels > 3:
            self.polyplot.levels = self.num_levels
            self.lineplot.levels = self.num_levels

    def _colorMapRangeLow_default(self):
        logger.debug("setting color map rangle low default")
        return self.model.minZ

    def _colorMapRangeHigh_default(self):
        return self.model.maxZ

    def _horizontalLowerLimit_default(self):
        return self.model.minZ

    def _horizontalUpperLimit_default(self):
        return self.model.maxZ

    def _verticalLowerLimit_default(self):
        return self.model.minZ

    def _verticalUpperLimit_default(self):
        return self.model.maxZ

    def _selectedFit_changed(self, selected):
        logger.debug("selected fit was changed")

    def _fixAspectRatioBool_changed(self):
        if self.fixAspectRatioBool:
            #using zoom range works but then when you reset zoom this function isn't called...
            #            rangeObject = self.polyplot.index_mapper.range
            #            xrangeValue = rangeObject.high[0]-rangeObject.low[0]
            #            yrangeValue = rangeObject.high[1]-rangeObject.low[1]
            #            logger.info("xrange = %s, yrange = %s " % (xrangeValue, yrangeValue))
            #            aspectRatioSquare = (xrangeValue)/(yrangeValue)
            #            self.polyplot.aspect_ratio=aspectRatioSquare
            self.centralContainer.aspect_ratio = float(
                self.model.pixelsX) / float(self.model.pixelsY)
            #self.polyplot.aspect_ratio = self.model.pixelsX/self.model.pixelsY

        else:
            self.centralContainer.aspect_ratio = None
            #self.polyplot.aspect_ratio = None
        self.container.request_redraw()
        self.centralContainer.request_redraw()

    def updatePlotLimits(self):
        """just updates the values in the GUI  """
        if self.autoRangeColor:
            self.colorMapRangeLow = self.model.minZ
            self.colorMapRangeHigh = self.model.maxZ
        if self.horizontalAutoRange:
            self.horizontalLowerLimit = self.model.minZ
            self.horizontalUpperLimit = self.model.maxZ
        if self.verticalAutoRange:
            self.verticalLowerLimit = self.model.minZ
            self.verticalUpperLimit = self.model.maxZ

    def _selectedFile_changed(self):
        self.model.getImageData(self.selectedFile)
        if self.updatePhysicsBool:
            self.physics.updatePhysics()
        for fit in self.fitList:
            fit.fitted = False
            fit.fittingStatus = fit.notFittedForCurrentStatus
            if fit.autoFitBool:  #we should automatically start fitting for this Fit
                fit._fit_routine(
                )  #starts a thread to perform the fit. auto guess and auto draw will be handled automatically
        self.update_view()
        #update log file plot if autorefresh is selected
        if self.logFilePlotObject.autoRefresh:
            try:
                self.logFilePlotObject.refreshPlot()
            except Exception as e:
                logger.error("failed to update log plot -  %s...." % e.message)

    def _cameraModel_changed(self):
        """camera model enum can be used as a helper. It just sets all the relevant
        editable parameters to the correct values. e.g. pixels size, etc.

        cameras:  "Andor Ixon 3838", "Apogee ALTA"
        """
        logger.info("camera model changed")
        if self.cameraModel == "ANDOR0":
            self.model.pixelsX = 512
            self.model.pixelsY = 512
            self.physics.pixelSize = 16.0
            self.physics.magnification = 2.0
            self.searchString = "ANDOR0"
        elif self.cameraModel == "ALTA0":
            self.model.pixelsX = 768
            self.model.pixelsY = 512
            self.physics.pixelSize = 9.0
            self.physics.magnification = 0.5
            self.searchString = "ALTA0"
        elif self.cameraModel == "ALTA1":
            self.model.pixelsX = 768
            self.model.pixelsY = 512
            self.physics.pixelSize = 9.0
            self.physics.magnification = 4.25
            self.searchString = "ALTA1"
        elif self.cameraModel == "ANDOR1":
            self.model.pixelsX = 512
            self.model.pixelsY = 512
            self.physics.pixelSize = 16.0
            self.physics.magnification = 2.0
            self.searchString = "ANDOR1"
        else:
            logger.error("unrecognised camera model")
        self.refreshFitReferences()
        self.model.getImageData(self.selectedFile)

    def refreshFitReferences(self):
        """When aspects of the image change so that the fits need to have
        properties updated, it should be done by this function"""
        for fit in self.fitList:
            fit.endX = self.model.pixelsX
            fit.endY = self.model.pixelsY

    def _pixelsX_changed(self):
        """If pixelsX or pixelsY change, we must send the new arrays to the fit functions """
        logger.info("pixels X Change detected")
        self.refreshFitReferences()
        self.update(self.model)
        self.model.getImageData(self.selectedFile)

    def _pixelsY_changed(self):
        """If pixelsX or pixelsY change, we must send the new arrays to the fit functions """
        logger.info("pixels Y Change detected")
        self.refreshFitReferences()
        self.update(self.model)
        self.model.getImageData(self.selectedFile)

    @traits.on_trait_change('model')
    def update_view(self):
        if self.model is not None:
            self.update(self.model)

    def _save_image(self, originalFilePath, newFilePath):
        """given the original file path this saves a new copy to new File path """
        shutil.copy2(originalFilePath, newFilePath)

    def _save_image_as(self):
        """ opens a save as dialog and allows user to save a copy of current image to
        a custom location with a custom name"""
        originalFilePath = str(
            self.selectedFile
        )  #so that this can't be affected by auto update after the dialog is open
        file_wildcard = str("PNG (*.png)|All files|*")
        default_directory = os.path.join("\\\\ursa", "AQOGroupFolder",
                                         "Experiment Humphry", "Data",
                                         "savedEagleImages")
        dialog = FileDialog(action="save as",
                            default_directory=default_directory,
                            wildcard=file_wildcard)
        dialog.open()
        if dialog.return_code == OK:
            self._save_image(originalFilePath, dialog.path)
        logger.debug("custom image copy made")

    def _save_image_default(self):
        head, tail = os.path.split(self.selectedFile)
        default_file = os.path.join("\\\\ursa", "AQOGroupFolder",
                                    "Experiment Humphry", "Data",
                                    "savedEagleImages", tail)
        self._save_image(self.selectedFile, default_file)
        logger.debug("default image copy made")
class Comparator(HasTraits):
    """ The main application.
    """

    #### Configuration traits ##################################################

    # The root directory of the test suite.
    suitedir = Str()

    # Mapping of SVG basenames to their reference PNGs. Use None if there is no
    # reference PNG.
    svg_png = Dict()

    # The list of SVG file names.
    svg_files = List()

    # The name of the default PNG file to display when no reference PNG exists.
    default_png = Str(os.path.join(this_dir, 'images/default.png'))

    #### State traits ##########################################################

    # The currently selected SVG file.
    current_file = Str()
    abs_current_file = Property(depends_on=['current_file'])

    # The current XML ElementTree root Element and its XMLTree view model.
    current_xml = Any()
    current_xml_view = Any()

    # The profilers.
    profile_this = Instance(ProfileThis, args=())

    #### GUI traits ############################################################

    # The text showing the current mouse coordinates over any of the components.
    mouse_coords = Property(Str, depends_on=['ch_controller.svg_coords'])

    # Move forward and backward through the list of SVG files.
    move_forward = Button('>>')
    move_backward = Button('<<')

    # The description of the test.
    description = HTML()

    document = Instance(document.SVGDocument)

    # The components to view.
    kiva_component = ComponentTrait(klass=SVGComponent)
    ref_component = ComponentTrait(klass=ImageComponent, args=())
    ch_controller = Instance(MultiController)

    # The profiler views.
    parsing_sike = Instance(Sike, args=())
    drawing_sike = Instance(Sike, args=())
    wx_doc_sike = Instance(Sike, args=())
    kiva_doc_sike = Instance(Sike, args=())

    traits_view = tui.View(
        tui.Tabbed(
            tui.VGroup(
                tui.HGroup(
                    tui.Item('current_file',
                             editor=tui.EnumEditor(name='svg_files'),
                             style='simple',
                             width=1.0,
                             show_label=False),
                    tui.Item(
                        'move_backward',
                        show_label=False,
                        enabled_when="svg_files.index(current_file) != 0"),
                    tui.Item(
                        'move_forward',
                        show_label=False,
                        enabled_when=
                        "svg_files.index(current_file) != len(svg_files)-1"),
                ),
                tui.VSplit(
                    tui.HSplit(
                        tui.Item('description',
                                 label='Description',
                                 show_label=False),
                        tui.Item('current_xml_view',
                                 editor=xml_tree_editor,
                                 show_label=False),
                    ),
                    tui.HSplit(
                        tui.Item('document',
                                 editor=SVGEditor(),
                                 show_label=False),
                        tui.Item('kiva_component', show_label=False),
                        tui.Item('ref_component', show_label=False),
                        # TODO: tui.Item('agg_component', show_label=False),
                    ),
                ),
                label='SVG',
            ),
            tui.Item('parsing_sike',
                     style='custom',
                     show_label=False,
                     label='Parsing Profile'),
            tui.Item('drawing_sike',
                     style='custom',
                     show_label=False,
                     label='Kiva Drawing Profile'),
            tui.Item('wx_doc_sike',
                     style='custom',
                     show_label=False,
                     label='Creating WX document'),
            tui.Item('kiva_doc_sike',
                     style='custom',
                     show_label=False,
                     label='Creating WX document'),
        ),
        width=1280,
        height=768,
        resizable=True,
        statusbar='mouse_coords',
        title='SVG Comparator',
    )

    def __init__(self, **traits):
        super(Comparator, self).__init__(**traits)
        kiva_ch = activate_tool(self.kiva_component,
                                Crosshair(self.kiva_component))
        ref_ch = activate_tool(self.ref_component,
                               Crosshair(self.ref_component))
        self.ch_controller = MultiController(kiva_ch, ref_ch)

    @classmethod
    def fromsuitedir(cls, dirname, **traits):
        """ Find all SVG files and their related reference PNG files under
        a directory.

        This assumes that the SVGs are located under <dirname>/svg/ and the
        related PNGs under <dirname>/png/ and that there are no subdirectories.
        """
        dirname = os.path.abspath(dirname)
        svgs = glob.glob(os.path.join(dirname, 'svg', '*.svg'))
        pngdir = os.path.join(dirname, 'png')
        d = {}
        for svg in svgs:
            png = None
            base = os.path.splitext(os.path.basename(svg))[0]
            for prefix in ('full-', 'basic-', 'tiny-', ''):
                fn = os.path.join(pngdir, prefix + base + '.png')
                if os.path.exists(fn):
                    png = os.path.basename(fn)
                    break
            d[os.path.basename(svg)] = png
        svgs = sorted(d)
        x = cls(suitedir=dirname, svg_png=d, svg_files=svgs, **traits)
        x.current_file = svgs[0]
        return x

    def display_reference_png(self, filename):
        """ Read the image file and shove its data into the display component.
        """
        img = Image.open(filename)
        arr = np.array(img)
        self.ref_component.image = arr

    def display_test_description(self):
        """ Extract the test description for display.
        """
        html = ET.Element('html')

        title = self.current_xml.find('.//{http://www.w3.org/2000/svg}title')
        if title is not None:
            title_text = title.text
        else:
            title_text = os.path.splitext(self.current_file)[0]
        p = ET.SubElement(html, 'p')
        b = ET.SubElement(p, 'b')
        b.text = 'Title: '
        b.tail = title_text

        desc_text = None
        version_text = None
        desc = self.current_xml.find('.//{http://www.w3.org/2000/svg}desc')
        if desc is not None:
            desc_text = desc.text
        else:
            testcase = self.current_xml.find(
                './/{http://www.w3.org/2000/02/svg/testsuite/description/}SVGTestCase'
            )
            if testcase is not None:
                desc_text = testcase.get('desc', None)
                version_text = testcase.get('version', None)
        if desc_text is not None:
            p = ET.SubElement(html, 'p')
            b = ET.SubElement(p, 'b')
            b.text = 'Description: '
            b.tail = normalize_text(desc_text)

        if version_text is None:
            script = self.current_xml.find(
                './/{http://www.w3.org/2000/02/svg/testsuite/description/}OperatorScript'
            )
            if script is not None:
                version_text = script.get('version', None)
        if version_text is not None:
            p = ET.SubElement(html, 'p')
            b = ET.SubElement(p, 'b')
            b.text = 'Version: '
            b.tail = version_text

        paras = self.current_xml.findall(
            './/{http://www.w3.org/2000/02/svg/testsuite/description/}Paragraph'
        )
        if len(paras) > 0:
            div = ET.SubElement(html, 'div')
            for para in paras:
                p = ET.SubElement(div, 'p')
                p.text = normalize_text(para.text)
                # Copy over any children elements like <a>.
                p[:] = para[:]

        tree = ET.ElementTree(html)
        f = StringIO()
        tree.write(f)
        text = f.getvalue()
        self.description = text

    def locate_file(self, name, kind):
        """ Find the location of the given file in the suite.

        Parameters
        ----------
        name : str
            Path of the file relative to the suitedir.
        kind : either 'svg' or 'png'
            The kind of file.

        Returns
        -------
        path : str
            The full path to the file.
        """
        return os.path.join(self.suitedir, kind, name)

    def _kiva_component_default(self):
        return SVGComponent(profile_this=self.profile_this)

    def _move_backward_fired(self):
        idx = self.svg_files.index(self.current_file)
        idx = max(idx - 1, 0)
        self.current_file = self.svg_files[idx]

    def _move_forward_fired(self):
        idx = self.svg_files.index(self.current_file)
        idx = min(idx + 1, len(self.svg_files) - 1)
        self.current_file = self.svg_files[idx]

    def _get_abs_current_file(self):
        return self.locate_file(self.current_file, 'svg')

    def _current_file_changed(self, new):
        # Reset the warnings filters. While it's good to only get 1 warning per
        # file, we want to get the same warning again if a new file issues it.
        warnings.resetwarnings()

        self.profile_this.start('Parsing')
        self.current_xml = ET.parse(self.abs_current_file).getroot()
        self.current_xml_view = xml_to_tree(self.current_xml)
        resources = document.ResourceGetter.fromfilename(self.abs_current_file)
        self.profile_this.stop()
        try:
            self.profile_this.start('Creating WX document')
            self.document = document.SVGDocument(self.current_xml,
                                                 resources=resources,
                                                 renderer=WxRenderer)
        except:
            logger.exception('Error parsing document %s', new)
            self.document = None

        self.profile_this.stop()

        try:
            self.profile_this.start('Creating Kiva document')
            self.kiva_component.document = document.SVGDocument(
                self.current_xml, resources=resources, renderer=KivaRenderer)
        except Exception as e:
            logger.exception('Error parsing document %s', new)
            self.kiva_component.document

        self.profile_this.stop()

        png_file = self.svg_png.get(new, None)
        if png_file is None:
            png_file = self.default_png
        else:
            png_file = self.locate_file(png_file, 'png')
        self.display_test_description()
        self.display_reference_png(png_file)

    def _get_mouse_coords(self):
        if self.ch_controller is None:
            return ''
        else:
            return '%1.3g %1.3g' % self.ch_controller.svg_coords

    @on_trait_change('profile_this:profile_ended')
    def _update_profiling(self, new):
        if new is not None:
            name, p = new
            stats = pstats.Stats(p)
            if name == 'Parsing':
                self.parsing_sike.stats = stats
            elif name == 'Drawing':
                self.drawing_sike.stats = stats
            elif name == 'Creating WX document':
                self.wx_doc_sike.stats = stats
            elif name == 'Creating Kiva document':
                self.kiva_doc_sike.stats = stats
Exemple #11
0
class ExperimentSnake(traits.HasTraits):
    """Main Experiment Snake GUI that sends arbitrary actions based on the 
    experiment runner sequence and actions that have been set up."""

    #mainLog = utilities.TextDisplay()
    mainLog = outputStream.OutputStream()
    statusString = traits.String("Press Start Snake to begin...")
    isRunning = traits.Bool(False)  # true when the snake is running
    sequenceStarted = traits.Bool(
        False)  # flashes true for ~1ms when sequence starts
    queue = traits.Int(0)

    variables = traits.Dict(
        key_trait=traits.Str, value_trait=traits.Float
    )  #dictionary mapping variable names in Exp control to their values in this sequence
    timingEdges = traits.Dict(
        key_trait=traits.Str, value_trait=traits.Float
    )  #dictionary mapping timing Edge names in Exp control to their values in this sequence
    statusList = [
    ]  #eventually will contain the information gathered from experiment Runner each time we poll

    startAction = traitsui.Action(name='start',
                                  action='_startSnake',
                                  image=pyface.image_resource.ImageResource(
                                      os.path.join('icons', 'start.png')))
    stopAction = traitsui.Action(name='stop',
                                 action='_stopSnakeToolbar',
                                 image=pyface.image_resource.ImageResource(
                                     os.path.join('icons', 'stop.png')))
    reloadHWAsAction = traitsui.Action(
        name='reload',
        action='_reloadHWAsToolbar',
        image=pyface.image_resource.ImageResource(
            os.path.join('icons', 'reload.png')))

    connectionTimer = traits.Instance(
        Timer
    )  # polls the experiment runner and starts off callbacks at appropriate times
    statusStringTimer = traits.Instance(
        Timer)  #updates status bar at regular times (less freque)
    getCurrentTimer = traits.Instance(
        Timer
    )  #waits for get current to return which marks the beginning of a sequence

    getCurrentThread = traits.Instance(SocketThread)

    connectionPollFrequency = traits.Float(
        1000.0)  #milliseconds defines accuracy you will perform callbacks at
    statusStringFrequency = traits.Float(2000.0)  #milliseconds
    getCurrentFrequency = traits.Float(
        1000.0)  #milliseconds should be shorter than the sequence

    timeRunning = traits.Float(0.0)  #how long the sequence has been running
    totalTime = traits.Float(0.0)  # total length of sequence
    runnerHalted = traits.Bool(True)  # true if runner is halted
    haltedCount = 0
    progress = traits.Float(0.0)  # % of cycle complete
    #progressBar = ProgressDialog()
    hardwareActions = traits.List(hardwareAction.hardwareAction.HardwareAction)

    examineVariablesDictionary = traits.Instance(
        variableDictionary.ExamineVariablesDictionary)
    xmlString = ""  # STRING that will contain entire XML File

    menubar = traitsui.MenuBar(
        traitsui.Menu(
            traitsui.Action(name='Start Snake', action='_startSnake'),
            traitsui.Action(name='Stop Snake', action='_stopSnake'),
            traitsui.Action(name='Reload', action='_reloadHWAs'),
            traitsui.Menu(traitsui.Action(name='DEBUG',
                                          action='_changeLoggingLevelDebug'),
                          traitsui.Action(name='INFO',
                                          action='_changeLoggingLevelInfo'),
                          traitsui.Action(name='WARNING',
                                          action='_changeLoggingLevelWarning'),
                          traitsui.Action(name='ERROR',
                                          action='_changeLoggingLevelError'),
                          name="Log Level"),
            name='Menu'))

    toolbar = traitsui.ToolBar(startAction, stopAction, reloadHWAsAction)

    mainSnakeGroup = traitsui.VGroup(
        traitsui.Item('statusString', show_label=False, style='readonly'),
        traitsui.Item('mainLog',
                      show_label=False,
                      springy=True,
                      style='custom',
                      editor=traitsui.InstanceEditor()))

    hardwareActionsGroup = traitsui.Group(traitsui.Item(
        'hardwareActions',
        show_label=False,
        style='custom',
        editor=traitsui.ListEditor(style="custom")),
                                          label="Hardware Actions",
                                          show_border=True)

    variableExaminerGroup = traitsui.Group(traitsui.Item(
        "examineVariablesDictionary",
        editor=traitsui.InstanceEditor(),
        style="custom",
        show_label=False),
                                           label="Variable Examiner")

    sidePanelGroup = traitsui.VSplit(hardwareActionsGroup,
                                     variableExaminerGroup)

    traits_view = traitsui.View(traitsui.HSplit(sidePanelGroup,
                                                mainSnakeGroup,
                                                show_labels=True),
                                resizable=True,
                                menubar=menubar,
                                toolbar=toolbar,
                                width=0.5,
                                height=0.75,
                                title="Experiment Snake",
                                icon=pyface.image_resource.ImageResource(
                                    os.path.join('icons', 'snakeIcon.ico')))

    def __init__(self, **traits):
        """ takes no  arguments to construct the snake. Everything is done through GUI.
        Snake construction makes a ExperimentSnakeConnection object and writes to the 
        main log window"""

        super(ExperimentSnake, self).__init__(**traits)
        self.connection = experimentRunnerConnection.Connection(
        )  #can override default ports and IP
        self.hardwareActions = [
            hardwareAction.sequenceLoggerHWA.SequenceLogger(
                0.0, snakeReference=self),
            hardwareAction.experimentTablesHWA.ExperimentTables(
                0.0, snakeReference=self, enabled=False),
            hardwareAction.dlicEvapHWA.EvaporationRamp(1.0,
                                                       snakeReference=self),
            #hardwareAction.dlicRFSweepHWA.DLICRFSweep(1.0, snakeReference = self,enabled=False),
            hardwareAction.dlicRFSweepLZHWA.DLICRFSweep(1.0,
                                                        snakeReference=self,
                                                        enabled=False),
            hardwareAction.dlicRFSweepLZWithPowerCtrlHWA.DLICRFSweep(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.dlicRFSweepLZWithPowerCtrl13PreparationHWA.
            DLICRFSweep(1.0, snakeReference=self, enabled=True),
            hardwareAction.dlicPiPulseHWA.DLICPiPulse(1.0,
                                                      snakeReference=self,
                                                      enabled=False),
            hardwareAction.evapAttenuationHWA.EvapAttenuation(
                1.0, snakeReference=self),
            hardwareAction.greyMollassesOffsetFreqHWA.GreyMollassesOffset(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.evapAttenuation2HWA.EvapAttenuation(
                "EvapSnakeAttenuationTimeFinal",
                snakeReference=self,
                enabled=False),
            hardwareAction.picomotorPlugHWA.PicomotorPlug(1.0,
                                                          snakeReference=self,
                                                          enabled=False),
            hardwareAction.windFreakOffsetLockHWA.WindFreak(
                0.0, snakeReference=self, enabled=False),
            hardwareAction.windFreakOffsetLockHighFieldImagingHWA.WindFreak(
                0.0, snakeReference=self, enabled=True),
            hardwareAction.windFreakOffsetLock6ImagingHWA.WindFreak(
                2.0, snakeReference=self, enabled=False),
            hardwareAction.windFreak6To1HWA.WindFreak(2.0,
                                                      snakeReference=self,
                                                      enabled=False),
            hardwareAction.windFreakOffsetLockLaser3.WindFreak(
                3.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaZSFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaZSAtten(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaZSEOMFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaZSEOMAtten(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaSpecFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiImaging(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiImagingDetuning(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiPushPulseAttenuation(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiPushPulseDetuning(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaDarkSpotAOMFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaDarkSpotAOMAtten(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaMOTFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaMOTAtten(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaMOTEOMAtten(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaImagingDP(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiMOTRep(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiMOTCool(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelLiOpticalPump(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNa2to2OpticalPumpingFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNa2to2OpticalPumpingAtt(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaHighFieldImagingFreq(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.AOMChannelHWAs.AOMChannelNaHighFieldImagingAtt(
                1.0, snakeReference=self, enabled=False),
            hardwareAction.digitalMultimeterCurrentMeasureHWA.
            DigitalMultimeterMeasurement(1.0,
                                         snakeReference=self,
                                         enabled=True),
            hardwareAction.MXGPiPulseHWA.PiPulse(1.0,
                                                 snakeReference=self,
                                                 enabled=False),
            hardwareAction.variableExplorerHWA.VariableExplorer(
                2.0, snakeReference=self, enabled=False),
            hardwareAction.jds6600HWA.JDS6600HWA(1.0,
                                                 snakeReference=self,
                                                 enabled=False),
            hardwareAction.watchdogHWA.WatchdogHWA(18.0,
                                                   snakeReference=self,
                                                   enabled=True)
        ]
        introString = """Welcome to experiment snake."""

        self.mainLog.addLine(introString, 1)

    def initialiseHardwareActions(self):
        for hdwAct in self.hardwareActions:
            if hdwAct.enabled:
                returnString = hdwAct.init()
                hdwAct.variablesReference = self.variables
                self.mainLog.addLine(returnString)

    def closeHardwareActions(self):
        """ this function is called when the user presses stop key. it should cleanly close or 
        shutdown all hardware. user must appropriately implement the hardware action close function"""
        for hdwAct in self.hardwareActions:
            if hdwAct.initialised:
                returnString = hdwAct.close()
                self.mainLog.addLine(returnString)

    def _startSnake(self):
        """action call back from menu or toolbar. Simply starts the timer that
        polls the runner and makes the isRunning bool true  """
        self.mainLog.addLine("Experiment Snake Started", 1)
        self.isRunning = True
        self.getCurrentBlocking()
        self.initialiseHardwareActions()
        self.startTimers()

    def newSequenceStarted(self):
        """called by GetCurrent Thread at the beginning of every sequence """
        if self.isRunning:  #otherwise we have already stopped before new sequence began again
            self.getStatusUpdate()
            self.mainLog.addLine("New cycle started: %s" % self.statusList[0],
                                 1)
            self.refreshExamineVariablesDictionary(
            )  # update the examine variables dictionary to reflect the latest values
            self.refreshVariableDependentCallbackTimes(
            )  # if a callback time is a timing edge name or variable name we must pull the value here
        else:
            self.mainLog.addLine("final connection closed")
        for hdwAct in self.hardwareActions:
            hdwAct.awaitingCallback = True

    def _stopSnakeToolbar(self):
        """if snake is stopped, addLine to main log and then run stopSnake """
        self.mainLog.addLine(
            "Experiment Snake Stopped (you should still wait till the end of this sequence before continuing)",
            1)
        self._stopSnake()

    def _reloadHWAsToolbar(self):
        """if snake is stopped, addLine to main log and then run stopSnake """
        self.mainLog.addLine(
            "Experiment Snake Stopped (you should still wait till the end of this sequence before continuing)",
            1)
        self._reloadHWAs()

    def _reloadHWAs(self):
        """if snake is stopped, addLine to main log and then run stopSnake """
        self.mainLog.addLine("Reloading hardware actions (advanced feature)",
                             3)
        reload(hardwareAction.hardwareAction)
        reload(hardwareAction.sequenceLoggerHWA)
        reload(hardwareAction.dlicEvapHWA)
        reload(hardwareAction.dlicRFSweepHWA)
        reload(hardwareAction.dlicRFSweepHWA)
        reload(hardwareAction.evapAttenuationHWA)
        reload(hardwareAction.evapAttenuation2HWA)
        reload(hardwareAction.picomotorPlugHWA)
        reload(hardwareAction.windFreakOffsetLockHWA)
        #reload( hardwareAction.AOMChannelHWAs)#CAUSES REFERENCING PROBLEMS!
        reload(hardwareAction.experimentTablesHWA)
        reload(hardwareAction.windFreakOffsetLockHighFieldImagingHWA)
        reload(hardwareAction.greyMollassesOffsetFreqHWA)
        reload(hardwareAction.dlicRFSweepLZHWA)
        reload(hardwareAction.digitalMultimeterCurrentMeasureHWA)
        self.__init__()

    def stopTimers(self):
        """stops all timers with error catching """
        try:
            #stop any previous timer, should only have 1 timer at a time
            if self.connectionTimer is not None:
                self.connectionTimer.stop()
        except Exception as e:
            logger.error(
                "couldn't stop current timer before starting new one: %s" %
                e.message)
        try:
            #stop any previous timer, should only have 1 timer at a time
            if self.statusStringTimer is not None:
                self.statusStringTimer.stop()
        except Exception as e:
            logger.error(
                "couldn't stop current timer before starting new one: %s" %
                e.message)
        try:
            #stop any previous timer, should only have 1 timer at a time
            if self.getCurrentTimer is not None:
                self.getCurrentTimer.stop()
        except Exception as e:
            logger.error(
                "couldn't stop current timer before starting new one: %s" %
                e.message)

    def _stopSnake(self):
        """Simply stops the timers, shuts down hardware and sets isRunning bool false  """
        self.stopTimers()
        self.closeHardwareActions()
        self.isRunning = False

    def startTimers(self):
        """This timer object polls the experiment runner regularly polling at any time"""
        #stop any previous timers
        self.stopTimers()
        #start timer
        self.connectionTimer = Timer(self.connectionPollFrequency,
                                     self.getStatus)
        time.sleep(0.1)
        self.statusStringTimer = Timer(self.statusStringFrequency,
                                       self.updateStatusString)
        time.sleep(0.1)
        self.getCurrentTimer = Timer(self.getCurrentFrequency, self.getCurrent)
        """Menu action function to change logger level """
        logger.info("timers started")

    def getStatus(self):
        """calls the connection objects get status function and updates the statusList """
        logger.debug("starting getStatus")
        try:
            self.getStatusUpdate()
            self.checkForCallback()
        except Exception as e:
            logger.error("error in getStatus Function")
            logger.error("error: %s " % e.message)
            self.mainLog.addLine(
                "error in getStatus Function. Error: %s" % e.message, 4)

    def getStatusUpdate(self):
        """Calls get status and updates times """
        try:
            statusString = self.connection.getStatus()
        except socket.error as e:
            logger.error(
                "failed to get status . message=%s . errno=%s . errstring=%s "
                % (e.message, e.errno, e.strerror))
            self.mainLog.addLine(
                "Failed to get status from Experiment Runner. message=%s . errno=%s . errstring=%s"
                % (e.message, e.errno, e.strerror), 3)
            self.mainLog.addLine(
                "Cannot update timeRunning - callbacks could be wrong this sequence!",
                4)
            return
        self.statusList = statusString.split("\n")
        timeFormat = '%d/%m/%Y %H:%M:%S'
        timeBegin = datetime.datetime.strptime(self.statusList[2], timeFormat)
        timeCurrent = datetime.datetime.strptime(self.statusList[3],
                                                 timeFormat)
        self.timeRunning = (timeCurrent - timeBegin).total_seconds()
        logger.debug("time Running = %s " % self.timeRunning)

    def checkForCallback(self):
        """if we've received a sequence, we check through all callback times and
        send off a callback on a hardware action if appropriate"""
        try:
            for hdwAct in [
                    hdwA for hdwA in self.hardwareActions if hdwA.enabled
            ]:  #only iterate through enable hardware actions
                if hdwAct.awaitingCallback and self.timeRunning >= hdwAct.callbackTime:  #callback should be started!
                    try:
                        logger.debug("attempting to callback %s " %
                                     hdwAct.hardwareActionName)
                        hdwAct.setVariablesDictionary(self.variables)
                        logger.debug("vars dictionary set to %s " %
                                     self.variables)
                        callbackReturnString = hdwAct.callback()
                        self.mainLog.addLine(
                            "%s @ %s secs : %s" %
                            (hdwAct.hardwareActionName, self.timeRunning,
                             callbackReturnString), 2)
                        hdwAct.awaitingCallback = False
                        hdwAct.callbackCounter += 1
                    except Exception as e:
                        logger.error(
                            "error while performing callback on %s. see error message below"
                            % (hdwAct.hardwareActionName))
                        logger.error("error: %s " % e.message)
                        self.mainLog.addLine(
                            "error while performing callback on %s. Error: %s"
                            % (hdwAct.hardwareActionName, e.message), 4)
        except Exception as e:
            logger.error("error in checkForCallbackFunction")
            logger.error("error: %s " % e.message)
            self.mainLog.addLine(
                "error in checkForCallbackFunction. Error: %s" % e.message, 4)

    def getCurrent(self):
        """calls the connection objects get status function and updates the variables dictionary """
        if self.getCurrentThread and self.getCurrentThread.isAlive():
            #logger.debug( "getCurrent - already waiting - will not start new thread")
            #removed the above as it fills the log without any useful information
            self.sequenceStarted = False
            return
        else:
            logger.info("starting getCurrent Thread")
            self.getCurrentThread = SocketThread()
            self.getCurrentThread.snakeReference = self  # for calling functions of the snake
            self.getCurrentThread.start()

    def getCurrentBlocking(self):
        """calls getCurrent and won't return until XML parsed. unlike above threaded function
        This is useful when starting up the snake so that we don't start looking for hardware events
        until a sequence has started and we have received XML"""
        self.mainLog.addLine("Waiting for next sequence to start")
        self.xmlString = self.connection.getCurrent(
        )  # only returns at the beginning of a sequence! Experiment runner then returns the entire XML file
        logger.debug("length of xml string = %s " % len(self.xmlString))
        logger.debug("end of xml file is like [-30:]= %s" %
                     self.xmlString[-30:])
        try:
            root = ET.fromstring(self.xmlString)
            variables = root.find("variables")
            self.variables = {
                child[0].text: float(child[1].text)
                for child in variables
            }
            #timing edges dictionary : name--> value
            self.timingEdges = {
                timingEdge.find("name").text:
                float(timingEdge.find("value").text)
                for timingEdge in root.find("timing")
            }
            self.newSequenceStarted()
        except ET.ParseError as e:
            self.mainLog.addLine("Error. Could not parse XML: %s" % e.message,
                                 3)
            self.mainLog.addLine(
                "Possible cause is that buffer is full. is XML length %s>= limit %s ????"
                % (len(self.xmlString), self.connection.BUFFER_SIZE_XML), 3)
            logger.error("could not parse XML: %s " % self.xmlString)
            logger.error(e.message)

    def updateStatusString(self):
        """update the status string with first element of return of GETSTATUS. 
        similiar to experiment control and camera control. It also does the analysis
        of progress that doesn't need to be as accurate (e.g. progress bar)"""
        logger.info("starting update status string")
        self.statusString = self.statusList[
            0] + "- Time Running = %s " % self.timeRunning
        self.queue = int(self.statusList[1])
        timeFormat = '%d/%m/%Y %H:%M:%S'
        timeBegin = datetime.datetime.strptime(self.statusList[2], timeFormat)
        timeEnd = datetime.datetime.strptime(self.statusList[4], timeFormat)
        self.timeTotal = (timeEnd - timeBegin).total_seconds()
        if self.timeRunning > self.timeTotal:
            self.haltedCount += 1
            self.runnerHalted = True
            if self.haltedCount == 0:
                logger.critical("runner was stopped.")
                self.mainLog.addLine("Runner stopped!", 3)
                self.closeHardwareActions()
        else:
            if self.haltedCount > 0:
                self.initialiseHardwareActions()
            self.haltedCount = 0
            self.runnerHalted = False
        self.progress = 100.0 * self.timeRunning / self.timeTotal

    def _examineVariablesDictionary_default(self):
        if len(self.hardwareActions) > 0:
            logger.debug(
                "returning first hardware action %s for examineVariablesDictionary default"
                % self.hardwareActions[0].hardwareActionName)
            return variableDictionary.ExamineVariablesDictionary(
                hdwA=self.hardwareActions[0]
            )  #default is the first in the list
        else:
            logger.warning(
                "hardwareActions list was empty. how should I populate variable examiner...?!."
            )
            return None

    def updateExamineVariablesDictionary(self, hdwA):
        """Populates the examineVariablesDictionary Pane appropriately. It is passed the 
        hdwA so that it can find the necessary variables"""
        self.examineVariablesDictionary.hdwA = hdwA
        self.examineVariablesDictionary.hardwareActionName = hdwA.hardwareActionName
        self.examineVariablesDictionary.updateDisplayList()
        logger.critical("examineVariablesDictionary changed")

    def refreshExamineVariablesDictionary(self):
        """calls the updateDisplayList function of examineVariables Dictionary 
        this updates the values in the display list to the latest values in variables
        dictionary. useful for refereshing at the beginning of a sequence"""
        self.examineVariablesDictionary.updateDisplayList()
        logger.info("refreshed examine variables dictionary")

    def refreshVariableDependentCallbackTimes(self):
        """if a HWA is variable dependent call back time, we refresh the value 
        using this function. THis should be called in each sequence"""
        [
            hdwA.parseCallbackTime() for hdwA in self.hardwareActions
            if hdwA.callbackTimeVariableDependent
        ]

    def _changeLoggingLevelDebug(self):
        """Menu action function to change logger level """
        logger.setLevel(logging.DEBUG)

    def _changeLoggingLevelInfo(self):
        """Menu action function to change logger level """
        logger.setLevel(logging.INFO)

    def _changeLoggingLevelWarning(self):
        """Menu action function to change logger level """
        logger.setLevel(logging.WARNING)

    def _changeLoggingLevelError(self):
        """Menu action function to change logger level """
        logger.setLevel(logging.ERROR)
Exemple #12
0
class HCFT(tr.HasStrictTraits):
    '''High-Cycle Fatigue Tool
    '''
    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    records_per_second = tr.Float(100)
    take_time_from_time_column = tr.Bool(True)
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_first_rows = tr.Range(low=1, high=10**9, mode='spinner')
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    force_column = tr.Enum(values='columns_headers_list')
    time_column = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    plot_settings_btn = tr.Button
    plot_settings = PlotSettings()
    plot_settings_active = tr.Bool
    normalize_cycles = tr.Bool
    smooth = tr.Bool
    plot_every_nth_point = tr.Range(low=1, high=1000000, mode='spinner')
    old_peak_force_before_cycles = tr.Float
    peak_force_before_cycles = tr.Float
    window_length = tr.Range(low=1, high=10**9 - 1, value=31, mode='spinner')
    polynomial_order = tr.Range(low=1, high=10**9, value=2, mode='spinner')
    activate = tr.Bool(False)
    add_plot = tr.Button
    add_creep_plot = tr.Button(desc='Creep plot of X axis array')
    clear_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_and_creep_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)
    min_cycle_force_range = tr.Float(50)
    cutting_method = tr.Enum(
        'Define min cycle range(force difference)', 'Define Max, Min')
    columns_to_be_averaged = tr.List
    figure = tr.Instance(mpl.figure.Figure)
    log = tr.Str('')
    clear_log = tr.Button

    def _figure_default(self):
        figure = mpl.figure.Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):
        try:

            self.reset()

            """ Handles the user clicking the 'Open...' button.
            """
            extns = ['*.csv', ]  # seems to handle only one extension...
            wildcard = '|'.join(extns)

            dialog = FileDialog(title='Select text file',
                                action='open', wildcard=wildcard,
                                default_path=self.file_csv)

            result = dialog.open()

            """ Test if the user opened a file to avoid throwing an exception if he 
            doesn't """
            if result == OK:
                self.file_csv = dialog.path
            else:
                return

            """ Filling x_axis and y_axis with values """
            headers_array = np.array(
                pd.read_csv(
                    self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                    nrows=1, header=None
                )
            )[0]
            for i in range(len(headers_array)):
                headers_array[i] = self.get_valid_file_name(headers_array[i])
            self.columns_headers_list = list(headers_array)

            """ Saving file name and path and creating NPY folder """
            dir_path = os.path.dirname(self.file_csv)
            self.npy_folder_path = os.path.join(dir_path, 'NPY')
            if os.path.exists(self.npy_folder_path) == False:
                os.makedirs(self.npy_folder_path)

            self.file_name = os.path.splitext(
                os.path.basename(self.file_csv))[0]

        except Exception as e:
            self.deal_with_exception(e)

    def _parse_csv_to_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.parse_csv_to_npy_fired)
        thread.start()

    def parse_csv_to_npy_fired(self):
        try:
            self.print_custom('Parsing csv into npy files...')

            for i in range(len(self.columns_headers_list) -
                           len(self.columns_to_be_averaged)):
                current_column_name = self.columns_headers_list[i]
                column_array = np.array(pd.read_csv(
                    self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                    skiprows=self.skip_first_rows, usecols=[i]))

                if current_column_name == self.time_column and \
                        self.take_time_from_time_column == False:
                    column_array = np.arange(start=0.0,
                                             stop=len(column_array) /
                                             self.records_per_second,
                                             step=1.0 / self.records_per_second)

                np.save(os.path.join(self.npy_folder_path, self.file_name +
                                     '_' + current_column_name + '.npy'),
                        column_array)

            """ Exporting npy arrays of averaged columns """
            for columns_names in self.columns_to_be_averaged:
                temp = np.zeros((1))
                for column_name in columns_names:
                    temp = temp + np.load(os.path.join(self.npy_folder_path,
                                                       self.file_name +
                                                       '_' + column_name +
                                                       '.npy')).flatten()
                avg = temp / len(columns_names)

                avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                    columns_names)
                np.save(os.path.join(self.npy_folder_path, self.file_name +
                                     '_' + avg_file_suffex + '.npy'), avg)

            self.print_custom('Finsihed parsing csv into npy files.')
        except Exception as e:
            self.deal_with_exception(e)

    def get_suffex_for_columns_to_be_averaged(self, columns_names):
        suffex_for_saved_file_name = 'avg_' + '_'.join(columns_names)
        return suffex_for_saved_file_name

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

    def _clear_plot_fired(self):
        self.figure.clear()
        self.data_changed = True

    def _add_columns_average_fired(self):
        try:
            columns_average = ColumnsAverage()
            for name in self.columns_headers_list:
                columns_average.columns.append(Column(column_name=name))

            # kind='modal' pauses the implementation until the window is closed
            columns_average.configure_traits(kind='modal')

            columns_to_be_averaged_temp = []
            for i in columns_average.columns:
                if i.selected:
                    columns_to_be_averaged_temp.append(i.column_name)

            if columns_to_be_averaged_temp:  # If it's not empty
                self.columns_to_be_averaged.append(columns_to_be_averaged_temp)

                avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                    columns_to_be_averaged_temp)
                self.columns_headers_list.append(avg_file_suffex)
        except Exception as e:
            self.deal_with_exception(e)

    def _generate_filtered_and_creep_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.generate_filtered_and_creep_npy_fired)
        thread.start()

    def generate_filtered_and_creep_npy_fired(self):
        try:
            if self.npy_files_exist(os.path.join(
                    self.npy_folder_path, self.file_name + '_' + self.force_column
                    + '.npy')) == False:
                return

            self.print_custom('Generating filtered and creep files...')

            # 1- Export filtered force
            force = np.load(os.path.join(self.npy_folder_path,
                                         self.file_name + '_' + self.force_column
                                         + '.npy')).flatten()
            peak_force_before_cycles_index = np.where(
                abs((force)) > abs(self.peak_force_before_cycles))[0][0]
            force_ascending = force[0:peak_force_before_cycles_index]
            force_rest = force[peak_force_before_cycles_index:]

            force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
                force_rest)

            force_max_min_indices = np.concatenate(
                (force_min_indices, force_max_indices))
            force_max_min_indices.sort()

            force_rest_filtered = force_rest[force_max_min_indices]
            force_filtered = np.concatenate(
                (force_ascending, force_rest_filtered))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.force_column + '_filtered.npy'),
                    force_filtered)

            # 2- Export filtered displacements
            for i in range(0, len(self.columns_headers_list)):
                if self.columns_headers_list[i] != self.force_column and \
                        self.columns_headers_list[i] != self.time_column:

                    disp = np.load(os.path.join(self.npy_folder_path, self.file_name
                                                + '_' +
                                                self.columns_headers_list[i]
                                                + '.npy')).flatten()
                    disp_ascending = disp[0:peak_force_before_cycles_index]
                    disp_rest = disp[peak_force_before_cycles_index:]

                    if self.activate == True:
                        disp_ascending = savgol_filter(
                            disp_ascending, window_length=self.window_length,
                            polyorder=self.polynomial_order)

                    disp_rest_filtered = disp_rest[force_max_min_indices]
                    filtered_disp = np.concatenate(
                        (disp_ascending, disp_rest_filtered))
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_'
                                         + self.columns_headers_list[i] +
                                         '_filtered.npy'), filtered_disp)

            # 3- Export creep for displacements
            # Cutting unwanted max min values to get correct full cycles and remove
            # false min/max values caused by noise
            if self.cutting_method == "Define Max, Min":
                force_max_indices_cutted, force_min_indices_cutted = \
                    self.cut_indices_of_min_max_range(force_rest,
                                                      force_max_indices,
                                                      force_min_indices,
                                                      self.force_max,
                                                      self.force_min)
            elif self.cutting_method == "Define min cycle range(force difference)":
                force_max_indices_cutted, force_min_indices_cutted = \
                    self.cut_indices_of_defined_range(force_rest,
                                                      force_max_indices,
                                                      force_min_indices,
                                                      self.min_cycle_force_range)

            self.print_custom("Cycles number= ", len(force_min_indices))
            self.print_custom("Cycles number after cutting fake cycles = ",
                              len(force_min_indices_cutted))

            for i in range(0, len(self.columns_headers_list)):
                if self.columns_headers_list[i] != self.time_column:
                    array = np.load(os.path.join(self.npy_folder_path, self.file_name +
                                                 '_' +
                                                 self.columns_headers_list[i]
                                                 + '.npy')).flatten()
                    array_rest = array[peak_force_before_cycles_index:]
                    array_rest_maxima = array_rest[force_max_indices_cutted]
                    array_rest_minima = array_rest[force_min_indices_cutted]
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                         self.columns_headers_list[i] + '_max.npy'), array_rest_maxima)
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                         self.columns_headers_list[i] + '_min.npy'), array_rest_minima)

            self.print_custom('Filtered and creep npy files are generated.')
        except Exception as e:
            self.deal_with_exception(e)

    def cut_indices_of_min_max_range(self, array, max_indices, min_indices,
                                     range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def cut_indices_of_defined_range(self, array, max_indices, min_indices, range_):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index, min_index in zip(max_indices, min_indices):
            if abs(array[max_index] - array[min_index]) > range_:
                cutted_max_indices.append(max_index)
                cutted_min_indices.append(min_index)

        if max_indices.size > min_indices.size:
            cutted_max_indices.append(max_indices[-1])
        elif min_indices.size > max_indices.size:
            cutted_min_indices.append(min_indices[-1])

        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = self.get_max_indices(input_array)
            force_min_indices = self.get_min_indices(input_array)
        else:
            force_max_indices = self.get_min_indices(input_array)
            force_min_indices = self.get_max_indices(input_array)

        return force_max_indices, force_min_indices

    def get_max_indices(self, a):
        # This method doesn't qualify first and last elements as max
        max_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True

            while a[i] == a[i + 1] and i < a.size - 1:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            if a[i] > a[i + 1] and a[i] > previous_element:
                max_indices.append(i)
            i += 1
        return np.array(max_indices)

    def get_min_indices(self, a):
        # This method doesn't qualify first and last elements as min
        min_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True
            while a[i] == a[i + 1]:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            if a[i] < a[i + 1] and a[i] < previous_element:
                min_indices.append(i)
            i += 1
        return np.array(min_indices)

    def _activate_changed(self):
        if self.activate == False:
            self.old_peak_force_before_cycles = self.peak_force_before_cycles
            self.peak_force_before_cycles = 0
        else:
            self.peak_force_before_cycles = self.old_peak_force_before_cycles

    def _window_length_changed(self, new):

        if new <= self.polynomial_order:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be bigger than polynomial order.')
            dialog.open()

        if new % 2 == 0 or new <= 0:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be odd positive integer.')
            dialog.open()

    def _polynomial_order_changed(self, new):
        if new >= self.window_length:
            dialog = MessageDialog(
                title='Attention!',
                message='Polynomial order must be smaller than window length.')
            dialog.open()

    #=========================================================================
    # Plotting
    #=========================================================================

    def _plot_settings_btn_fired(self):
        try:
            self.plot_settings.configure_traits(kind='modal')
        except Exception as e:
            self.deal_with_exception(e)

    def npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            # TODO fix this
            self.print_custom(
                'Please parse csv file to generate npy files first.')
#             dialog = MessageDialog(
#                 title='Attention!',
#                 message='Please parse csv file to generate npy files first.')
#             dialog.open()
            return False

    def filtered_and_creep_npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            # TODO fix this
            self.print_custom(
                'Please generate filtered and creep npy files first.')
#             dialog = MessageDialog(
#                 title='Attention!',
#                 message='Please generate filtered and creep npy files first.')
#             dialog.open()
            return False

    data_changed = tr.Event

    def _add_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_plot_fired)
        thread.start()

    def add_plot_fired(self):
        try:
            if self.apply_filters:
                if self.filtered_and_creep_npy_files_exist(os.path.join(
                        self.npy_folder_path, self.file_name + '_' + self.x_axis
                        + '_filtered.npy')) == False:
                    return
                x_axis_name = self.x_axis + '_filtered'
                y_axis_name = self.y_axis + '_filtered'
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.x_axis
                                                    + '_filtered.npy'), mmap_mode='r')
                y_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.y_axis
                                                    + '_filtered.npy'), mmap_mode='r')
            else:
                if self.npy_files_exist(os.path.join(
                        self.npy_folder_path, self.file_name + '_' + self.x_axis
                        + '.npy')) == False:
                    return

                x_axis_name = self.x_axis
                y_axis_name = self.y_axis
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.x_axis
                                                    + '.npy'), mmap_mode='r')
                y_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.y_axis
                                                    + '.npy'), mmap_mode='r')

            if self.plot_settings_active:
                print(self.plot_settings.first_rows)
                print(self.plot_settings.distance)
                print(self.plot_settings.num_of_rows_after_each_distance)
                print(np.size(x_axis_array))
                indices = self.get_indices_array(np.size(x_axis_array),
                                                 self.plot_settings.first_rows,
                                                 self.plot_settings.distance,
                                                 self.plot_settings.num_of_rows_after_each_distance)
                x_axis_array = self.x_axis_multiplier * x_axis_array[indices]
                y_axis_array = self.y_axis_multiplier * y_axis_array[indices]
            else:
                x_axis_array = self.x_axis_multiplier * x_axis_array
                y_axis_array = self.y_axis_multiplier * y_axis_array

            self.print_custom('Adding Plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.figure.add_subplot(1, 1, 1)

            ax.set_xlabel(x_axis_name)
            ax.set_ylabel(y_axis_name)
            ax.plot(x_axis_array, y_axis_array, 'k',
                    linewidth=1.2, color=np.random.rand(3), label=self.file_name +
                    ', ' + x_axis_name)

            ax.legend()
            self.data_changed = True
            self.print_custom('Finished adding plot.')

        except Exception as e:
            self.deal_with_exception(e)

    def _add_creep_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_creep_plot_fired)
        thread.start()

    def add_creep_plot_fired(self):
        try:
            if self.filtered_and_creep_npy_files_exist(os.path.join(
                    self.npy_folder_path, self.file_name + '_' + self.x_axis
                    + '_max.npy')) == False:
                return

            self.print_custom('Loading npy files...')
            disp_max = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_max.npy'))
            disp_min = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_min.npy'))
            complete_cycles_number = disp_max.size

            self.print_custom('Adding creep-fatigue plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.figure.add_subplot(1, 1, 1)

            ax.set_xlabel('Cycles number')
            ax.set_ylabel(self.x_axis)

            if self.plot_every_nth_point > 1:
                disp_max = disp_max[0::self.plot_every_nth_point]
                disp_min = disp_min[0::self.plot_every_nth_point]

            if self.smooth:
                # Keeping the first item of the array and filtering the rest
                disp_max = np.concatenate((
                    np.array([disp_max[0]]),
                    savgol_filter(disp_max[1:],
                                  window_length=self.window_length,
                                  polyorder=self.polynomial_order)
                ))
                disp_min = np.concatenate((
                    np.array([disp_min[0]]),
                    savgol_filter(disp_min[1:],
                                  window_length=self.window_length,
                                  polyorder=self.polynomial_order)
                ))

            if self.normalize_cycles:
                ax.plot(np.linspace(0, 1., disp_max.size), disp_max,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Max'
                        + ', ' + self.file_name + ', ' + self.x_axis)
                ax.plot(np.linspace(0, 1., disp_min.size), disp_min,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Min'
                        + ', ' + self.file_name + ', ' + self.x_axis)
            else:
                ax.plot(np.linspace(0, complete_cycles_number,
                                    disp_max.size), disp_max,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Max'
                        + ', ' + self.file_name + ', ' + self.x_axis)
                ax.plot(np.linspace(0, complete_cycles_number,
                                    disp_min.size), disp_min,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Min'
                        + ', ' + self.file_name + ', ' + self.x_axis)

            ax.legend()
            self.data_changed = True
            self.print_custom('Finished adding creep-fatigue plot.')

        except Exception as e:
            self.deal_with_exception(e)

    def get_indices_array(self,
                          array_size,
                          first_rows,
                          distance,
                          num_of_rows_after_each_distance):
        result_1 = np.arange(first_rows)
        result_2 = np.arange(start=first_rows, stop=array_size,
                             step=distance + num_of_rows_after_each_distance)
        result_2_updated = np.array([], dtype=np.int_)

        for result_2_value in result_2:
            data_slice = np.arange(result_2_value, result_2_value +
                                   num_of_rows_after_each_distance)
            result_2_updated = np.concatenate((result_2_updated, data_slice))

        result = np.concatenate((result_1, result_2_updated))
        return result

    def reset(self):
        self.columns_to_be_averaged = []
        self.log = ''

    def print_custom(self, *input_args):
        print(*input_args)
        if self.log == '':
            self.log = ''.join(str(e) for e in list(input_args))
        else:
            self.log = self.log + '\n' + \
                ''.join(str(e) for e in list(input_args))

    def deal_with_exception(self, e):
        self.print_custom('SOMETHING WENT WRONG!')
        self.print_custom('--------- Error message: ---------')
        self.print_custom(traceback.format_exc())
        self.print_custom('----------------------------------')

    def _clear_log_fired(self):
        self.log = ''

    #=========================================================================
    # Configuration of the view
    #=========================================================================
    traits_view = ui.View(
        ui.HSplit(
            ui.VSplit(
                ui.VGroup(
                    ui.VGroup(
                        ui.Item('decimal'),
                        ui.Item('delimiter'),
                        ui.HGroup(
                            ui.UItem('open_file_csv', has_focus=True),
                            ui.UItem('file_csv', style='readonly', width=0.1)),
                        label='Importing csv file',
                        show_border=True)),
                ui.VGroup(
                    ui.VGroup(
                        ui.VGroup(
                            ui.Item('take_time_from_time_column'),
                            ui.Item('time_column',
                                    enabled_when='take_time_from_time_column == True'),
                            ui.Item('records_per_second',
                                    enabled_when='take_time_from_time_column == False'),
                            label='Time calculation',
                            show_border=True),
                        ui.UItem('add_columns_average'),
                        ui.Item('skip_first_rows'),
                        ui.UItem('parse_csv_to_npy', resizable=True),
                        label='Processing csv file',
                        show_border=True)),
                ui.VGroup(
                    ui.VGroup(
                        ui.HGroup(ui.Item('x_axis'), ui.Item(
                            'x_axis_multiplier')),
                        ui.HGroup(ui.Item('y_axis'), ui.Item(
                            'y_axis_multiplier')),
                        ui.VGroup(
                            ui.HGroup(ui.UItem('add_plot'),
                                      ui.Item('apply_filters'),
                                      ui.Item('plot_settings_btn',
                                              label='Settings',
                                              show_label=False,
                                              enabled_when='plot_settings_active == True'),
                                      ui.Item('plot_settings_active',
                                              show_label=False)
                                      ),
                            show_border=True,
                            label='Plotting X axis with Y axis'
                        ),
                        ui.VGroup(
                            ui.HGroup(ui.UItem('add_creep_plot'),
                                      ui.VGroup(
                                          ui.Item('normalize_cycles'),
                                          ui.Item('smooth'),
                                          ui.Item('plot_every_nth_point'))
                                      ),
                            show_border=True,
                            label='Plotting Creep-fatigue of X axis variable'
                        ),
                        ui.UItem('clear_plot', resizable=True),
                        show_border=True,
                        label='Plotting'))
            ),
            ui.VGroup(
                ui.Item('force_column'),
                ui.VGroup(ui.VGroup(
                    ui.Item('window_length'),
                    ui.Item('polynomial_order'),
                    enabled_when='activate == True or smooth == True'),
                    show_border=True,
                    label='Smoothing parameters (Savitzky-Golay filter):'
                ),
                ui.VGroup(ui.VGroup(
                    ui.Item('activate'),
                    ui.Item('peak_force_before_cycles',
                            enabled_when='activate == True')
                ),
                    show_border=True,
                    label='Smooth ascending branch for all displacements:'
                ),
                ui.VGroup(ui.Item('cutting_method'),
                          ui.VGroup(ui.Item('force_max'),
                                    ui.Item('force_min'),
                                    label='Max, Min:',
                                    show_border=True,
                                    enabled_when='cutting_method == "Define Max, Min"'),
                          ui.VGroup(ui.Item('min_cycle_force_range'),
                                    label='Min cycle force range:',
                                    show_border=True,
                                    enabled_when='cutting_method == "Define min cycle range(force difference)"'),
                          show_border=True,
                          label='Cut fake cycles for creep:'),

                ui.VSplit(
                    ui.UItem('generate_filtered_and_creep_npy'),
                    ui.VGroup(
                        ui.Item('log',
                                width=0.1, style='custom'),
                        ui.UItem('clear_log'))),
                show_border=True,
                label='Filters'
            ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     width=0.8,
                     label='2d plots')
        ),
        title='High-cycle fatigue tool',
        resizable=True,
        width=0.85,
        height=0.7
    )
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    something = tr.Instance(Something)
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')

    path_hdf5 = tr.Str('')

    def _something_default(self):
        return Something()

    #=========================================================================
    # File management
    #=========================================================================
    file_csv = tr.File

    open_file_csv = tr.Button('Input file')

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = [
            '*.csv',
        ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open',
                            wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path
        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(self.file_csv,
                        delimiter=self.delimiter,
                        decimal=self.decimal,
                        nrows=1,
                        header=None))[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

    #=========================================================================
    # Parameters of the filter algorithm
    #=========================================================================

    chunk_size = tr.Int(10000, auto_set=False, enter_set=True)

    skip_rows = tr.Int(4, auto_set=False, enter_set=True)

    # 1) use the decorator
    @tr.on_trait_change('chunk_size, skip_rows')
    def whatever_name_size_changed(self):
        print('chunk-size changed')

    # 2) use the _changed or _fired extension
    def _chunk_size_changed(self):
        print('chunk_size changed - calling the named function')

    data = tr.Array(dtype=np.float_)

    read_loadtxt_button = tr.Button()

    def _read_loadtxt_button_fired(self):
        self.data = np.loadtxt(self.file_csv,
                               skiprows=self.skip_rows,
                               delimiter=self.delimiter)
        print(self.data.shape)

    read_csv_button = tr.Button
    read_hdf5_button = tr.Button

    def _read_csv_button_fired(self):
        self.read_csv()

    def _read_hdf5_button_fired(self):
        self.read_hdf5_no_filter()

    def read_csv(self):
        '''Read the csv file and transform it to the hdf5 format.
        The output file has the same name as the input csv file
        with an extension hdf5
        '''
        path_csv = self.file_csv
        # Following splitext splits the path into a pair (root, extension)
        self.path_hdf5 = os.path.splitext(path_csv)[0] + '.hdf5'

        for i, chunk in enumerate(
                pd.read_csv(path_csv,
                            delimiter=self.delimiter,
                            decimal=self.decimal,
                            skiprows=self.skip_rows,
                            chunksize=self.chunk_size)):
            chunk_array = np.array(chunk)
            chunk_data_frame = pd.DataFrame(
                chunk_array, columns=['a', 'b', 'c', 'd', 'e', 'f'])
            if i == 0:
                chunk_data_frame.to_hdf(self.path_hdf5,
                                        'all_data',
                                        mode='w',
                                        format='table')
            else:
                chunk_data_frame.to_hdf(self.path_hdf5,
                                        'all_data',
                                        append=True)

    def read_hdf5_no_filter(self):

        # reading hdf files is really memory-expensive!
        force = np.array(pd.read_hdf(self.path_hdf5, columns=['b']))
        weg = np.array(pd.read_hdf(self.path_hdf5, columns=['c']))
        disp1 = np.array(pd.read_hdf(self.path_hdf5, columns=['d']))
        disp2 = np.array(pd.read_hdf(self.path_hdf5, columns=['e']))
        disp3 = np.array(pd.read_hdf(self.path_hdf5, columns=['f']))

        force = np.concatenate((np.zeros((1, 1)), force))
        weg = np.concatenate((np.zeros((1, 1)), weg))
        disp1 = np.concatenate((np.zeros((1, 1)), disp1))
        disp2 = np.concatenate((np.zeros((1, 1)), disp2))
        disp3 = np.concatenate((np.zeros((1, 1)), disp3))

        dir_path = os.path.dirname(self.file_csv)
        npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(npy_folder_path) == False:
            os.makedirs(npy_folder_path)

        file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

        np.save(
            os.path.join(npy_folder_path, file_name + '_Force_nofilter.npy'),
            force)
        np.save(
            os.path.join(npy_folder_path,
                         file_name + '_Displacement_machine_nofilter.npy'),
            weg)
        np.save(
            os.path.join(npy_folder_path,
                         file_name + '_Displacement_sliding1_nofilter.npy'),
            disp1)
        np.save(
            os.path.join(npy_folder_path,
                         file_name + '_Displacement_sliding2_nofilter.npy'),
            disp2)
        np.save(
            os.path.join(npy_folder_path,
                         file_name + '_Displacement_crack1_nofilter.npy'),
            disp3)

        # Defining chunk size for matplotlib points visualization
        mpl.rcParams['agg.path.chunksize'] = 50000

        plt.subplot(111)
        plt.xlabel('Displacement [mm]')
        plt.ylabel('kN')
        plt.title('original data', fontsize=20)
        plt.plot(disp2, force, 'k')
        plt.show()

    figure = tr.Instance(Figure)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    npy_folder_path = tr.Str
    file_name = tr.Str

    plot = tr.Button

    def _plot_fired(self):
        ax = self.figure.add_subplot(111)
        print('plotting figure')
        print(type(self.x_axis), type(self.y_axis))
        print(self.data[:, 1])
        print(self.data[:, self.x_axis])
        print(self.data[:, self.y_axis])
        ax.plot(self.data[:, self.x_axis], self.data[:, self.y_axis])

    traits_view = ui.View(ui.HSplit(
        ui.VSplit(
            ui.HGroup(ui.UItem('open_file_csv'),
                      ui.UItem('file_csv', style='readonly'),
                      label='Input data'),
            ui.VGroup(ui.Item('chunk_size'),
                      ui.Item('skip_rows'),
                      ui.Item('decimal'),
                      ui.Item('delimiter'),
                      label='Filter parameters'),
            ui.VGroup(
                ui.HGroup(ui.Item('read_loadtxt_button', show_label=False),
                          ui.Item('plot', show_label=False),
                          show_border=True),
                ui.HGroup(ui.Item('read_csv_button', show_label=False),
                          ui.Item('read_hdf5_button', show_label=False),
                          show_border=True))),
        ui.UItem('figure',
                 editor=MPLFigureEditor(),
                 resizable=True,
                 springy=True,
                 label='2d plots'),
    ),
                          resizable=True,
                          width=0.8,
                          height=0.6)