Exemplo n.º 1
0
class FileSelectedFrame(ta.HasTraits):
    """
    Frame for current files selected
    """

    file_list = ta.List(ta.Str, [])

    Add_File = ta.Button()
    Add_Folder = ta.Button()
    Undo_Add = ta.Button()

    view = tua.View(tua.Item('file_list'),
                    tua.Item('Add_File', show_label=False),
                    tua.Item('Add_Folder', show_label=False),
                    tua.Item('Undo_Add', show_label=False),
                    resizable=True)

    def _Add_File_fired(self):
        global select_files
        self.file_list.append(select_files.file_name)

    def _Add_Folder_fired(self):
        global select_files
        self.file_list += GetAllPDF(select_files.file_directory)

    def _Undo_Add_fired(self):
        del self.file_list[-1]
Exemplo n.º 2
0
class IntegrateArea(SpanSelectorInSignal1D):
    integrate = t.Button()

    def __init__(self, signal, signal_range=None):
        if signal.axes_manager.signal_dimension != 1:
            raise SignalDimensionError(
                signal.axes.signal_dimension, 1)

        self.signal = signal
        self.axis = self.signal.axes_manager.signal_axes[0]
        self.span_selector = None
        if (not hasattr(self.signal, '_plot') or self.signal._plot is None or
                not self.signal._plot.is_active):
            self.signal.plot()
        self.span_selector_switch(on=True)

    def apply(self):
        integrated_spectrum = self.signal._integrate_in_range_commandline(
            signal_range=(
                self.ss_left_value,
                self.ss_right_value)
        )
        # Replaces the original signal inplace with the new integrated spectrum
        plot = False
        if self.signal._plot and integrated_spectrum.axes_manager.shape != ():
            self.signal._plot.close()
            plot = True
        self.signal.__init__(**integrated_spectrum._to_dictionary())
        self.signal._assign_subclass()
        self.signal.axes_manager.set_signal_dimension(0)

        if plot is True:
            self.signal.plot()
Exemplo n.º 3
0
class Something(tr.HasTraits):

    txt_file_name = tr.File

    openTxt = tr.Button('Open...')

    a = tr.Int(20, auto_set=False, enter_set=True,
               input=True)

    b = tr.Float(20, auto_set=False, enter_set=True,
                 input=True)

    @tr.on_trait_change('+input')
    def _handle_input_change(self):
        print('some input parameter changed')
        self.input_event = True

    input_event = tr.Event

    def _some_event_changed(self):
        print('input happend')

    def _openTxt_fired(self):
        print('do something')
        print(self.txt_file_name)

    traits_view = ui.View(
        ui.VGroup(
            ui.HGroup(
                ui.Item('openTxt', show_label=False),
                ui.Item('txt_file_name', width=200),
                ui.Item('a')
            ),
        )
    )
Exemplo n.º 4
0
class EntryBlock(traits.HasTraits):
    
    fieldName = traits.String("fieldName",desc = "describes what the information to be entered in the text block is referring to")
    textBlock = traits.String()
    commitButton = traits.Button("save",desc="commit information in text block to logFile")

    
    traits_view = traitsui.View(traitsui.VGroup(
                    traitsui.Item("fieldName",show_label=False, style="readonly"),
                    traitsui.Item("textBlock",show_label=False, style="custom"),
                    traitsui.Item("commitButton",show_label=False), show_border=True, label="information"
                        ))
    
    def __init__(self, **traitsDict):
        """user supplies arguments in init to supply class attributes defined above """
        super(EntryBlock,self).__init__(**traitsDict)
        
    def _commitButton_fired(self):
        logger.info("saving %s info starting" % self.fieldName)
        timeStamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        blockDelimiterStart = "__"+self.fieldName+"__<start>"
        blockDelimiterEnd = "__"+self.fieldName+"__<end>"
        fullString = "\n"+blockDelimiterStart+"\n"+timeStamp+"\n"+self.textBlock+"\n"+blockDelimiterEnd+"\n"
        with open(self.commentFile, "a+") as writeFile:
            writeFile.write(fullString)
        logger.info("saving %s info finished" % self.fieldName)
    
    def clearTextBlock(self):
        self.textBlock = ""
Exemplo n.º 5
0
class FileFrame(ta.HasTraits):
    """
    Frame for file selecting
    """
    def_file = '/home/jackdra/LQCD/Scripts/EDM_paper/graphs/FF/FullFFFit/Neutron_ContFit_a.pdf'
    def_folder = '/home/jackdra/LQCD/Scripts/EDM_paper/graphs/FF/FullFFFit/'
    file_directory = ta.Directory(def_folder)
    file_name = ta.File(def_file, filter=['*.pdf'])

    Add_File = ta.Button()
    Add_Folder = ta.Button()
    # Undo_Add = ta.Button()

    view = tua.View(
        tua.HSplit(
            tua.Item('file_directory', style='custom', springy=True),
            tua.Item('file_name', style='custom', springy=True),
            tua.VGroup(tua.Item('file_directory', springy=True),
                       tua.Item('file_name', springy=True),
                       tua.Item('Add_File', show_label=False),
                       tua.Item('Add_Folder', show_label=False)
                       # tua.Item('Undo_Add',show_label=False),
                       )),
        resizable=True,
        height=1000,
        width=1500)

    def _file_name_changed(self):
        self.file_directory = '/'.join(self.file_name.split('/')[:-1]) + '/'

    def _file_directory_changed(self):
        file_list = GetAllPDF(self.file_directory)
        if len(file_list) > 0:
            self.file_name = GetAllPDF(self.file_directory)[0]

    def _Add_File_fired(self):
        global files_selected
        files_selected.file_list.append(self.file_name)

    def _Add_Folder_fired(self):
        global files_selected
        files_selected.file_list += GetAllPDF(self.file_directory)
Exemplo n.º 6
0
class Librarian(traits.HasTraits):
    """Librarian provides a way of writing useful information into the 
    log folder for eagle logs. It is designed to make the information inside
    an eagle log easier to come back to. It mainly writes default strings into
    the comments file in the log folder"""
    
    logType = traits.Enum("important","debug","calibration")
    typeCommitButton = traits.Button("save")
    axisList = AxisSelector()
    purposeBlock = EntryBlock(fieldName="What is the purpose of this log?")
    explanationBlock = EntryBlock(fieldName = "Explain what the data shows (important parameters that change, does it make sense etc.)?")
    additionalComments = EntryBlock(fieldName = "Anything Else?")

    traits_view = traitsui.View(
        traitsui.VGroup(
            traitsui.Item("logFolder",show_label=False, style="readonly"),
            traitsui.HGroup(traitsui.Item("logType",show_label=False),traitsui.Item("typeCommitButton",show_label=False)),
            traitsui.Item("axisList",show_label=False, editor=traitsui.InstanceEditor(),style='custom'),
            traitsui.Item("purposeBlock",show_label=False, editor=traitsui.InstanceEditor(),style='custom'),
            traitsui.Item("explanationBlock",show_label=False, editor=traitsui.InstanceEditor(),style='custom'),
            traitsui.Item("additionalComments",show_label=False, editor=traitsui.InstanceEditor(),style='custom')
        )  , resizable=True  , kind ="live"
    )    
    
    def __init__(self, **traitsDict):
        """Librarian object requires the log folder it is referring to. If a .csv
        file is given as logFolder argument it will use parent folder as the 
        logFolder"""
        super(Librarian, self).__init__(**traitsDict)
        if os.path.isfile(self.logFolder):
            self.logFolder = os.path.split(self.logFolder)[0]
        else:
            logger.debug("found these in %s: %s" %(self.logFolder, os.listdir(self.logFolder) ))
        
        self.logFile = os.path.join(self.logFolder, os.path.split(self.logFolder)[1]+".csv")
        self.commentFile = os.path.join(self.logFolder, "comments.txt")
        self.axisList.commentFile = self.commentFile
        self.axisList.logFile = self.logFile
        self.purposeBlock.commentFile = self.commentFile
        self.explanationBlock.commentFile = self.commentFile
        self.additionalComments.commentFile  = self.commentFile
        
    def _typeCommitButton_fired(self):
        logger.info("saving axes info starting")
        timeStamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        blockDelimiterStart = "__Log Type__<start>"
        blockDelimiterEnd = "__Log Type__<end>"
        fullString = "\n"+blockDelimiterStart+"\n"+timeStamp+"\n"+self.logType+"\n"+blockDelimiterEnd+"\n"
        with open(self.commentFile, "a+") as writeFile:
            writeFile.write(fullString)
        logger.info("saving axes info finished")
Exemplo n.º 7
0
class BeginFrame(ta.HasTraits):

    Start = ta.Button()

    view = tua.View(tua.Item('Start', show_label=False))

    def _Start_fired(self):
        global files_selected, select_files
        files_selected = FileSelectedFrame()
        select_files = FileFrame()
        atributes = PlotAtributes()
        files_selected.configure_traits()
        select_files.configure_traits()
        atributes.configure_traits()
Exemplo n.º 8
0
class ColumnEditor(traits.HasTraits):
    """ Define the main column Editor class. Complex part is handled 
    by get columns function below that defines the view and editor."""
    columns = traits.List()
    numberOfColumns = traits.Int()
    selectAllButton = traits.Button('Select/Deselect All')
    selectDeselectBool = traits.Bool(True)

    def _selectAllButton_fired(self):
        if self.selectDeselectBool:
            self.columns = []
            self.selectDeselectBool = not self.selectDeselectBool
        else:
            self.columns = range(0, self.numberOfColumns)
            self.selectDeselectBool = not self.selectDeselectBool
Exemplo n.º 9
0
class AxisSelector(traits.HasTraits):
    """here we select what axes the user should use when plotting this data """
    masterList = traits.List
    masterListWithNone =  traits.List
    xAxis = traits.Enum(values="masterList")
    yAxis = traits.Enum(values="masterList")
    series = traits.Enum(values="masterListWithNone")
    commitButton = traits.Button("save",desc="commit information in text block to logFile")
    
    traits_view=traitsui.View(traitsui.VGroup(traitsui.Item("xAxis",label="x axis"),traitsui.Item("yAxis",label="y axis"),
                                  traitsui.Item("series",label="series"),traitsui.Item("commitButton",show_label=False),show_border=True, label="axes selection"))
    
    def __init__(self, **traitsDict):
        """allows user to select which axes are useful for plotting in this log"""
        super(AxisSelector, self).__init__(**traitsDict)
    
    
    def _masterList_default(self):
        """gets the header row of log file which are interpreted as the column
        names that can be plotted."""
        logger.info("updating master list of axis choices")
        logger.debug("comment file = %s" % self.logFile)
        print "comment file = %s" % self.logFile
        if not os.path.exists(self.logFile):
            return []
        try:
            with open(self.logFile) as csvfile:
                headerReader = csv.reader(csvfile)
                headerRow=headerReader.next()
            return headerRow
        except IOError:
            return []
            
    def _masterListWithNone_default(self):
        return ["None"]+self._masterList_default()
        
    def _commitButton_fired(self):
        logger.info("saving axes info starting")
        timeStamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        blockDelimiterStart = "__Axes Selection__<start>"
        blockDelimiterEnd = "__Axes Selection__<end>"
        textBlock = "xAxis = %s\nyAxis = %s\n series = %s" % (self.xAxis,self.yAxis,self.series)
        fullString = "\n"+blockDelimiterStart+"\n"+timeStamp+"\n"+textBlock+"\n"+blockDelimiterEnd+"\n"
        with open(self.commentFile, "a+") as writeFile:
            writeFile.write(fullString)
        logger.info("saving axes info finished")
Exemplo n.º 10
0
class ChangeAxis(tapi.HasStrictTraits):

    Change_X_Axis = tapi.Button('Change X-axis')
    Plot_Data = tapi.Instance(Plot2D)
    headers = tapi.List(tapi.Str)

    def __init__(self, **traits):
        tapi.HasStrictTraits.__init__(self, **traits)

    def _Change_X_Axis_fired(self):
        global Change_X_Axis_Enabled
        Change_X_Axis_Enabled = False
        ms = SingleSelect(choices=self.headers, plot=self.Plot_Data)
        ms.configure_traits(handler=ChangeAxisHandler())

    view = tuiapi.View(
        tuiapi.Item('Change_X_Axis',
                    enabled_when='Change_X_Axis_Enabled==True',
                    show_label=False))
Exemplo n.º 11
0
class newDarkPictureDialog(traits.HasTraits):
    # pathSourceImages = traits.Directory( os.path.join("\\\\192.168.16.71","Humphry","Data","eagleLogs") )
    pathSourceImages = traits.Directory( eagleLogsFolder )
    pathNewDarkPicture = traits.File( defaultDarkPictureFilename, editor = traitsui.FileEditor(dialog_style='save') )
    cancelButton = traitsui.Action(name = 'Cancel', action = '_cancel')
    okButton = traitsui.Action(name = 'Calculate dark picture', action = '_ok')

    date = traits.String( time.strftime('%Y %m %d'), desc='Date' )
    camera = traits.String( "Andor1" )
    interval = traits.Float(0.003)
    filterCountLi = traits.Int(1)
    temperature = traits.Float(-40.0)
    autoFilename = traits.Button('Auto Filename')

    traits_view = traitsui.View(
        traitsui.Group(
            traitsui.Item('pathSourceImages'),
            traitsui.Group(
                traitsui.Item('date'),
                traitsui.Item('camera'),
                traitsui.Item('interval'),
                traitsui.Item('temperature'),
                traitsui.Item('autoFilename'),
                label='Auto Filename', show_border=True
            ),
            traitsui.Item('pathNewDarkPicture')
        ),
        buttons = [cancelButton, okButton],
        handler = newDarkPictureDialogHandler()
    )

    def _autoFilename_fired(self):
        filename = self.date + ' - dark ' + self.camera + ' - '
        filename += 'interval {} '.format(self.interval)
        filename += 'temperature {} '.format(self.temperature)
        filename = filename.replace('.','_')
        # filename += '.gz'
        filename += '.npy'
        path = os.path.join( defaultDarkPictureFilename, self.camera)
        if not os.path.exists( path ):
            os.mkdir( path )
        self.pathNewDarkPicture = os.path.join( path, filename )
Exemplo n.º 12
0
class IntegrateArea(SpanSelectorInSpectrum):
    integrate = t.Button()

    view = tu.View(
        buttons=[OKButton, CancelButton],
        title='Integrate in range',
        handler=SpanSelectorInSpectrumHandler,
    )

    def __init__(self, signal, signal_range=None):
        if signal.axes_manager.signal_dimension != 1:
            raise SignalDimensionError(
                signal.axes.signal_dimension, 1)

        self.signal = signal
        self.span_selector = None
        if not hasattr(self.signal, '_plot'):
            self.signal.plot()
        elif self.signal._plot is None:
            self.signal.plot()
        elif self.signal._plot.is_active() is False:
            self.signal.plot()
        self.span_selector_switch(on=True)

    def apply(self):
        integrated_spectrum = self.signal._integrate_in_range_commandline(
            signal_range=(
                self.ss_left_value,
                self.ss_right_value)
        )
        # Replaces the original signal inplace with the new integrated spectrum
        plot = False
        if self.signal._plot and integrated_spectrum.axes_manager.shape != ():
            self.signal._plot.close()
            plot = True
        self.signal.__init__(**integrated_spectrum._to_dictionary())
        self.signal._assign_subclass()
        self.signal.axes_manager.set_signal_dimension(0)

        if plot is True:
            self.signal.plot()
Exemplo n.º 13
0
class BeginFrame(ta.HasTraits):
    """
    starting frame, to select what type of thing you want to do
    options are:
    """

    file_name = ta.File(Load_Prev_File(), filter=['*.xml'])
    Alpha_window = AlphaFrame()
    FlowOp_window = FlowOpFrame()
    Plot_window = PlotFrame()
    Load_FlowOp = ta.Button()
    Load_Alpha = ta.Button()
    Load_Plot = ta.Button()

    view = tua.View('_',
                    tua.Item('file_name'),
                    tua.Item('Load_FlowOp', show_label=False),
                    tua.Item('Load_Alpha', show_label=False),
                    tua.Item('Load_Plot', show_label=False),
                    buttons=['OK'],
                    resizable=True)

    def _Load_Alpha_fired(self):
        global xml_data
        xml_data = {}
        xml_data['save_file'] = self.file_name
        if os.path.isfile(self.file_name):
            xml_data = ReadXml(self.file_name)
            xml_data['Results']['save_file'] = self.file_name.replace(
                '.xml', '.tex')
            self.Alpha_window.configure_traits()
            Save_Prev_File(str(self.file_name))
        else:
            print('File not found:')
            print(xml_data['save_file'])

    def _Load_Plot_fired(self):
        global xml_data
        xml_data = {}
        xml_data['save_file'] = self.file_name
        if os.path.isfile(self.file_name):
            xml_data = ReadXml(self.file_name)
            xml_data['Results']['save_file'] = self.file_name.replace(
                '.xml', '.tex')
            self.Plot_window.configure_traits()
            Save_Prev_File(str(self.file_name))
        else:
            print('File not found:')
            print(xml_data['save_file'])

    def _Load_FlowOp_fired(self):
        global xml_data
        xml_data = {}
        xml_data['save_file'] = self.file_name
        if os.path.isfile(self.file_name):
            xml_data = ReadXml(self.file_name)
            xml_data['Results']['save_file'] = self.file_name.replace(
                '.xml', '.tex')
            self.FlowOp_window.configure_traits()
            Save_Prev_File(str(self.file_name))
        else:
            print('File not found:')
            print(xml_data['save_file'])
Exemplo n.º 14
0
class FlowOpFrame(ta.HasTraits):
    # key_list = range(1,21)
    # this_key = ta.Enum(key_list)
    output_file = ta.Str('')
    this_flow_time = ta.Str('t_f6.01')
    this_t_sum = ta.Str('ts63')
    this_parameter = ta.Str('A')
    do_chi = ta.Bool(True)
    do_chi_err = ta.Bool(True)
    Load = ta.Button()
    Show = ta.Button()
    Show_latex = ta.Button()
    Apply = ta.Button()
    #
    view = tua.View(tua.Item('output_file'),
                    tua.Item('this_flow_time'),
                    tua.Item('this_t_sum'),
                    tua.Item('this_parameter'),
                    tua.Item('do_chi'),
                    tua.Item('do_chi_err', enabled_when='do_chi'),
                    tua.Item('Load', show_label=False),
                    tua.Item('Show', show_label=False),
                    tua.Item('Show_latex', show_label=False),
                    tua.Item('Apply', show_label=False),
                    buttons=['OK'],
                    resizable=True)

    #
    def _Load_fired(self):
        global xml_data
        if xml_data is None:
            raise EnvironmentError('plot data has not been loaded')
        # self.key_list = xml_data.plot_data.keys()
        self.output_file = xml_data['Results']['save_file']

    def _Show_fired(self):
        global xml_data
        print(unparse(xml_data, pretty=True))

    def _Show_latex_fired(self):
        global latex_data
        if latex_data is None:
            print('latex data not generated yet')
        else:
            print(latex_data)

    def _Apply_fired(self):
        global xml_data, latex_data
        this_data = xml_data['Results']
        vall, ilist, chil, slist1, slist2 = [], [], [], [], []
        if 'Chi_boot' not in list(this_data.keys()):
            print('Integrated FlowOp result not in keys')
        else:
            this_data = this_data['Chi_boot']
            if self.this_flow_time not in list(this_data.keys()):
                print(self.this_flow_time, ' not in flow time keys ')
                print(list(this_data.keys()))
                return
            this_data = this_data[self.this_flow_time]
            if isinstance(this_data, dict) or isinstance(
                    this_data, OrderedDict):
                vall.append(
                    MakeValAndErr(this_data['Avg'],
                                  this_data['Std'],
                                  latex=True))
            else:
                vall.append(this_data)
            chil.append('N.A.')
            ilist.append('chiint')

        this_data = xml_data['Results']
        if 'Prop_Fit' not in list(this_data.keys()):
            print('Fits not in keys')
            print(list(this_data.keys()))
            return
        this_data = this_data['Prop_Fit']
        if self.this_flow_time not in list(this_data.keys()):
            print(self.this_flow_time, ' not in flow time keys ')
            print(list(this_data.keys()))
            return
        this_data = this_data[self.this_flow_time]
        if self.this_t_sum not in list(this_data.keys()):
            print(self.this_t_sum, ' not in sum source time keys ')
            print(list(this_data.keys()))
            return
        this_data = this_data[self.this_t_sum]
        for ikey, idata in this_data.items():
            this_p = self.this_parameter
            if this_p in list(idata.keys()):
                if isinstance(idata[this_p], dict) or isinstance(
                        idata[this_p], OrderedDict):
                    vall.append(
                        MakeValAndErr(idata[this_p]['Avg'],
                                      idata[this_p]['Std'],
                                      latex=True))
                else:
                    vall.append(idata[this_p])
                if 'Chi_pow_2_pdf' in list(idata.keys()):
                    if isinstance(idata['Chi_pow_2_pdf'], dict) or isinstance(
                            idata['Chi_pow_2_pdf'], OrderedDict):
                        chil.append(
                            MakeValAndErr(idata['Chi_pow_2_pdf']['Avg'],
                                          idata['Chi_pow_2_pdf']['Std'],
                                          latex=True))
                    else:
                        chil.append(idata['Chi_pow_2_pdf'])
                ilist.append(ikey)
                leftval, rightval = list(
                    map(float,
                        ikey.replace('fitr', '').split('-')))
                slist1.append(1 / rightval)
                slist2.append(leftval)
        dump, dump2, ilisttemp, valltemp, chiltemp = (list(x) for x in zip(
            *sorted(zip(slist1, slist2, ilist[1:], vall[1:], chil[1:]))))
        ilist = [ilist[0]] + ilisttemp
        vall = [vall[0]] + valltemp
        chil = [chil[0]] + chiltemp
        latex_data = pa.DataFrame(index=ilist)
        latex_data.loc[:,
                       self.this_parameter + ' Avg'] = pa.Series(vall,
                                                                 index=ilist)
        if len(vall) == len(chil) and self.do_chi:
            latex_data.loc[:, self.this_parameter + ' chi2pdf'] = pa.Series(
                chil, index=ilist)
        with open(xml_data['Results']['save_file'], 'w') as f:
            format_output = FormatLatex(latex_data.to_latex(escape=False))
            f.write(format_output)
Exemplo n.º 15
0
class PlotFrame(ta.HasTraits):
    # key_list = range(1,21)
    # this_key = ta.Enum(key_list)
    output_file = ta.Str('')
    this_data_type = ta.Enum(['data', 'fit_parameters'])
    transpose_table = ta.Bool(False)
    include_chi = ta.Bool(True)
    fmt_latex = ta.Bool(True)
    Load = ta.Button()
    Show = ta.Button()
    Show_latex = ta.Button()
    Apply = ta.Button()
    #
    view = tua.View(tua.Item('output_file'),
                    tua.Item('this_data_type'),
                    tua.Item('transpose_table'),
                    tua.Item('include_chi'),
                    tua.Item('fmt_latex'),
                    tua.Item('Load', show_label=False),
                    tua.Item('Show', show_label=False),
                    tua.Item('Show_latex', show_label=False),
                    tua.Item('Apply', show_label=False),
                    buttons=['OK'],
                    resizable=True)

    #
    def _Load_fired(self):
        global xml_data
        if xml_data is None:
            raise EnvironmentError('plot data has not been loaded')
        # self.key_list = xml_data.plot_data.keys()
        self.output_file = xml_data['Results']['save_file']

    def _Show_fired(self):
        global xml_data
        print(unparse(xml_data, pretty=True))

    def _Show_latex_fired(self):
        global latex_data
        if latex_data is None:
            print('latex data not generated yet')
        else:
            print(latex_data.to_string())

    def _Apply_fired(self):
        global xml_data, latex_data
        this_data = xml_data['Results']
        table_data = pa.DataFrame()
        for cfg_key, cfg_data in this_data.items():
            if cfg_key in [
                    'window_size_x', 'window_size_y', 'Info', 'save_file'
            ]:
                continue
            this_data = pa.Series()
            if self.this_data_type in cfg_data.keys():
                if self.this_data_type == 'fit_parameters' and self.include_chi \
                and 'chi_pow_2_pdf' in cfg_data.keys():
                    this_data[r'$\chi^{2}_{pdf}$'] = '$' + MakeValAndErr(
                        cfg_data['chi_pow_2_pdf']['Avg'],
                        cfg_data['chi_pow_2_pdf']['Std'],
                        Dec=2,
                        latex=self.fmt_latex) + '$'
                for data_key, data_data in cfg_data[
                        self.this_data_type].items():
                    if isinstance(data_data, dict):
                        this_data[data_key] = '$' + MakeValAndErr(
                            data_data['Avg'],
                            data_data['Std'],
                            Dec=2,
                            latex=self.fmt_latex) + '$'
                    else:
                        if isinstance(data_data, (list, tuple, np.ndarray)):
                            if len(data_data) == 2:
                                this_data[data_key] = '$' + MakeValAndErr(
                                    data_data[0],
                                    data_data[1],
                                    Dec=2,
                                    latex=self.fmt_latex) + '$'
                            elif len(data_data) == 1:
                                this_data[data_key] = f'${data_data[0]:.3f}$'
                        else:
                            this_data[data_key] = f'${data_data[0]:.3f}$'
            if len(this_data) > 0:
                table_data[cfg_key] = this_data
        latex_data = table_data
        if self.transpose_table:
            latex_data = latex_data.transpose()
        with open(xml_data['Results']['save_file'], 'w') as f:
            format_output = FormatLatex(latex_data.to_latex(escape=False))
            f.write(format_output)
Exemplo n.º 16
0
class CalibrationAlignmentWindow(Widget):
    params = traits.Instance( talign.Alignment )
    save_align_json = traits.Button(label='Save alignment data as .json file')
    save_new_cal = traits.Button(label='Save new calibration as .xml file')
    save_new_cal_dir = traits.Button(label='Save new calibration as directory')

    traits_view = View( Group( ( Item( 'params', style='custom',
                                       show_label=False),
                                 Item( 'save_align_json', show_label = False ),
                                 Item( 'save_new_cal', show_label = False ),
                                 Item( 'save_new_cal_dir', show_label = False ),
                                 )),
                        title = 'Calibration Alignment',
                        )
    orig_data_verts = traits.Instance(object)
    orig_data_speeds = traits.Instance(object)
    reconstructor = traits.Instance(object)
    viewed_data = traits.Instance(tvtk.DataSet)
    source = traits.Instance(VTKDataSource)

    def __init__(self, parent, **traits):
        super(CalibrationAlignmentWindow, self).__init__(**traits)
        self.params = talign.Alignment()

        self.control = self.edit_traits(parent=parent,
                                        kind='subpanel',
                                        context={'h1':self.params, # XXX ???
                                                 'object':self},
                                        ).control
        self.params.on_trait_change( self._params_changed )

    def set_data(self,orig_data_verts,orig_data_speeds,reconstructor,align_json):
        self.orig_data_verts = orig_data_verts
        self.orig_data_speeds = orig_data_speeds
        self.reconstructor = reconstructor

        assert orig_data_verts.ndim == 2
        assert orig_data_speeds.ndim == 1
        assert orig_data_verts.shape[0] == 4
        assert orig_data_verts.shape[1] == orig_data_speeds.shape[0]


        # from mayavi2-2.0.2a1/enthought.tvtk/enthought/tvtk/tools/mlab.py
        #   Glyphs.__init__
        points = hom2vtk(orig_data_verts)
        polys = numpy.arange(0, len(points), 1, 'l')
        polys = numpy.reshape(polys, (len(points), 1))
        pd = tvtk.PolyData(points=points, polys=polys)
        pd.point_data.scalars = orig_data_speeds
        pd.point_data.scalars.name = 'speed'
        self.viewed_data = pd
        self.source = VTKDataSource(data=self.viewed_data, name='aligned data')

        if align_json:
            j = json.loads(open(align_json).read())
            self.params.s = j["s"]
            for i,k in enumerate(("tx", "ty", "tz")):
                setattr(self.params, k, j["t"][i])

            R = np.array(j["R"])
            rx,ry,rz = np.rad2deg(euler_from_matrix(R, 'sxyz'))

            self.params.r_x = rx
            self.params.r_y = ry
            self.params.r_z = rz

            self._params_changed()

    def _params_changed(self):
        if self.orig_data_verts is None or self.viewed_data is None:
            # no data set yet
            return
        M = self.params.get_matrix()
        verts = np.dot(M,self.orig_data_verts)
        self.viewed_data.points = hom2vtk(verts)
        self.source.update()
        self.source.render()

    def get_aligned_R(self):
        M = self.params.get_matrix()
        R = self.reconstructor
        alignedR = R.get_aligned_copy(M,update_water_boundary=False)
        return alignedR

    def _save_align_json_fired(self):
        wildcard = 'JSON files (*.json)|*.json|' + FileDialog.WILDCARD_ALL
        dialog = FileDialog(#parent=self.window.control,
                            title='Save alignment as .json file',
                            action='save as', wildcard=wildcard
                            )
        if dialog.open() == OK:
            buf = json.dumps(self.params.as_dict())
            with open(dialog.path,mode='w') as fd:
                fd.write(buf)

    def _save_new_cal_fired(self):
        wildcard = 'XML files (*.xml)|*.xml|' + FileDialog.WILDCARD_ALL
        dialog = FileDialog(#parent=self.window.control,
                            title='Save calibration .xml file',
                            action='save as', wildcard=wildcard
                            )
        if dialog.open() == OK:
            alignedR = self.get_aligned_R()
            alignedR.save_to_xml_filename(dialog.path)

    def _save_new_cal_dir_fired(self):
        dialog = FileDialog(#parent=self.window.control,
                            title='Save calibration directory',
                            action='save as',
                            )
        if dialog.open() == OK:
            alignedR = self.get_aligned_R()
            alignedR.save_to_files_in_new_directory(dialog.path)
Exemplo n.º 17
0
class SpikesRemoval(SpanSelectorInSpectrum):
    interpolator_kind = t.Enum(
        'Linear',
        'Spline',
        default='Linear')
    threshold = t.Float()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1, 10, 3)
    interpolator = None
    default_spike_width = t.Int(5)
    index = t.Int(0)
    add_noise = t.Bool(True,
                       desc="Add noise to the healed portion of the "
                       "spectrum. Use the noise properties "
                       "defined in metadata if present, otherwise "
                       "it defaults to shot noise.")
    view = tu.View(tu.Group(
        tu.Group(
            tu.Item('show_derivative_histogram', show_label=False),
            'threshold',
            show_border=True,),
        tu.Group(
            'add_noise',
            'interpolator_kind',
            'default_spike_width',
            tu.Group(
                'spline_order',
                visible_when='interpolator_kind == \'Spline\''),
            show_border=True,
            label='Advanced settings'),
    ),
        buttons=[OKButton,
                 OurPreviousButton,
                 OurFindButton,
                 OurApplyButton, ],
        handler=SpikesRemovalHandler,
        title='Spikes removal tool')

    def __init__(self, signal, navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [coordinate for coordinate in
                            signal.axes_manager._am_indices_generator()
                            if (navigation_mask is None or not
                                navigation_mask[coordinate[::-1]])]
        self.signal = signal
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        if len(self.coordinates) > 1:
            signal.axes_manager.indices = self.coordinates[0]
        self.threshold = 400
        self.index = 0
        self.argmax = None
        self.derivmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask
        md = self.signal.metadata
        from hyperspy.signal import Signal
        if "Signal.Noise_properties" in md:
            if "Signal.Noise_properties.variance" in md:
                self.noise_variance = md.Signal.Noise_properties.variance
                if isinstance(md.Signal.Noise_properties.variance, Signal):
                    self.noise_type = "heteroscedastic"
                else:
                    self.noise_type = "white"
        else:
            self.noise_type = "shot noise"

    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()

    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                      navigation_mask=self.navigation_mask)

    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            self.derivmax = abs(derivative.max())
            return True
        else:
            return False

    def _reset_line(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()

    def find(self, back=False):
        self._reset_line()
        ncoordinates = len(self.coordinates)
        spike = self.detect_spike()
        while not spike and (
                (self.index < ncoordinates - 1 and back is False) or
                (self.index > 0 and back is True)):
            if back is False:
                self.index += 1
            else:
                self.index -= 1
            spike = self.detect_spike()

        if spike is False:
            messages.information('End of dataset reached')
            self.index = 0
            self._reset_line()
            return
        else:
            minimum = max(0, self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            thresh_label = DerivativeTextParameters(
                text="$\mathsf{\delta}_\mathsf{max}=$",
                color="black")
            self.ax.legend([thresh_label], [repr(int(self.derivmax))], handler_map={
                           DerivativeTextParameters: DerivativeTextHandler()}, loc='best')
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(
                    minimum),
                self.signal.axes_manager.signal_axes[0].index2value(
                    maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        if len(self.coordinates) > 1:
            self.signal._plot.pointer._update_patch_position()

    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False

    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False

    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None

    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()

    def _add_noise_changed(self, old, new):
        self.span_selector_changed()

    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        self.span_selector_changed()

    def create_interpolation_line(self):
        self.interpolated_line = drawing.spectrum.SpectrumLine()
        self.interpolated_line.data_function = \
            self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(
            color='blue',
            type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()

    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if self.ss_left_value == self.ss_right_value:
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)

        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1

        return left, right

    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = 10
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = int(np.clip(left, 0, len(data)))
        right = int(np.clip(right, 0, len(data)))
        x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
        y = np.hstack((data[ileft:left], data[right:iright]))
        if ileft == 0:
            # Extrapolate to the left
            data[left:right] = data[right + 1]

        elif iright == (len(data) - 1):
            # Extrapolate to the right
            data[left:right] = data[left - 1]

        else:
            # Interpolate
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])

        # Add noise
        if self.add_noise is True:
            if self.noise_type == "white":
                data[left:right] += np.random.normal(
                    scale=np.sqrt(self.noise_variance),
                    size=right - left)
            elif self.noise_type == "heteroscedastic":
                noise_variance = self.noise_variance(
                    axes_manager=self.signal.axes_manager)[left:right]
                noise = [np.random.normal(scale=np.sqrt(item))
                         for item in noise_variance]
                data[left:right] += noise
            else:
                data[left:right] = np.random.poisson(
                    np.clip(data[left:right], 0, np.inf))

        return data

    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()

    def apply(self):
        self.signal()[:] = self.get_interpolated_spectrum()
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
Exemplo n.º 18
0
class FileImportManager(tr.HasTraits):
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    skip_rows = tr.Int(4, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])

    parse_csv_to_npy = tr.Button

    view = ui.View(ui.VGroup(
        ui.HGroup(
            ui.UItem('open_file_csv'),
            ui.UItem('file_csv', style='readonly'),
        ),
        ui.Item('skip_rows'),
        ui.Item('decimal'),
        ui.Item('delimiter'),
        ui.Item('parse_csv_to_npy', show_label=False),
    ))

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = ['*.csv', ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open', wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path

        """ Fill columns_headers_list """
        headers_array = np.array(
            pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                nrows=1, header=None
            )
        )[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(len(self.columns_headers_list)):
            column_array = np.array(pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal, skiprows=self.skip_rows, usecols=[i]))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.columns_headers_list[i] + '.npy'), column_array)

        print('Finsihed parsing csv into npy files.')
Exemplo n.º 19
0
class Librarian(traits.HasTraits):
    """Librarian provides a way of writing useful information into the 
    log folder for eagle logs. It is designed to make the information inside
    an eagle log easier to come back to. It mainly writes default strings into
    the comments file in the log folder"""

    logType = traits.Enum("important", "debug", "calibration")
    writeToOneNoteButton = traits.Button("save")
    refreshInformation = traits.Button("refresh")
    saveImage = traits.Button("save plot")
    axisList = AxisSelector()
    purposeBlock = EntryBlock(fieldName="What is the purpose of this log?")
    resultsBlock = EntryBlock(
        fieldName=
        "Explain what the data shows (important parameters that change, does it make sense etc.)?"
    )
    commentsBlock = EntryBlock(fieldName="Anything Else?")
    saveButton = traits.Button("Save")
    #    notebooks = traits.Enum(values = "notebookNames") # we could let user select from a range of notebooks
    #    notebookNames = traits.List
    notebookName = traits.String("Humphry's Notebook")
    sectionName = traits.String("Eagle Logs")
    logName = traits.String("")
    xAxis = traits.String("")
    yAxis = traits.String("")

    traits_view = traitsui.View(traitsui.VGroup(
        traitsui.Item("logName", show_label=False, style="readonly"),
        traitsui.Item("axisList",
                      show_label=False,
                      editor=traitsui.InstanceEditor(),
                      style='custom'),
        traitsui.Item("purposeBlock",
                      show_label=False,
                      editor=traitsui.InstanceEditor(),
                      style='custom'),
        traitsui.Item("resultsBlock",
                      show_label=False,
                      editor=traitsui.InstanceEditor(),
                      style='custom'),
        traitsui.Item("commentsBlock",
                      show_label=False,
                      editor=traitsui.InstanceEditor(),
                      style='custom'),
        traitsui.HGroup(
            traitsui.Item("writeToOneNoteButton", show_label=False),
            traitsui.Item("refreshInformation", show_label=False)),
    ),
                                resizable=True,
                                kind="live",
                                title="Eagle OneNote")

    def __init__(self, **traitsDict):
        """Librarian object requires the log folder it is referring to. If a .csv
        file is given as logFolder argument it will use parent folder as the 
        logFolder"""
        super(Librarian, self).__init__(**traitsDict)
        if os.path.isfile(self.logFolder):
            self.logFolder = os.path.split(self.logFolder)[0]
        else:
            logger.debug("found these in %s: %s" %
                         (self.logFolder, os.listdir(self.logFolder)))

        self.logName = os.path.split(self.logFolder)[1]
        self.logFile = os.path.join(self.logFolder,
                                    os.path.split(self.logFolder)[1] + ".csv")
        self.axisList.logFile = self.logFile  #needs a copy so it can calculate valid values
        self.axisList.masterList = self.axisList._masterList_default()
        self.axisList.masterListWithNone = self.axisList._masterListWithNone_default(
        )
        if self.xAxis != "":
            self.axisList.xAxis = self.xAxis
        if self.yAxis != "":
            self.axisList.yAxis = self.yAxis

        self.eagleOneNote = oneNotePython.eagleLogsOneNote.EagleLogOneNote(
            notebookName=self.notebookName, sectionName=self.sectionName)
        logPage = self.eagleOneNote.setPage(self.logName)
        #
        #        except Exception as e:
        #            logger.error("failed to created an EagleOneNote Instance. This could happen for many reasons. E.g. OneNote not installed or most likely, the registry is not correct. See known bug and fix in source code of onenotepython module:%s" % e.message)
        if logPage is not None:  #page exists
            self.purposeBlock.textBlock = self.eagleOneNote.getOutlineText(
                "purpose")
            self.resultsBlock.textBlock = self.eagleOneNote.getOutlineText(
                "results")
            self.commentsBlock.textBlock = self.eagleOneNote.getOutlineText(
                "comments")
            xAxis, yAxis, series = self.eagleOneNote.getParametersOutlineValues(
            )
            try:
                self.axisList.xAxis, self.axisList.yAxis, self.axisList.series = xAxis, yAxis, series
            except Exception as e:
                logger.error(
                    "error when trying to read analysis parameters: %s" %
                    e.message)
            self.pageExists = True
        else:
            self.pageExists = False
            self.purposeBlock.textBlock = ""
            self.resultsBlock.textBlock = ""
            self.commentsBlock.textBlock = ""
            #could also reset axis list but it isn't really necessary

    def _writeToOneNoteButton_fired(self):
        """writes content of librarian to one note page """
        if not self.pageExists:
            self.eagleOneNote.createNewEagleLogPage(self.logName,
                                                    refresh=True,
                                                    setCurrent=True)
            self.pageExists = True
        self.eagleOneNote.setOutline("purpose",
                                     self.purposeBlock.textBlock,
                                     rewrite=False)
        self.eagleOneNote.setOutline("results",
                                     self.resultsBlock.textBlock,
                                     rewrite=False)
        self.eagleOneNote.setOutline("comments",
                                     self.commentsBlock.textBlock,
                                     rewrite=False)
        self.eagleOneNote.setDataOutline(self.logName, rewrite=False)
        self.eagleOneNote.setParametersOutline(self.axisList.xAxis,
                                               self.axisList.yAxis,
                                               self.axisList.series,
                                               rewrite=False)
        self.eagleOneNote.currentPage.rewritePage()
        #now to get resizing done well we want to completely repull the XML and data
        #brute force method:
        self.eagleOneNote = oneNotePython.eagleLogsOneNote.EagleLogOneNote(
            notebookName=self.notebookName, sectionName=self.sectionName)
        logPage = self.eagleOneNote.setPage(
            self.logName)  #this sets current page of eagleOneNote
        self.eagleOneNote.organiseOutlineSizes()
Exemplo n.º 20
0
class SpikesRemoval(SpanSelectorInSpectrum):
    interpolator_kind = t.Enum(
        'Linear',
        'Spline',
        default = 'Linear')
    threshold = t.Float()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1,10, 3)
    interpolator = None
    default_spike_width = t.Int(5)
    index = t.Int(0)
    view = tu.View(tu.Group(
        tu.Group(
                 tu.Item('show_derivative_histogram', show_label=False),
                 'threshold',
                 show_border=True,),
        tu.Group(
            'interpolator_kind',
            'default_spike_width',
            tu.Group(
                'spline_order', 
                visible_when = 'interpolator_kind == \'Spline\''),
            show_border=True,
            label='Advanced settings'),
            ),
            buttons= [OKButton,
                      OurPreviousButton,
                      OurFindButton,
                      OurApplyButton,],
            handler = SpikesRemovalHandler,
            title = 'Spikes removal tool')
                 
    def __init__(self, signal,navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [coordinate for coordinate in 
                            signal.axes_manager._am_indices_generator()
                            if (navigation_mask is None or not 
                                navigation_mask[coordinate[::-1]])]
        self.signal = signal
        sys.setrecursionlimit(np.cumprod(self.signal.data.shape)[-1])
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        signal.axes_manager.indices = self.coordinates[0]
        self.threshold = 400
        self.index = 0
        self.argmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask
        
    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()
        
    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                navigation_mask=self.navigation_mask)
        
    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            return True
        else:
            return False

    def find(self, back=False):
        if ((self.index == len(self.coordinates) - 1 and back is False)
        or (back is True and self.index == 0)):
            messages.information('End of dataset reached')
            return
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()
        
        if self.detect_spike() is False:
            if back is False:
                self.index += 1
            else:
                self.index -= 1
            self.find(back=back)
        else:
            minimum = max(0,self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(
                    minimum),
                self.signal.axes_manager.signal_axes[0].index2value(
                    maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        self.signal._plot.pointer.update_patch_position()
        
    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False
        
    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False
        
    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
           
    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()
            
    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()
        
    def _ss_left_value_changed(self, old, new):
        self.span_selector_changed()
        
    def _ss_right_value_changed(self, old, new):
        self.span_selector_changed()
        
    def create_interpolation_line(self):
        self.interpolated_line = drawing.spectrum.SpectrumLine()
        self.interpolated_line.data_function = \
            self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(
            color='blue',
            type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()
        
    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if self.ss_left_value == self.ss_right_value:
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)
        
        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1
            
        return left,right
        
        
    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = 10
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = np.clip(left, 0, len(data))
        right = np.clip(right, 0, len(data))
        x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
        y = np.hstack((data[ileft:left], data[right:iright]))
        if ileft == 0:
            # Extrapolate to the left
            data[left:right] = data[right + 1]
            
        elif iright == (len(data) - 1):
            # Extrapolate to the right
            data[left:right] = data[left - 1]
            
        else:
            # Interpolate
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])
        
        # Add noise
        data = np.random.poisson(np.clip(data, 0, np.inf))
        return data

                      
    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()
            
    def apply(self):
        self.signal()[:] = self.get_interpolated_spectrum()
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
Exemplo n.º 21
0
class GCS(t.HasTraits):
    """
    This is the ground control station GUI class. 
    
    For usage, for example if telemetry radio is on COM8 and the GCS Pixhawk is
    on COM7:
    >>> gcs = GCS()
    >>> gcs.setup_uav_link("COM8")
    >>> gcs.poll_uav()
    >>> gcs.setup_gcs_link("COM7")
    >>> gcs.poll_gcs()
    >>> gcs.configure_traits()
    >>> gcs.close()
    """
    # Connections
    dialect = t.Str("gcs_pixhawk")
    show_errors = t.Bool(True)

    # ON GCS: Outgoing
    mission_message = t.Enum(set_mission_mode.keys(), label="Mission Type")
    sweep_angle = t.Float(label="Angle (degrees)")
    sweep_alt_start = t.Float(label="Start Altitude (m)")
    sweep_alt_end = t.Float(label="End Altitude (m)")
    sweep_alt_step = t.Float(label="Number of Altitude Steps")

    mavlink_message = t.Enum(mavlink_msgs, label="Mavlink Message")
    mavlink_message_filt = t.Enum(mavlink_msgs_filt, label="Mavlink Message")
    mavlink_message_params = t.Str(label="params")
    mavlink_message_args = t.Str(', '.join(
        mavlink_msgs_attr[mavlink_msgs_filt[0]]['args'][1:]),
                                 label="Arguments")

    # ON GCS: Incoming
    # Tether
    tether_length = t.Float(t.Undefined, label='Length (m)')
    tether_tension = t.Float(t.Undefined, label='Tension (N)')
    tether_velocity = t.Float(t.Undefined, label="Velocity")

    # ON GCS: Incoming
    # GCS Pixhawk
    gcs_eph = t.Float(t.Undefined)
    gcs_epv = t.Float(t.Undefined)
    gcs_satellites_visible = t.Int(t.Undefined)
    gcs_fix_type = t.Int(t.Undefined)

    gcs_airspeed = t.Float(t.Undefined)
    gcs_groundspeed = t.Float(t.Undefined)
    gcs_heading = t.Float(t.Undefined)
    gcs_velocity = t.Array(shape=(3, ))

    # Location inputs
    gcs_alt = t.Float(t.Undefined)
    gcs_lat = t.Float(t.Undefined)
    gcs_lon = t.Float(t.Undefined)

    # Attitude inputs
    gcs_pitch = t.Float(t.Undefined)
    gcs_roll = t.Float(t.Undefined)
    gcs_yaw = t.Float(t.Undefined)
    gcs_pitchspeed = t.Float(t.Undefined)
    gcs_yawspeed = t.Float(t.Undefined)
    gcs_rollspeed = t.Float(t.Undefined)

    # Battery Inputs
    gcs_current = t.Float(t.Undefined)
    gcs_level = t.Float(t.Undefined)
    gcs_voltage = t.Float(t.Undefined)

    # GCS connectinos
    gcs = t.Any(t.Undefined)
    gcs_polling = t.Bool(False)
    gcs_msg_thread = t.Instance(threading.Thread)
    gcs_error = t.Int(0)
    gcs_port = t.Str(t.Undefined)
    gcs_baud = t.Int(t.Undefined)

    # ON DRONE: Incoming
    # Mission Status
    mission_status = t.Enum(mission_status.keys())

    # Probe
    probe_u = t.Float(t.Undefined, label="u (m/s)")
    probe_v = t.Float(t.Undefined, label="v (m/s)")
    probe_w = t.Float(t.Undefined, label="w (m/s)")

    # Vehicle inputs
    uav_modename = t.Str(t.Undefined)
    uav_armed = t.Bool(t.Undefined)
    uav_eph = t.Float(t.Undefined)
    uav_epv = t.Float(t.Undefined)
    uav_satellites_visible = t.Int(t.Undefined)
    uav_fix_type = t.Int(t.Undefined)

    uav_airspeed = t.Float(t.Undefined)
    uav_groundspeed = t.Float(t.Undefined)
    uav_heading = t.Float(t.Undefined)
    uav_velocity = t.Array(shape=(3, ))

    # Location inputs
    uav_alt = t.Float(t.Undefined)
    uav_lat = t.Float(t.Undefined)
    uav_lon = t.Float(t.Undefined)

    # Attitude inputs
    uav_pitch = t.Float(t.Undefined)
    uav_roll = t.Float(t.Undefined)
    uav_yaw = t.Float(t.Undefined)
    uav_pitchspeed = t.Float(t.Undefined)
    uav_yawspeed = t.Float(t.Undefined)
    uav_rollspeed = t.Float(t.Undefined)

    # Battery Inputs
    uav_current = t.Float(t.Undefined)
    uav_level = t.Float(t.Undefined)
    uav_voltage = t.Float(t.Undefined)

    # Vehicle Connections
    uav = t.Any(t.Undefined)
    uav_polling = t.Bool(False)
    uav_msg_thread = t.Instance(threading.Thread)
    uav_error = t.Int(0)
    uav_port = t.Str(t.Undefined)
    uav_baud = t.Int(t.Undefined)

    # GCS connectinos
    gcs = t.Any(t.Undefined)
    gcs_polling = t.Bool(False)
    gcs_msg_thread = t.Instance(threading.Thread)
    gcs_error = t.Int(0)
    gcs_port = t.Str(t.Undefined)
    gcs_baud = t.Int(t.Undefined)

    # ui Buttons and display groups
    update_mission = t.Button("Update")
    send_mavlink_message = t.Button("Send")
    filtered = t.Bool(True)

    group_input = tui.Group(
        tui.Item(name="mission_status", enabled_when='False'),
        tui.Item(name="mission_message"),
        tui.Item(name="sweep_angle",
                 visible_when='mission_message=="SCHEDULE_SWEEP"'),
        tui.Item(name="sweep_alt_start",
                 visible_when='mission_message=="SCHEDULE_SWEEP"'),
        tui.Item(name="sweep_alt_end",
                 visible_when='mission_message=="SCHEDULE_SWEEP"'),
        tui.Item(name="sweep_alt_step",
                 visible_when='mission_message=="SCHEDULE_SWEEP"'),
        tui.Item(name="update_mission"),
        tui.Item("_"),
        tui.Item("filtered"),
        tui.Item("mavlink_message", visible_when='filtered==False'),
        tui.Item("mavlink_message_filt", visible_when='filtered'),
        tui.Item("mavlink_message_args",
                 enabled_when='False',
                 editor=tui.TextEditor(),
                 height=-40),
        tui.Item("mavlink_message_params"),
        tui.Item("send_mavlink_message"),
        tui.Item("_"),
        tui.Item(name="tether_tension", enabled_when='False'),
        tui.Item(name="tether_length", enabled_when='False'),
        tui.Item(name="tether_velocity", enabled_when='False'),
        tui.Item("_"),
        orientation="vertical",
        show_border=True,
        label="On GCS")
    group_uav = tui.Group(tui.Item(name="uav_modename", enabled_when='False'),
                          tui.Item(name="uav_airspeed", enabled_when='False'),
                          tui.Item(name="uav_groundspeed",
                                   enabled_when='False'),
                          tui.Item(name='uav_armed', enabled_when='False'),
                          tui.Item(name='uav_alt', enabled_when='False'),
                          tui.Item(name='uav_lat', enabled_when='False'),
                          tui.Item(name='uav_lon', enabled_when='False'),
                          tui.Item(name='uav_velocity', enabled_when='False'),
                          tui.Item(name='uav_pitch', enabled_when='False'),
                          tui.Item(name='uav_roll', enabled_when='False'),
                          tui.Item(name='uav_yaw', enabled_when='False'),
                          tui.Item(name='uav_current', enabled_when='False'),
                          tui.Item(name='uav_level', enabled_when='False'),
                          tui.Item(name='uav_voltage', enabled_when='False'),
                          tui.Item("_"),
                          tui.Item(name='probe_u', enabled_when='False'),
                          tui.Item(name='probe_v', enabled_when='False'),
                          tui.Item(name='probe_w', enabled_when='False'),
                          orientation='vertical',
                          show_border=True,
                          label="Incoming")
    group_gcs = tui.Group(tui.Item(name="gcs_airspeed", enabled_when='False'),
                          tui.Item(name="gcs_groundspeed",
                                   enabled_when='False'),
                          tui.Item(name='gcs_alt', enabled_when='False'),
                          tui.Item(name='gcs_lat', enabled_when='False'),
                          tui.Item(name='gcs_lon', enabled_when='False'),
                          tui.Item(name='gcs_velocity', enabled_when='False'),
                          tui.Item(name='gcs_pitch', enabled_when='False'),
                          tui.Item(name='gcs_roll', enabled_when='False'),
                          tui.Item(name='gcs_yaw', enabled_when='False'),
                          tui.Item(name='gcs_current', enabled_when='False'),
                          tui.Item(name='gcs_level', enabled_when='False'),
                          tui.Item(name='gcs_voltage', enabled_when='False'),
                          orientation='vertical',
                          show_border=True,
                          label="GCS")
    traits_view = tui.View(tui.Group(group_input,
                                     group_uav,
                                     group_gcs,
                                     orientation='horizontal'),
                           resizable=True)

    def _update_mission_fired(self):
        """ This will fire when the update_mission button is clicked
        
        In that case we send one of our custom MAVLINK messages, either
        set_mission_mode or schedule_sweep
        
        """
        mode = set_mission_mode[self.mission_message]
        if mode >= 0:
            self.uav.mav.set_mission_mode_send(mode)
        else:
            self.uav.mav.schedule_sweep_send(self.sweep_angle,
                                             self.sweep_alt_start,
                                             self.sweep_alt_end,
                                             self.sweep_alt_step)

    def _mavlink_message_changed(self):
        """ This will fire when the dropdown is changed
        """
        self.mavlink_message_args = ', '.join(
            mavlink_msgs_attr[self.mavlink_message]['args'][1:])

    def _mavlink_message_filt_changed(self):
        """ This will fire when the filtered dropdown is changed
        """
        self.mavlink_message_args = ', '.join(
            mavlink_msgs_attr[self.mavlink_message_filt]['args'][1:])

    def _send_mavlink_message_fired(self):
        """ This will fire when the send_mavlink_message button is clicked
        
        In that case we pass on the mavlink message that the user is trying
        to send. 
        """
        func = mavlink_msgs_attr[self.mavlink_message]['name']
        args = [float(m) for m in self.mavlink_message_params.split(',')]
        getattr(self.uav.mav, func)(*args)

    def setup_uav_link(self, uav_port, uav_baud=56700):
        """
        This sets up the connection to the UAV. 
        
        Parameters
        -----------
        uav_port : str
            Serial port where UAV is connected (via telemetry radio)
        uav_baud: int, optional
            The baud rate. Default is 56700
        """
        mavutil.set_dialect(self.dialect)
        self.uav = mavutil.mavlink_connection(uav_port, uav_baud)
        self.uav_port = uav_port
        self.uav_baud = uav_baud

    def setup_gcs_link(self, gcs_port, gcs_baud=115200):
        """
        This sets up the connection to the GCS Pixhawk. 
        
        Parameters
        -----------
        uav_port : str
            Serial port where GCS Pixhawk is connected (via usb cable)
        uav_baud: int, optional
            The baud rate. Default is 115200
        """
        mavutil.set_dialect(self.dialect)
        self.gcs = mavutil.mavlink_connection(gcs_port, gcs_baud)
        self.gcs_port = gcs_port
        self.gcs_baud = gcs_baud

    def poll_uav(self):
        """
        This runs a new thread that listens for messages from the UAV and
        parses them for the GCS
        """
        self.uav_polling = True

        def worker():
            # Make sure we are connected
            m = self.uav
            m.mav.heartbeat_send(mavutil.mavlink.MAV_TYPE_GCS,
                                 mavutil.mavlink.MAV_AUTOPILOT_INVALID, 0, 0,
                                 0)
            print("Waiting for heartbeat from %s" % m.address)
            self.uav.wait_heartbeat()
            print "Found Heardbeat, continuing"

            i = 0
            while self.uav_polling:
                #                print "uav_polling round", i
                i += 1
                try:
                    s = m.recv(16 * 1024)
                except Exception:
                    time.sleep(0.1)
                # prevent a dead serial port from causing the CPU to spin. The user hitting enter will
                # cause it to try and reconnect
                if len(s) == 0:
                    time.sleep(0.1)

                if 'windows' in platform.architecture()[-1].lower():
                    # strip nsh ansi codes
                    s = s.replace("\033[K", "")

                if m.first_byte:
                    m.auto_mavlink_version(s)
                msgs = m.mav.parse_buffer(s)
                if msgs:
                    for msg in msgs:
                        if getattr(m, '_timestamp', None) is None:
                            m.post_message(msg)
                        if msg.get_type() == "BAD_DATA":
                            if self.show_errors:
                                print "MAV error: %s" % msg
                            self.uav_error += 1
                        else:
                            self.parse_uav_msg(msg)
            print "uav_polling Stopped"
            self.uav_polling = False

        self.uav_msg_thread = threading.Thread(target=worker)
        self.uav_msg_thread.start()

    def poll_gcs(self):
        """
        This runs a new thread that listens for messages from the GCS Pixhawk
        and parses them for the GCS, it also forwards relevant messages to the
        UAV
        """
        self.gcs_polling = True

        def worker():
            # Make sure we are connected
            m = self.gcs
            m.mav.heartbeat_send(mavutil.mavlink.MAV_TYPE_GCS,
                                 mavutil.mavlink.MAV_AUTOPILOT_INVALID, 0, 0,
                                 0)
            print("Waiting for heartbeat from %s" % m.address)
            self.gcs.wait_heartbeat()
            print "Found Heardbeat, continuing"
            i = 0
            while self.gcs_polling:
                #                print "gcs_polling round", i
                i += 1
                try:
                    s = m.recv(16 * 1024)
                except Exception:
                    time.sleep(0.1)
                # prevent a dead serial port from causing the CPU to spin. The user hitting enter will
                # cause it to try and reconnect
                if len(s) == 0:
                    time.sleep(0.1)

                if 'windows' in platform.architecture()[-1].lower():
                    # strip nsh ansi codes
                    s = s.replace("\033[K", "")

                if m.first_byte:
                    m.auto_mavlink_version(s)
                msgs = m.mav.parse_buffer(s)
                if msgs:
                    for msg in msgs:
                        if getattr(m, '_timestamp', None) is None:
                            m.post_message(msg)
                        if msg.get_type() == "BAD_DATA":
                            if self.show_errors:
                                print "MAV error: %s" % msg
                            self.gcs_error += 1
                        else:
                            self.parsefwd_gcs_msg(msg)
            print "gcs_polling Stopped"
            self.gcs_polling = False

        self.gcs_msg_thread = threading.Thread(target=worker)
        self.gcs_msg_thread.start()

    def parse_uav_msg(self, m):
        """
        This parses a message received from the UAV and stores the values
        in the class attributes so that the GUI will update
        """
        #        print "Parsing Message"
        typ = m.get_type()
        if typ == 'GLOBAL_POSITION_INT':
            (self.uav_lat, self.uav_lon) = (m.lat / 1.0e7, m.lon / 1.0e7)
            self.uav_velocity = (m.vx / 100.0, m.vy / 100.0, m.vz / 100.0)
        elif typ == 'GPS_RAW':
            pass  # better to just use global position int
            # (self.lat, self.lon) = (m.lat, m.lon)
            # self.__on_change('location')
        elif typ == 'GPS_RAW_INT':
            # (self.lat, self.lon) = (m.lat / 1.0e7, m.lon / 1.0e7)
            self.uav_eph = m.eph
            self.uav_epv = m.epv
            self.uav_satellites_visible = m.satellites_visible
            self.uav_fix_type = m.fix_type
        elif typ == "VFR_HUD":
            self.uav_heading = m.heading
            self.uav_alt = m.alt
            self.uav_airspeed = m.airspeed
            self.uav_groundspeed = m.groundspeed
        elif typ == "ATTITUDE":
            self.uav_pitch = m.pitch
            self.uav_yaw = m.yaw
            self.uav_roll = m.roll
            self.uav_pitchspeed = m.pitchspeed
            self.uav_yawspeed = m.yawspeed
            self.uav_rollspeed = m.rollspeed
        elif typ == "SYS_STATUS":
            self.uav_voltage = m.voltage_battery
            self.uav_current = m.current_battery
            self.uav_level = m.battery_remaining
        elif typ == "HEARTBEAT":
            pass
#        print "Parsing Message DONE"

    def fwd_msg_to_uav(self, m):
        """This forwards messages from the GCS Pixhawk to the UAV if there is
        a UAV connected"""
        if self.uav is not t.Undefined:
            self.uav.write(m.get_msgbuf())

    def parsefwd_gcs_msg(self, m):
        """
        This parses a message received from the GCS Pixhawk, stores the values
        in the class attributes so that the GUI will update, and forwards 
        relevant messages to the UAV
        """
        #        print "Parsing Message"
        typ = m.get_type()
        if typ == 'GLOBAL_POSITION_INT':
            (self.gcs_lat, self.gcs_lon) = (m.lat / 1.0e7, m.lon / 1.0e7)
            self.gcs_velocity = (m.vx / 100.0, m.vy / 100.0, m.vz / 100.0)
            # Forward message
            self.fwd_msg_to_uav(m)
        elif typ == 'GPS_RAW':
            # better to just use global position int
            # (self.lat, self.lon) = (m.lat, m.lon)
            # self.__on_change('location')
            # Forward message
            self.fwd_msg_to_uav(m)
        elif typ == 'GPS_RAW_INT':
            # (self.lat, self.lon) = (m.lat / 1.0e7, m.lon / 1.0e7)
            self.gcs_eph = m.eph
            self.gcs_epv = m.epv
            self.gcs_satellites_visible = m.satellites_visible
            self.gcs_fix_type = m.fix_type
            # Forward message
            self.fwd_msg_to_uav(m)
        elif typ == "VFR_HUD":
            self.gcs_heading = m.heading
            self.gcs_alt = m.alt
            self.gcs_airspeed = m.airspeed
            self.gcs_groundspeed = m.groundspeed
            # Forward message
            self.fwd_msg_to_uav(m)
        elif typ == "ATTITUDE":
            self.gcs_pitch = m.pitch
            self.gcs_yaw = m.yaw
            self.gcs_roll = m.roll
            self.gcs_pitchspeed = m.pitchspeed
            self.gcs_yawspeed = m.yawspeed
            self.gcs_rollspeed = m.rollspeed
            # Forward message
            self.fwd_msg_to_uav(m)
        elif typ == "SYS_STATUS":
            self.gcs_voltage = m.voltage_battery
            self.gcs_current = m.current_battery
            self.gcs_level = m.battery_remaining
        elif typ == "HEARTBEAT":
            # Forward message
            self.fwd_msg_to_uav(m)


#        print "Parsing Message DONE"

    def close(self, *args, **kwargs):
        """
        This closes down the serial connections and stop the GUI polling
        """
        print 'Closing down connection'
        try:
            self.uav_polling = False
            self.uav.close()
        except:
            pass
        try:
            self.gcs_polling = False
            self.gcs.close()
        except:
            pass
Exemplo n.º 22
0
class ModelPlot(tapi.HasStrictTraits):

    Plot_Data = tapi.Instance(Plot2D)
    plot_info = tapi.Dict(tapi.Int, tapi.Dict(tapi.Str, tapi.List(tapi.Str)))
    Multi_Select = tapi.Instance(MultiSelect)
    Change_Axis = tapi.Instance(ChangeAxis)
    Reset_Zoom = tapi.Button('Reset Zoom')
    Reload_Data = tapi.Button('Reload Data')
    Print_to_PDF = tapi.Button('Print to PDF')
    Load_Overlay = tapi.Button('Open Overlay')
    Close_Overlay = tapi.Button('Close Overlay')
    Step = tapi.Int
    low_step = tapi.Int
    high_step = tapi.Int
    X_Scale = tapi.String("1.0")
    Y_Scale = tapi.String("1.0")
    Single_Select_Overlay_Files = tapi.Instance(SingleSelectOverlayFiles)
    filepaths = tapi.List(tapi.String)
    file_variables = tapi.List(tapi.String)

    def __init__(self, **traits):
        """Put together information to be sent to Plot2D information
        needed:

        plot_info : dict
           {0: {file_0: header_0}}
           {1: {file_1: header_1}}
           ...
           {n: {file_n: header_n}}
        variables : list
           list of variables that changed from one simulation to another
        x_idx : int
           column containing x variable to be plotted

        """

        tapi.HasStrictTraits.__init__(self, **traits)
        fileinfo = get_sorted_fileinfo(self.filepaths)
        data = []
        for idx, (fnam, fhead, fdata) in enumerate(fileinfo):
            if idx == 0: mheader = fhead
            self.plot_info[idx] = {fnam: fhead}
            data.append(fdata)

        self.Plot_Data = Plot2D(plot_data=data,
                                variables=self.file_variables,
                                x_idx=0,
                                plot_info=self.plot_info)
        self.Multi_Select = MultiSelect(choices=mheader, plot=self.Plot_Data)
        self.Change_Axis = ChangeAxis(Plot_Data=self.Plot_Data,
                                      headers=mheader)
        self.Single_Select_Overlay_Files = SingleSelectOverlayFiles(choices=[])
        self.low_step = 0
        self.high_step = 1
        pass

    def _Reset_Zoom_fired(self):
        self.Plot_Data.change_plot(self.Plot_Data.plot_indices)

    def _X_Scale_changed(self, scale):
        """Detect if the x-axis scale was changed and let the plotter know

        Parameters
        ----------
        scale : str
           The user entered scale

        Returns
        -------
        None

        Notes
        -----
        scale should be a float, one of the operations in LDICT, or one of the
        optional magic keywords: min, max, normalize. On entry, scale is
        stripped, and if an empty string is sent in, it is reset to 1.0. If
        the magic words min or max are specified, the scale is set to the min
        or max of the x-axis data for the FIRST set of data. If the magic
        keyword normalize is specified, scale is set to 1 / max.

        """
        scale = scale.strip()
        if not scale:
            scale = self.X_Scale = "1.0"
        if scale == "max":
            scale = str(self.Plot_Data.max_x())
        elif scale == "min":
            scale = str(self.Plot_Data.min_x())
        elif scale == "normalize":
            _max = self.Plot_Data.abs_max_x()
            _max = 1. if _max < EPSILON else _max
            scale = str(1. / _max)
        try:
            scale = float(eval(scale, GDICT, LDICT))
        except:
            return
        self.Plot_Data.change_plot(self.Plot_Data.plot_indices, x_scale=scale)
        return

    def _Y_Scale_changed(self, scale):
        """Detect if the y-axis scale was changed and let the plotter know

        Parameters
        ----------
        scale : str
           The user entered scale

        Returns
        -------
        None

        Notes
        -----

        scale should be a float, one of the operations in LDICT, or one of the
        optional magic keywords: min, max, normalize. On entry, scale is
        stripped, and if an empty string is sent in, it is reset to 1.0. If
        the magic words min or max are specified, the scale is set to the min
        or max of the y-axis data for the FIRST set of data. If the magic
        keyword normalize is specified, scale is set to 1 / max.

        """
        scale = scale.strip()
        if not scale:
            scale = self.Y_Scale = "1.0"
        if scale == "max":
            scale = str(self.Plot_Data.max_y())
        elif scale == "min":
            scale = str(self.Plot_Data.min_y())
        elif scale == "normalize":
            _max = self.Plot_Data.abs_max_y()
            _max = 1. if _max < EPSILON else _max
            scale = str(1. / _max)
        try:
            scale = float(eval(scale, GDICT, LDICT))
        except:
            return
        self.Plot_Data.change_plot(self.Plot_Data.plot_indices, y_scale=scale)
        return

    def _Reload_Data_fired(self):
        self.Step = 0
        self.reload_data()

    @tapi.on_trait_change('Step')
    def get_data_at_step(self):
        self.reload_data()

    def reload_data(self):
        fileinfo = get_sorted_fileinfo(self.filepaths, step=self.Step)
        data = []
        for idx, (fnam, fhead, fdata) in enumerate(fileinfo):
            if idx == 0: mheader = fhead
            self.plot_info[idx] = {fnam: fhead}
            data.append(fdata)
        self.Plot_Data.plot_data = data
        self.Plot_Data.plot_info = self.plot_info
        self.Multi_Select.choices = mheader
        self.Change_Axis.headers = mheader
        self.Plot_Data.change_plot(self.Plot_Data.plot_indices)

    def _Print_to_PDF_fired(self):
        if not XY_DATA:
            return

        # get the maximum of Y for normalization
        ymax = max(np.amax(np.abs(xyd.y)) for xyd in XY_DATA)

        # setup figure
        plt.figure(0)
        plt.cla()
        plt.clf()

        # plot y value for each plot on window
        ynames = []
        for xyd in sorted(XY_DATA, key=lambda x: x.lw, reverse=True):
            label = xyd.key + ":" + xyd.yname if len(
                XY_DATA) > 1 else xyd.yname
            ynames.append(xyd.yname)
            plt.plot(xyd.x, xyd.y / ymax, label=label, lw=xyd.lw)
        yname = common_prefix(ynames)
        plt.xlabel(xyd.xname)
        plt.ylabel(yname)
        plt.legend(loc="best")
        plt.savefig("{0}-vs-{1}.pdf".format(yname, xyd.xname))

    def _Close_Overlay_fired(self):
        if self.Single_Select_Overlay_Files.selected:
            index = self.Single_Select_Overlay_Files.choices.index(
                self.Single_Select_Overlay_Files.selected)
            self.Single_Select_Overlay_Files.choices.remove(
                self.Single_Select_Overlay_Files.selected)
            del self.Plot_Data.overlay_plot_data[
                self.Single_Select_Overlay_Files.selected]
            if not self.Single_Select_Overlay_Files.choices:
                self.Single_Select_Overlay_Files.selected = ""
            else:
                if index >= len(self.Single_Select_Overlay_Files.choices):
                    index = len(self.Single_Select_Overlay_Files.choices) - 1
                self.Single_Select_Overlay_Files.selected = self.Single_Select_Overlay_Files.choices[
                    index]
            self.Plot_Data.change_plot(self.Plot_Data.plot_indices)

    def _Load_Overlay_fired(self):
        dialog = papi.FileDialog(action="open")
        dialog.open()
        info = {}
        if dialog.return_code == papi.OK:
            for eachfile in dialog.paths:
                try:
                    fhead, fdata = loadcontents(eachfile)
                except:
                    logmes("{0}: Error reading overlay data".format(eachfile))
                    continue
                fnam = os.path.basename(eachfile)
                self.Plot_Data.overlay_plot_data[fnam] = fdata
                self.Plot_Data.overlay_headers[fnam] = fhead
                self.Single_Select_Overlay_Files.choices.append(fnam)
                continue
            self.Plot_Data.change_plot(self.Plot_Data.plot_indices)
        return
Exemplo n.º 23
0
class Matplotlibify(traits.HasTraits):

    logFilePlotReference = traits.Instance(
        logFilePlot.LogFilePlot
    )  #gives access to most of the required attributes
    logFilePlotsReference = traits.Instance(
        logFilePlots.LogFilePlots)  #refernce to logFilePlots object
    xAxisLabel = traits.String("")
    yAxisLabel = traits.String("")
    xAxisLabel2 = traits.String("")  #used if in dual plot mode
    yAxisLabel2 = traits.String("")
    legendReplacements = traits.Dict(key_trait=traits.String,
                                     value_trait=traits.String)
    #xLim = traits.Tuple()
    replacementStrings = {}
    setXLimitsBool = traits.Bool(False)
    setYLimitsBool = traits.Bool(False)

    xMin = traits.Float
    xMax = traits.Float
    yMin = traits.Float
    yMax = traits.Float

    matplotlibifyMode = traits.Enum("default", "dual plot")
    logFilePlot1 = traits.Any(
    )  #will be mapped traits of name of log file plot to lfp reference
    logFilePlot2 = traits.Any(
    )  #will be mapped traits of name of log file plot to lfp reference

    generatePlotScriptButton = traits.Button("generate plot")
    showPlotButton = traits.Button("show")
    templatesFolder = os.path.join("C:", "Users", "tharrison", "Google Drive",
                                   "Thesis", "python scripts", "matplotlibify")
    templateFile = traits.File(
        os.path.join(templatesFolder, "matplotlibifyDefaultTemplate.py"))
    generatedScriptLocation = traits.File(
        os.path.join("C:", "Users", "tharrison", "Google Drive", "Thesis",
                     "python scripts", "matplotlibify", "debug.py"))

    secondPlotGroup = traitsui.VGroup(
        traitsui.Item("matplotlibifyMode", label="add second plot"),
        traitsui.HGroup(
            traitsui.Item("logFilePlot1",
                          visible_when="matplotlibifyMode=='dual plot'"),
            traitsui.Item("logFilePlot2",
                          visible_when="matplotlibifyMode=='dual plot'")))

    labelsGroup = traitsui.VGroup(
        traitsui.HGroup(traitsui.Item("xAxisLabel"),
                        traitsui.Item("yAxisLabel")),
        traitsui.HGroup(
            traitsui.Item("xAxisLabel2",
                          label="X axis label (2nd)",
                          visible_when="matplotlibifyMode=='dual plot'"),
            traitsui.Item("yAxisLabel2",
                          label="Y axis label (2nd)",
                          visible_when="matplotlibifyMode=='dual plot'")))

    limitsGroup = traitsui.VGroup(
        traitsui.Item("setXLimitsBool", label="set x limits?"),
        traitsui.Item("setYLimitsBool", label="set x limits?"),
        traitsui.HGroup(
            traitsui.Item("xMin", label="x min",
                          visible_when="setXLimitsBool"),
            traitsui.Item("xMax", label="x max",
                          visible_when="setXLimitsBool"),
            traitsui.Item("yMin", label="y min",
                          visible_when="setYLimitsBool"),
            traitsui.Item("yMax", label="y max",
                          visible_when="setYLimitsBool")))

    traits_view = traitsui.View(secondPlotGroup,
                                labelsGroup,
                                limitsGroup,
                                traitsui.Item("legendReplacements"),
                                traitsui.Item("templateFile"),
                                traitsui.Item("generatedScriptLocation"),
                                traitsui.Item('generatePlotScriptButton'),
                                traitsui.Item('showPlotButton'),
                                resizable=True)

    def __init__(self, **traitsDict):
        super(Matplotlibify, self).__init__(**traitsDict)
        self.generateReplacementStrings()
        self.add_trait(
            "logFilePlot1",
            traits.Trait(
                self.logFilePlotReference.logFilePlotsTabName, {
                    lfp.logFilePlotsTabName: lfp
                    for lfp in self.logFilePlotsReference.lfps
                }))
        self.add_trait(
            "logFilePlot2",
            traits.Trait(
                self.logFilePlotReference.logFilePlotsTabName, {
                    lfp.logFilePlotsTabName: lfp
                    for lfp in self.logFilePlotsReference.lfps
                }))

    def generateReplacementStrings(self):
        self.replacementStrings = {}

        if self.matplotlibifyMode == 'default':
            specific = self.getReplacementStringsFor(self.logFilePlotReference)
            generic = self.getGlobalReplacementStrings()
            self.replacementStrings.update(specific)
            self.replacementStrings.update(generic)

        elif self.matplotlibifyMode == 'dual plot':
            specific1 = self.getReplacementStringsFor(self.logFilePlot1_,
                                                      identifier="lfp1.")
            specific2 = self.getReplacementStringsFor(self.logFilePlot2_,
                                                      identifier="lfp2.")
            generic = self.getGlobalReplacementStrings()
            self.replacementStrings.update(specific1)
            self.replacementStrings.update(specific2)
            self.replacementStrings.update(generic)

        for key in self.replacementStrings.keys(
        ):  #wrap strings in double quotes
            logger.info("%s = %s" % (self.replacementStrings[key],
                                     type(self.replacementStrings[key])))
            if isinstance(self.replacementStrings[key], (str, unicode)):
                self.replacementStrings[key] = unicode(
                    self.wrapInQuotes(self.replacementStrings[key]))

    def getReplacementStringsFor(self, logFilePlot, identifier=""):
        """generates the replacement strings that are specific to a log file plot.
        indentifier is used inside key to make it unique to that lfp and should have the format
        {{lfp.mode}}. Identifier must include the . character"""
        return {
            '{{%smode}}' % identifier:
            logFilePlot.mode,
            '{{%slogFile}}' % identifier:
            logFilePlot.logFile,
            '{{%sxAxis}}' % identifier:
            logFilePlot.xAxis,
            '{{%syAxis}}' % identifier:
            logFilePlot.yAxis,
            '{{%saggregateAxis}}' % identifier:
            logFilePlot.aggregateAxis,
            '{{%sseries}}' % identifier:
            logFilePlot.series,
            '{{%sfiterYs}}' % identifier:
            logFilePlot.filterYs,
            '{{%sfilterMinYs}}' % identifier:
            logFilePlot.filterMinYs,
            '{{%sfilterMaxYs}}' % identifier:
            logFilePlot.filterMaxYs,
            '{{%sfilterXs}}' % identifier:
            logFilePlot.filterXs,
            '{{%sfilterMinXs}}' % identifier:
            logFilePlot.filterMinXs,
            '{{%sfilterMaxXs}}' % identifier:
            logFilePlot.filterMaxXs,
            '{{%sfilterNaN}}' % identifier:
            logFilePlot.filterNaN,
            '{{%sfilterSpecific}}' % identifier:
            logFilePlot.filterSpecific,
            '{{%sfilterSpecificString}}' % identifier:
            logFilePlot.filterSpecificString,
            '{{%sxLogScale}}' % identifier:
            logFilePlot.xLogScale,
            '{{%syLogScale}}' % identifier:
            logFilePlot.yLogScale,
            '{{%sinterpretAsTimeAxis}}' % identifier:
            logFilePlot.interpretAsTimeAxis
        }

    def getGlobalReplacementStrings(self, identifier=""):
        """generates the replacement strings that are specific to a log file plot """
        return {
            '{{%sxAxisLabel}}' % identifier: self.xAxisLabel,
            '{{%syAxisLabel}}' % identifier: self.yAxisLabel,
            '{{%sxAxisLabel2}}' % identifier: self.xAxisLabel2,
            '{{%syAxisLabel2}}' % identifier: self.yAxisLabel2,
            '{{%slegendReplacements}}' % identifier: self.legendReplacements,
            '{{%ssetXLimitsBool}}' % identifier: self.setXLimitsBool,
            '{{%ssetYLimitsBool}}' % identifier: self.setYLimitsBool,
            '{{%sxlimits}}' % identifier: (self.xMin, self.xMax),
            '{{%sylimits}}' % identifier: (self.yMin, self.yMax),
            '{{%smatplotlibifyMode}}' % identifier: self.matplotlibifyMode
        }

    def wrapInQuotes(self, string):
        return '"%s"' % string

    def _xAxisLabel_default(self):
        return self.logFilePlotReference.xAxis

    def _yAxisLabel_default(self):
        return self.logFilePlotReference.yAxis

    def _legendReplacements_default(self):
        return {_: _ for _ in self.logFilePlotReference.parseSeries()}

    def _xMin_default(self):
        return self.logFilePlotReference.firstPlot.x_axis.mapper.range.low

    def _xMax_default(self):
        return self.logFilePlotReference.firstPlot.x_axis.mapper.range.high

    def _yMin_default(self):
        return self.logFilePlotReference.firstPlot.y_axis.mapper.range.low

    def _yMax_default(self):
        return self.logFilePlotReference.firstPlot.y_axis.mapper.range.high

    def _generatedScriptLocation_default(self):
        root = os.path.join("C:", "Users", "tharrison", "Google Drive",
                            "Thesis", "python scripts", "matplotlibify")
        head, tail = os.path.split(self.logFilePlotReference.logFile)
        matplotlibifyName = os.path.splitext(tail)[0] + "-%s-vs-%s" % (
            self._yAxisLabel_default(), self._xAxisLabel_default())
        baseName = os.path.join(root, matplotlibifyName)
        filename = baseName + ".py"
        c = 0
        while os.path.exists(baseName + ".py"):
            filename = baseName + "-%s.py" % c
        return filename

    def replace_all(self, text, replacementDictionary):
        for placeholder, new in replacementDictionary.iteritems():
            text = text.replace(placeholder, str(new))
        return text

    def _generatePlotScriptButton_fired(self):
        logger.info("attempting to generate matplotlib script...")
        self.generateReplacementStrings()
        with open(self.templateFile, "rb") as template:
            text = self.replace_all(template.read(), self.replacementStrings)
        with open(self.generatedScriptLocation, "wb") as output:
            output.write(text)
        logger.info("succesfully generated matplotlib script at location %s " %
                    self.generatedScriptLocation)

    def _showPlotButton_fired(self):
        logger.info("attempting to show matplotlib plot...")
        self.generateReplacementStrings()
        with open(self.templateFile, "rb") as template:
            text = self.replace_all(template.read(), self.replacementStrings)
        ns = {}
        exec text in ns
        logger.info("exec completed succesfully...")

    def _matplotlibifyMode_changed(self):
        """change default template depending on whether or not this is a double axis plot """
        if self.matplotlibifyMode == "default":
            self.templateFile = os.path.join(
                self.templatesFolder, "matplotlibifyDefaultTemplate.py")
        elif self.matplotlibifyMode == "dual plot":
            self.templateFile = os.path.join(
                self.templatesFolder, "matplotlibifyDualPlotTemplate.py")
            self.xAxisLabel2 = self.logFilePlot2.xAxis
            self.yAxisLabel2 = self.logFilePlot2.yAxis
Exemplo n.º 24
0
class LogFilePlotFitter(traits.HasTraits):
    """This class allows the user to fit the data in log file plots with standard 
    functions or a custom function"""

    model = traits.Trait(
        "Gaussian", {
            "Linear": Model(fittingFunctions.linear),
            "Quadratic": Model(fittingFunctions.quadratic),
            "Gaussian": Model(fittingFunctions.gaussian),
            "lorentzian": Model(fittingFunctions.lorentzian),
            "parabola": Model(fittingFunctions.parabola),
            "exponential": Model(fittingFunctions.exponentialDecay),
            "sineWave": Model(fittingFunctions.sineWave),
            "sineWaveDecay1": Model(fittingFunctions.sineWaveDecay1),
            "sineWaveDecay2": Model(fittingFunctions.sineWaveDecay2),
            "sincSquared": Model(fittingFunctions.sincSquared),
            "sineSquared": Model(fittingFunctions.sineSquared),
            "sineSquaredDecay": Model(fittingFunctions.sineSquaredDecay),
            "custom": Model(custom)
        },
        desc="model selected for fitting the data"
    )  # mapped trait. so model --> string and model_ goes to Model object. see http://docs.enthought.com/traits/traits_user_manual/custom.html#mapped-traits
    parametersList = traits.List(
        Parameter, desc="list of parameters for fitting in chosen model")

    customCode = traits.Code(
        "def custom(x, param1, param2):\n\treturn param1*param2*x",
        desc="python code for a custom fitting function")
    customCodeCompileButton = traits.Button(
        "compile",
        desc=
        "defines the above function and assigns it to the custom model for fitting"
    )
    fitButton = traits.Button(
        "fit",
        desc="runs fit on selected data set using selected parameters and model"
    )
    usePreviousFitButton = traits.Button(
        "use previous fit",
        desc="use the fitted values as the initial guess for the next fit")
    guessButton = traits.Button(
        "guess",
        desc=
        "guess initial values from data using _guess function in library. If not defined button is disabled"
    )
    saveFitButton = traits.Button(
        "save fit",
        desc="writes fit parameters values and tolerances to a file")
    cycleAndFitButton = traits.Button(
        "cycle fit",
        desc=
        "fits using current initial parameters, saves fit, copies calculated values to initial guess and moves to next dataset in ordered dict"
    )
    dataSets = collections.OrderedDict(
    )  #dict mapping dataset name (for when we have multiple data sets) --> (xdata,ydata ) tuple (scipy arrays) e.g. {"myData": (array([1,2,3]), array([1,2,3]))}
    dataSetNames = traits.List(traits.String)
    selectedDataSet = traits.Enum(values="dataSetNames")
    modelFitResult = None
    logFilePlotReference = None
    modelFitMessage = traits.String("not yet fitted")
    isFitted = traits.Bool(False)
    maxFitTime = traits.Float(
        10.0, desc="maximum time fitting can last before abort")
    statisticsButton = traits.Button("stats")
    statisticsString = traits.String("statistics not calculated")
    plotPoints = traits.Int(200, label="Number of plot points")

    predefinedModelGroup = traitsui.VGroup(
        traitsui.Item("model", show_label=False),
        traitsui.Item("object.model_.definitionString",
                      style="readonly",
                      show_label=False,
                      visible_when="model!='custom'"))
    customFunctionGroup = traitsui.VGroup(traitsui.Item("customCode",
                                                        show_label=False),
                                          traitsui.Item(
                                              "customCodeCompileButton",
                                              show_label=False),
                                          visible_when="model=='custom'")
    modelGroup = traitsui.VGroup(predefinedModelGroup,
                                 customFunctionGroup,
                                 show_border=True)
    dataAndFittingGroup = traitsui.VGroup(
        traitsui.HGroup(
            traitsui.Item("selectedDataSet", label="dataset"),
            traitsui.Item("fitButton", show_label=False),
            traitsui.Item("usePreviousFitButton", show_label=False),
            traitsui.Item("guessButton",
                          show_label=False,
                          enabled_when="model_.guessFunction is not None")),
        traitsui.HGroup(traitsui.Item("cycleAndFitButton", show_label=False),
                        traitsui.Item("saveFitButton", show_label=False),
                        traitsui.Item("statisticsButton", show_label=False)),
        traitsui.Item("plotPoints"),
        traitsui.Item("statisticsString", style="readonly"),
        traitsui.Item("modelFitMessage", style="readonly"),
        show_border=True)
    variablesGroup = traitsui.VGroup(traitsui.Item(
        "parametersList",
        editor=traitsui.ListEditor(style="custom"),
        show_label=False,
        resizable=True),
                                     show_border=True,
                                     label="parameters")

    traits_view = traitsui.View(traitsui.Group(modelGroup,
                                               dataAndFittingGroup,
                                               variablesGroup,
                                               layout="split"),
                                resizable=True)

    def __init__(self, **traitsDict):
        super(LogFilePlotFitter, self).__init__(**traitsDict)
        self._set_parametersList()

    def _set_parametersList(self):
        """sets the parameter list to the correct values given the current model """
        self.parametersList = [
            Parameter(name=parameterName, parameter=parameterObject)
            for (parameterName,
                 parameterObject) in self.model_.parameters.iteritems()
        ]

    def _model_changed(self):
        """updates model and hences changes parameters appropriately"""
        self._set_parametersList()
        self._guessButton_fired(
        )  # will only guess if there is a valid guessing function

    def _customCodeCompileButton_fired(self):
        """defines function as defined by user """
        exec(self.customCode)
        self.model_.__init__(custom)
        self._set_parametersList()

    def setFitData(self, name, xData, yData):
        """updates the dataSets dictionary """
        self.dataSets[name] = (xData, yData)

    def cleanValidNames(self, uniqueValidNames):
        """removes any elements from datasets dictionary that do not 
        have a key in uniqueValidNames"""
        for dataSetName in self.dataSets.keys():
            if dataSetName not in uniqueValidNames:
                del self.dataSets[dataSetName]

    def setValidNames(self):
        """sets list of valid choices for datasets """
        self.dataSetNames = self.dataSets.keys()

    def getParameters(self):
        """ returns the lmfit parameters object for the fit function"""
        return lmfit.Parameters(
            {_.name: _.parameter
             for _ in self.parametersList})

    def _setCalculatedValues(self, modelFitResult):
        """updates calculated values with calculated argument """
        parametersResult = modelFitResult.params
        for variable in self.parametersList:
            variable.calculatedValue = parametersResult[variable.name].value

    def _setCalculatedValuesErrors(self, modelFitResult):
        """given the covariance matrix returned by scipy optimize fit
        convert this into stdeviation errors for parameters list and updated
        the stdevError attribute of variables"""
        parametersResult = modelFitResult.params
        for variable in self.parametersList:
            variable.stdevError = parametersResult[variable.name].stderr

    def fit(self):
        params = self.getParameters()
        x, y = self.dataSets[self.selectedDataSet]
        self.modelFitResult = self.model_.model.fit(y, x=x, params=params)
        #self.modelFitResult = self.model_.model.fit(y, x=x, params=params,iter_cb=self.getFitCallback(time.time()))#can also pass fit_kws= {"maxfev":1000}
        self._setCalculatedValues(
            self.modelFitResult)  #update fitting paramters final values
        self._setCalculatedValuesErrors(self.modelFitResult)
        self.modelFitMessage = self.modelFitResult.message
        if not self.modelFitResult.success:
            logger.error("failed to fit in LogFilePlotFitter")
        self.isFitted = True
        if self.logFilePlotReference is not None:
            self.logFilePlotReference.plotFit()

    def getFitCallback(self, startTime):
        """returns the callback function that is called at every iteration of fit to check if it 
        has been running too long"""
        def fitCallback(params, iter, resid, *args, **kws):
            """check the time and compare to start time """
            if time.time() - startTime > self.maxFitTime:
                return True

        return fitCallback

    def _fitButton_fired(self):
        self.fit()

    def _usePreviousFitButton_fired(self):
        """update the initial guess value with the fitted values of the parameter """
        for parameter in self.parametersList:
            parameter.initialValue = parameter.calculatedValue

    def _guessButton_fired(self):
        """calls _guess function and updates initial fit values accordingly """
        print "guess button clicked"
        if self.model_.guessFunction is None:
            print "attempted to guess initial values but no guess function is defined. returning without changing initial values"
            logger.error(
                "attempted to guess initial values but no guess function is defined. returning without changing initial values"
            )
            return
        logger.info("attempting to guess initial values using %s" %
                    self.model_.guessFunction.__name__)
        xs, ys = self.dataSets[self.selectedDataSet]
        guessDictionary = self.model_.guessFunction(xs, ys)
        logger.debug("guess results = %s" % guessDictionary)
        print "guess results = %s" % guessDictionary
        for parameterName, guessValue in guessDictionary.iteritems():
            for parameter in self.parametersList:
                if parameter.name == parameterName:
                    parameter.initialValue = guessValue

    def _saveFitButton_fired(self):
        saveFolder, filename = os.path.split(self.logFilePlotReference.logFile)
        parametersResult = self.modelFitResult.params
        logFileName = os.path.split(saveFolder)[1]
        functionName = self.model_.function.__name__
        saveFileName = os.path.join(
            saveFolder, logFileName + "-" + functionName + "-fitSave.csv")

        #parse selected data set name to get column names
        #selectedDataSet is like "aaaa=1.31 bbbb=1.21"
        seriesColumnNames = [
            seriesString.split("=")[0]
            for seriesString in self.selectedDataSet.split(" ")
        ]

        if not os.path.exists(saveFileName):  #create column names
            with open(saveFileName, "ab+") as csvFile:
                writer = csv.writer(csvFile)
                writer.writerow(
                    seriesColumnNames +
                    [variable.name for variable in self.parametersList] + [
                        variable.name + "-tolerance"
                        for variable in self.parametersList
                    ])
        with open(saveFileName, "ab+") as csvFile:  #write save to file
            writer = csv.writer(csvFile)
            seriesValues = [
                seriesString.split("=")[1]
                for seriesString in self.selectedDataSet.split(" ")
            ]  #values of the legend keys so you know what fit was associated with
            writer.writerow(seriesValues + [
                parametersResult[variable.name].value
                for variable in self.parametersList
            ] + [
                parametersResult[variable.name].stderr
                for variable in self.parametersList
            ])

    def _cycleAndFitButton_fired(self):
        logger.info("cycle and fit button pressed")
        self._fitButton_fired()
        self._saveFitButton_fired()
        self._usePreviousFitButton_fired()
        currentDataSetIndex = self.dataSets.keys().index(self.selectedDataSet)
        self.selectedDataSet = self.dataSets.keys()[currentDataSetIndex + 1]

    def _statisticsButton_fired(self):
        from scipy.stats import pearsonr
        xs, ys = self.dataSets[self.selectedDataSet]
        mean = scipy.mean(ys)
        median = scipy.median(ys)
        std = scipy.std(ys)
        minimum = scipy.nanmin(ys)
        maximum = scipy.nanmax(ys)
        peakToPeak = maximum - minimum
        pearsonCorrelation = pearsonr(xs, ys)
        resultString = "mean=%G , median=%G stdev =%G\nmin=%G,max=%G, pk-pk=%G\nPearson Correlation=(%G,%G)\n(stdev/mean)=%G" % (
            mean, median, std, minimum, maximum, peakToPeak,
            pearsonCorrelation[0], pearsonCorrelation[1], std / mean)
        self.statisticsString = resultString

    def getFitData(self):
        dataX = self.dataSets[self.selectedDataSet][0]
        # resample x data
        dataX = np.linspace(min(dataX), max(dataX), self.plotPoints)
        dataY = self.modelFitResult.eval(x=dataX)
        return dataX, dataY
Exemplo n.º 25
0
class Fit(traits.HasTraits):

    name = traits.Str(desc="name of fit")
    function = traits.Str(desc="function we are fitting with all parameters")
    variablesList = traits.List(Parameter)
    calculatedParametersList = traits.List(CalculatedParameter)
    xs = None  # will be a scipy array
    ys = None  # will be a scipy array
    zs = None  # will be a scipy array
    performFitButton = traits.Button("Perform Fit")
    getInitialParametersButton = traits.Button("Guess Initial Values")
    usePreviousFitValuesButton = traits.Button("Use Previous Fit")
    drawRequestButton = traits.Button("Draw Fit")
    setSizeButton = traits.Button("Set Initial Size")
    chooseVariablesButtons = traits.Button("choose logged variables")
    logAllVariables = traits.Bool(True)
    logLibrarianButton = traits.Button("librarian")
    logLastFitButton = traits.Button("log current fit")
    removeLastFitButton = traits.Button("remove last fit")
    autoFitBool = traits.Bool(
        False,
        desc=
        "Automatically perform this Fit with current settings whenever a new image is loaded"
    )
    autoDrawBool = traits.Bool(
        True,
        desc=
        "Once a fit is complete update the drawing of the fit or draw the fit for the first time"
    )
    autoGuessBool = traits.Bool(False,
                                desc="Perform a new guess before fitting")
    autoSizeBool = traits.Bool(
        False,
        desc=
        "If TOF variable is read from latest XML and is equal to 0.11ms (or time set in Physics) then it will automatically update the physics sizex and sizey with the Sigma x and sigma y from the gaussian fit"
    )
    autoPreviousBool = traits.Bool(
        False,
        desc=
        "Whenever a fit is completed replace the guess values with the calculated values (useful for increasing speed of the next fit)"
    )
    logBool = traits.Bool(
        False, desc="Log the calculated and fitted values with a timestamp")
    logName = traits.String(
        desc="name of the scan - will be used in the folder name")
    logToNas = traits.Bool(
        True, desc="If true, log goes to Humphry-NAS instead of ursa")
    ##    logDirectory = os.path.join("\\\\ursa","AQOGroupFolder","Experiment Humphry","Data","eagleLogs")
    logDirectory = os.path.join("G:", os.sep, "Experiment Humphry", "Data",
                                "eagleLogs")
    logDirectoryNas = os.path.join("\\\\192.168.16.71", "Humphry", "Data",
                                   "eagleLogs")
    latestSequence = os.path.join("\\\\ursa", "AQOGroupFolder",
                                  "Experiment Humphry",
                                  "Experiment Control And Software",
                                  "currentSequence", "latestSequence.xml")
    conditionalFitBool = traits.Bool(
        False,
        desc=
        "If true, fit is only executed, if current sequence contains matching variable 'eagleID'"
    )
    conditionalFitID = traits.Int(0)

    logFile = traits.File(desc="file path of logFile")

    logAnalyserBool = traits.Bool(
        False, desc="only use log analyser script when True")
    logAnalysers = [
    ]  #list containing full paths to each logAnalyser file to run
    logAnalyserDisplayString = traits.String(
        desc=
        "comma separated read only string that is a list of all logAnalyser python scripts to run. Use button to choose files"
    )
    logAnalyserSelectButton = traits.Button("sel. analyser",
                                            image='@icons:function_node',
                                            style="toolbar")
    xmlLogVariables = []
    imageInspectorReference = None  #will be a reference to the image inspector
    fitting = traits.Bool(False)  #true when performing fit
    fitted = traits.Bool(
        False)  #true when current data displayed has been fitted
    fitSubSpace = traits.Bool(True)
    drawSubSpace = traits.Bool(True)
    startX = traits.Int(230)
    startY = traits.Int(230)
    endX = traits.Int(550)
    endY = traits.Int(430)
    fittingStatus = traits.Str()
    fitThread = None
    fitTimeLimit = traits.Float(
        10.0,
        desc=
        "Time limit in seconds for fitting function. Only has an effect when fitTimeLimitBool is True"
    )
    fitTimeLimitBool = traits.Bool(
        True,
        desc=
        "If True then fitting functions will be limited to time limit defined by fitTimeLimit "
    )
    physics = traits.Instance(
        physicsProperties.physicsProperties.PhysicsProperties)
    #status strings
    notFittedForCurrentStatus = "Not Fitted for Current Image"
    fittedForCurrentImageStatus = "Fit Complete for Current Image"
    currentlyFittingStatus = "Currently Fitting..."
    failedFitStatus = "Failed to finish fit. See logger"
    timeExceededStatus = "Fit exceeded user time limit"

    lmfitModel = traits.Instance(
        lmfit.Model
    )  #reference to the lmfit model  must be initialised in subclass
    mostRecentModelResult = None  # updated to the most recent ModelResult object from lmfit when a fit thread is performed

    fitSubSpaceGroup = traitsui.VGroup(
        traitsui.HGroup(
            traitsui.Item("fitSubSpace", label="Fit Sub Space",
                          resizable=True), traitsui.Item("drawSubSpace")),
        traitsui.VGroup(traitsui.HGroup(
            traitsui.Item("startX", resizable=True),
            traitsui.Item("startY", resizable=True)),
                        traitsui.HGroup(traitsui.Item("endX", resizable=True),
                                        traitsui.Item("endY", resizable=True)),
                        visible_when="fitSubSpace"),
        label="Fit Sub Space",
        show_border=True)

    generalGroup = traitsui.VGroup(traitsui.Item("name",
                                                 label="Fit Name",
                                                 style="readonly",
                                                 resizable=True),
                                   traitsui.Item("function",
                                                 label="Fit Function",
                                                 style="readonly",
                                                 resizable=True),
                                   fitSubSpaceGroup,
                                   label="Fit",
                                   show_border=True)

    variablesGroup = traitsui.VGroup(traitsui.Item(
        "variablesList",
        editor=traitsui.ListEditor(style="custom"),
        show_label=False,
        resizable=True),
                                     show_border=True,
                                     label="parameters")

    derivedGroup = traitsui.VGroup(traitsui.Item(
        "calculatedParametersList",
        editor=traitsui.ListEditor(style="custom"),
        show_label=False,
        resizable=True),
                                   show_border=True,
                                   label="derived values")

    buttons = traitsui.HGroup(
        traitsui.VGroup(traitsui.HGroup(
            traitsui.Item("autoFitBool", label="Auto fit?", resizable=True),
            traitsui.Item("performFitButton", show_label=False,
                          resizable=True)),
                        traitsui.HGroup(
                            traitsui.Item("autoGuessBool",
                                          label="Auto guess?",
                                          resizable=True),
                            traitsui.Item("getInitialParametersButton",
                                          show_label=False,
                                          resizable=True)),
                        traitsui.HGroup(
                            traitsui.Item("autoPreviousBool",
                                          label="Auto previous?",
                                          resizable=True),
                            traitsui.Item("usePreviousFitValuesButton",
                                          show_label=False,
                                          resizable=True)),
                        show_border=True),
        traitsui.VGroup(traitsui.HGroup(
            traitsui.Item("autoDrawBool", label="Auto draw?", resizable=True),
            traitsui.Item("drawRequestButton",
                          show_label=False,
                          resizable=True)),
                        traitsui.HGroup(
                            traitsui.Item("autoSizeBool",
                                          label="Auto size?",
                                          resizable=True),
                            traitsui.Item("setSizeButton",
                                          show_label=False,
                                          resizable=True)),
                        traitsui.HGroup(
                            traitsui.Item("conditionalFitBool",
                                          label="Conditional Fit?",
                                          resizable=True),
                            traitsui.Item("conditionalFitID",
                                          label="Fit ID",
                                          resizable=True)),
                        show_border=True))

    logGroup = traitsui.VGroup(
        traitsui.HGroup(
            traitsui.Item("logBool", resizable=True),
            traitsui.Item("logAllVariables", resizable=True),
            traitsui.Item("chooseVariablesButtons",
                          show_label=False,
                          resizable=True,
                          enabled_when="not logAllVariables")),
        traitsui.HGroup(traitsui.Item("logName", resizable=True)),
        traitsui.HGroup(traitsui.Item("logToNas", resizable=True)),  #changed
        traitsui.HGroup(
            traitsui.Item("removeLastFitButton",
                          show_label=False,
                          resizable=True),
            traitsui.Item("logLastFitButton", show_label=False,
                          resizable=True)),
        traitsui.HGroup(
            traitsui.Item("logAnalyserBool", label="analyser?",
                          resizable=True),
            traitsui.Item("logAnalyserDisplayString",
                          show_label=False,
                          style="readonly",
                          resizable=True),
            traitsui.Item("logAnalyserSelectButton",
                          show_label=False,
                          resizable=True)),
        label="Logging",
        show_border=True)

    actionsGroup = traitsui.VGroup(traitsui.Item("fittingStatus",
                                                 style="readonly",
                                                 resizable=True),
                                   logGroup,
                                   buttons,
                                   label="Fit Actions",
                                   show_border=True)
    traits_view = traitsui.View(traitsui.VGroup(generalGroup, variablesGroup,
                                                derivedGroup, actionsGroup),
                                kind="subpanel")

    def __init__(self, **traitsDict):
        super(Fit, self).__init__(**traitsDict)
        # self.startX = 0
        # self.startY = 0
        self.lmfitModel = lmfit.Model(self.fitFunc)

        # load config
        with open(configFile, 'r') as f:
            settings = json.load(f)
            if 'logAnalysers' in settings:
                self.logAnalysers = settings['logAnalysers']
                self.logAnalyserDisplayString = str(
                    [os.path.split(path)[1] for path in self.logAnalysers])
            if 'logAnalyserBool' in settings:
                self.logAnalyserBool = settings['logAnalyserBool']

    def _set_xs(self, xs):
        self.xs = xs

    def _set_ys(self, ys):
        self.ys = ys

    def _set_zs(self, zs):
        self.zs = zs

    def _fittingStatus_default(self):
        return self.notFittedForCurrentStatus

    def _getInitialValues(self):
        """returns ordered list of initial values from variables List """
        return [_.initialValue for _ in self.variablesList]

    def _getParameters(self):
        """creates an lmfit parameters object based on the user input in variablesList """
        return lmfit.Parameters(
            {_.name: _.parameter
             for _ in self.variablesList})

    def _getCalculatedValues(self):
        """returns ordered list of fitted values from variables List """
        return [_.calculatedValue for _ in self.variablesList]

    def _intelligentInitialValues(self):
        """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """
        self._setInitialValues(self._getIntelligentInitialValues())

    def _get_subSpaceArrays(self):
        """returns the arrays of the selected sub space. If subspace is not
        activated then returns the full arrays"""
        if self.fitSubSpace:
            xs = self.xs[self.startX:self.endX]
            ys = self.ys[self.startY:self.endY]
            logger.info("xs array sliced length %s " % (xs.shape))
            logger.info("ys array sliced length %s  " % (ys.shape))
            zs = self.zs[self.startY:self.endY, self.startX:self.endX]
            logger.info("zs sub space array %s,%s " % (zs.shape))

            return xs, ys, zs
        else:
            return self.xs, self.ys, self.zs

    def _getIntelligentInitialValues(self):
        """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """
        logger.warning("Dummy function should not be called directly")
        return
        #in python this should be a pass statement. I.e. user has to overwrite this

    def fitFunc(self, data, *p):
        """Function that we are trying to fit to. """
        logger.error("Dummy function should not be called directly")
        return
        #in python this should be a pass statement. I.e. user has to overwrite this

    def _setCalculatedValues(self, modelFitResult):
        """updates calculated values with calculated argument """
        parametersResult = modelFitResult.params
        for variable in self.variablesList:
            variable.calculatedValue = parametersResult[variable.name].value

    def _setCalculatedValuesErrors(self, modelFitResult):
        """given the covariance matrix returned by scipy optimize fit
        convert this into stdeviation errors for parameters list and updated
        the stdevError attribute of variables"""
        parametersResult = modelFitResult.params
        for variable in self.variablesList:
            variable.stdevError = parametersResult[variable.name].stderr

    def _setInitialValues(self, guesses):
        """updates calculated values with calculated argument """
        c = 0
        for variable in self.variablesList:
            variable.initialValue = guesses[c]
            c += 1

    def deriveCalculatedParameters(self):
        """Wrapper for subclass definition of deriving calculated parameters
        can put more general calls in here"""
        if self.fitted:
            self._deriveCalculatedParameters()

    def _deriveCalculatedParameters(self):
        """Should be implemented by subclass. should update all variables in calculate parameters list"""
        logger.error("Should only be called by subclass")
        return

    def _fit_routine(self):
        """This function performs the fit in an appropriate thread and 
        updates necessary values when the fit has been performed"""
        self.fitting = True
        if self.fitThread and self.fitThread.isAlive():
            logger.warning(
                "Fitting is already running. You should wait till this fit has timed out before a new thread is started...."
            )
            #logger.warning("I will start a new fitting thread but your previous thread may finish at some undetermined time. you probably had bad starting conditions :( !")
            return
        self.fitThread = FitThread()  #new fitting thread
        self.fitThread.fitReference = self
        self.fitThread.isCurrentFitThread = True  # user can create multiple fit threads on a particular fit but only the latest one will have an effect in the GUI
        self.fitThread.start()
        self.fittingStatus = self.currentlyFittingStatus

    def _perform_fit(self):
        """Perform the fit using scipy optimise curve fit.
        We must supply x and y as one argument and zs as anothger. in the form
        xs: 0 1 2 0 1 2 0 
        ys: 0 0 0 1 1 1 2
        zs: 1 5 6 1 9 8 2
        Hence the use of repeat and tile in  positions and unravel for zs
        initially xs,ys is a linspace array and zs is a 2d image array
        """
        if self.xs is None or self.ys is None or self.zs is None:
            logger.warning(
                "attempted to fit data but had no data inside the Fit object. set xs,ys,zs first"
            )
            return ([], [])
        params = self._getParameters()
        if self.fitSubSpace:  #fit only the sub space
            #create xs, ys and zs which are appropriate slices of the arrays
            xs, ys, zs = self._get_subSpaceArrays()
        else:  #fit the whole array of data (slower)
            xs, ys, zs = self.xs, self.ys, self.zs
        positions = scipy.array([
            scipy.tile(xs, len(ys)),
            scipy.repeat(ys, len(xs))
        ])  #for creating data necessary for gauss2D function
        if self.fitTimeLimitBool:
            modelFitResult = self.lmfitModel.fit(scipy.ravel(zs),
                                                 positions=positions,
                                                 params=params,
                                                 iter_cb=self.getFitCallback(
                                                     time.time()))
        else:  #no iter callback
            modelFitResult = self.lmfitModel.fit(scipy.ravel(zs),
                                                 positions=positions,
                                                 params=params)
        return modelFitResult

    def getFitCallback(self, startTime):
        """returns the callback function that is called at every iteration of fit to check if it 
        has been running too long"""
        def fitCallback(params, iter, resid, *args, **kws):
            """check the time and compare to start time """
            if time.time() - startTime > self.fitTimeLimit:
                raise FitException("Fit time exceeded user limit")

        return fitCallback

    def _performFitButton_fired(self):
        self._fit_routine()

    def _getInitialParametersButton_fired(self):
        self._intelligentInitialValues()

    def _drawRequestButton_fired(self):
        """tells the imageInspector to try and draw this fit as an overlay contour plot"""
        self.imageInspectorReference.addFitPlot(self)

    def _setSizeButton_fired(self):
        """use the sigmaX and sigmaY from the current fit to overwrite the 
        inTrapSizeX and inTrapSizeY parameters in the Physics Instance"""
        self.physics.inTrapSizeX = abs(self.sigmax.calculatedValue)
        self.physics.inTrapSizeY = abs(self.sigmay.calculatedValue)

    def _getFitFuncData(self):
        """if data has been fitted, this returns the zs data for the ideal
        fitted function using the calculated paramters"""
        positions = [
            scipy.tile(self.xs, len(self.ys)),
            scipy.repeat(self.ys, len(self.xs))
        ]  #for creating data necessary for gauss2D function
        zsravelled = self.fitFunc(positions, *self._getCalculatedValues())
        return zsravelled.reshape(self.zs.shape)

    def _logAnalyserSelectButton_fired(self):
        """open a fast file editor for selecting many files """
        fileDialog = FileDialog(action="open files")
        fileDialog.open()
        if fileDialog.return_code == pyface.constant.OK:
            self.logAnalysers = fileDialog.paths
            logger.info("selected log analysers: %s " % self.logAnalysers)
        self.logAnalyserDisplayString = str(
            [os.path.split(path)[1] for path in self.logAnalysers])

    def runSingleAnalyser(self, module):
        """runs the logAnalyser module calling the run function and returns the 
        columnNames and values as a list"""
        exec("import logAnalysers.%s as currentAnalyser" % module)
        reload(
            currentAnalyser
        )  #in case it has changed..#could make this only when user requests
        #now the array also contains the raw image as this may be different to zs if you are using a processor
        if hasattr(self.imageInspectorReference, "rawImage"):
            rawImage = self.imageInspectorReference.rawImage
        else:
            rawImage = None
        return currentAnalyser.run([self.xs, self.ys, self.zs, rawImage],
                                   self.physics.variables, self.variablesList,
                                   self.calculatedParametersList)

    def runAnalyser(self):
        """ if logAnalyserBool is true we perform runAnalyser at the end of _log_fit
        runAnalyser checks that logAnalyser exists and is a python script with a valid run()function
        it then performs the run method and passes to the run function:
        -the image data as a numpy array
        -the xml variables dictionary
        -the fitted paramaters
        -the derived values"""
        for logAnalyser in self.logAnalysers:
            if not os.path.isfile(logAnalyser):
                logger.error(
                    "attempted to runAnalyser but could not find the logAnalyser File: %s"
                    % logAnalyser)
                return
        #these will contain the final column names and values
        finalColumns = []
        finalValues = []
        #iterate over each selected logAnalyser get the column names and values and add them to the master lists
        for logAnalyser in self.logAnalysers:
            directory, module = os.path.split(logAnalyser)
            module, ext = os.path.splitext(module)
            if ext != ".py":
                logger.error("file was not a python module. %s" % logAnalyser)
            else:
                columns, values = self.runSingleAnalyser(module)
                finalColumns.extend(columns)
                finalValues.extend(values)
        return finalColumns, finalValues

    def mostRecentModelFitReport(self):
        """returns the lmfit fit report of the most recent 
        lmfit model results object"""
        if self.mostRecentModelResult is not None:
            return lmfit.fit_report(self.mostRecentModelResult) + "\n\n"
        else:
            return "No fit performed"

    def getCalculatedParameters(self):
        """useful for print returns tuple list of calculated parameter name and value """
        return [(_.name, _.value) for _ in self.calculatedParametersList]

    def _log_fit(self):

        if self.logName == "":
            logger.warning("no log file defined. Will not log")
            return
        #generate folders if they don't exist
        if self.logToNas is not True:
            logFolder = os.path.join(self.logDirectory, self.logName)
        else:
            logFolder = os.path.join(self.logDirectoryNas, self.logName)
        if not os.path.isdir(logFolder):
            logger.info("creating a new log folder %s" % logFolder)
            os.mkdir(logFolder)

        imagesFolder = os.path.join(logFolder, "images")
        if not os.path.isdir(imagesFolder):
            logger.info("creating a new images Folder %s" % imagesFolder)
            os.mkdir(imagesFolder)

        commentsFile = os.path.join(logFolder, "comments.txt")
        if not os.path.exists(commentsFile):
            logger.info("creating a comments file %s" % commentsFile)
            open(commentsFile,
                 "a+").close()  #create a comments file in every folder!

        firstSequenceCopy = os.path.join(logFolder,
                                         "copyOfInitialSequence.ctr")
        if not os.path.exists(firstSequenceCopy):
            logger.info("creating a copy of the first sequence %s -> %s" %
                        (self.latestSequence, firstSequenceCopy))
            shutil.copy(self.latestSequence, firstSequenceCopy)

        if self.imageInspectorReference.model.imageMode == "process raw image":  #if we are using a processor, save the details of the processor used to the log folder
            processorParamtersFile = os.path.join(logFolder,
                                                  "processorOptions.txt")
            processorPythonScript = os.path.join(logFolder,
                                                 "usedProcessor.py")  #TODO!
            if not os.path.exists(processorParamtersFile):
                with open(processorParamtersFile, "a+") as processorParamsFile:
                    string = str(self.imageInspectorReference.model.
                                 chosenProcessor) + "\n"
                    string += str(self.imageInspectorReference.model.processor.
                                  optionsDict)
                    processorParamsFile.write(string)

        logger.debug("finished all checks on log folder")
        #copy current image
        try:
            shutil.copy(self.imageInspectorReference.selectedFile,
                        imagesFolder)
            if self.imageInspectorReference.selectedFile.endswith("_X2.tif"):
                shutil.copy(
                    self.imageInspectorReference.selectedFile.replace(
                        "_X2.tif", "_X1.tif"), imagesFolder)
        except IOError as e:
            logger.error("Could not copy image. Got IOError: %s " % e.message)
        except Exception as e:
            logger.error("Could not copy image. Got %s: %s " %
                         (type(e), e.message))
            raise e
        logger.info("copying current image")
        self.logFile = os.path.join(logFolder, self.logName + ".csv")

        #analyser logic
        if self.logAnalyserBool:  #run the analyser script as requested
            logger.info(
                "log analyser bool enabled... will attempt to run analyser script"
            )
            analyserResult = self.runAnalyser()
            logger.info("analyser result = %s " % list(analyserResult))
            if analyserResult is None:
                analyserColumnNames = []
                analyserValues = []
                #analyser failed. continue as if nothing happened
            else:
                analyserColumnNames, analyserValues = analyserResult
        else:  #no analyser enabled
            analyserColumnNames = []
            analyserValues = []

        if not os.path.exists(self.logFile):
            variables = [_.name for _ in self.variablesList]
            calculated = [_.name for _ in self.calculatedParametersList]
            times = ["datetime", "epoch seconds"]
            info = ["img file name"]
            xmlVariables = self.getXmlVariables()
            columnNames = times + info + variables + calculated + xmlVariables + analyserColumnNames
            with open(
                    self.logFile, 'ab+'
            ) as logFile:  # note use of binary file so that windows doesn't write too many /r
                writer = csv.writer(logFile)
                writer.writerow(columnNames)
        #column names already exist so...
        logger.debug("copying current image")
        variables = [_.calculatedValue for _ in self.variablesList]
        calculated = [_.value for _ in self.calculatedParametersList]
        now = time.time()  #epoch seconds
        timeTuple = time.localtime(now)
        date = time.strftime("%Y-%m-%dT%H:%M:%S", timeTuple)
        times = [date, now]
        info = [self.imageInspectorReference.selectedFile]
        xmlVariables = [
            self.physics.variables[varName]
            for varName in self.getXmlVariables()
        ]
        data = times + info + variables + calculated + xmlVariables + analyserValues

        with open(self.logFile, 'ab+') as logFile:
            writer = csv.writer(logFile)
            writer.writerow(data)

    def _logLastFitButton_fired(self):
        """logs the fit. User can use this for non automated logging. i.e. log
        particular fits"""
        self._log_fit()

    def _removeLastFitButton_fired(self):
        """removes the last line in the log file """
        logFolder = os.path.join(self.logDirectory, self.logName)
        self.logFile = os.path.join(logFolder, self.logName + ".csv")
        if self.logFile == "":
            logger.warning("no log file defined. Will not log")
            return
        if not os.path.exists(self.logFile):
            logger.error(
                "cant remove a line from a log file that doesn't exist")
        with open(self.logFile, 'r') as logFile:
            lines = logFile.readlines()
        with open(self.logFile, 'wb') as logFile:
            logFile.writelines(lines[:-1])

    def saveLastFit(self):
        """saves result of last fit to a txt/csv file. This can be useful for live analysis
        or for generating sequences based on result of last fit"""
        try:
            with open(
                    self.imageInspectorReference.cameraModel + "-" +
                    self.physics.species + "-" + "lastFit.csv",
                    "wb") as lastFitFile:
                writer = csv.writer(lastFitFile)
                writer.writerow(["time", time.time()])
                for variable in self.variablesList:
                    writer.writerow([variable.name, variable.calculatedValue])
                for variable in self.calculatedParametersList:
                    writer.writerow([variable.name, variable.value])
        except Exception as e:
            logger.error("failed to save last fit to text file. message %s " %
                         e.message)

    def _chooseVariablesButtons_fired(self):
        self.xmlLogVariables = self.chooseVariables()

    def _usePreviousFitValuesButton_fired(self):
        """update the guess initial values with the value from the last fit """
        logger.info(
            "use previous fit values button fired. loading previous initial values"
        )
        self._setInitialValues(self._getCalculatedValues())

    def getXmlVariables(self):
        if self.logAllVariables:
            return sorted(self.physics.variables.keys())
        else:
            return self.xmlLogVariables

    def chooseVariables(self):
        """Opens a dialog asking user to select columns from a data File that has
        been selected. THese are then returned as a string suitable for Y cols input"""
        columns = self.physics.variables.keys()
        columns.sort()
        values = zip(range(0, len(columns)), columns)

        checklist_group = traitsui.Group(
            '10',  # insert vertical space
            traitsui.Label('Select the additional variables you wish to log'),
            traitsui.UItem('columns',
                           style='custom',
                           editor=traitsui.CheckListEditor(values=values,
                                                           cols=6)),
            traitsui.UItem('selectAllButton'))

        traits_view = traitsui.View(checklist_group,
                                    title='CheckListEditor',
                                    buttons=['OK'],
                                    resizable=True,
                                    kind='livemodal')

        col = ColumnEditor(numberOfColumns=len(columns))
        try:
            col.columns = [
                columns.index(varName) for varName in self.xmlLogVariables
            ]
        except Exception as e:
            logger.error(
                "couldn't selected correct variable names. Returning empty selection"
            )
            logger.error("%s " % e.message)
            col.columns = []
        col.edit_traits(view=traits_view)
        logger.debug("value of columns selected = %s ", col.columns)
        logger.debug("value of columns selected = %s ",
                     [columns[i] for i in col.columns])
        return [columns[i] for i in col.columns]

    def _logLibrarianButton_fired(self):
        """opens log librarian for current folder in logName box. """
        logFolder = os.path.join(self.logDirectory, self.logName)
        if not os.path.isdir(logFolder):
            logger.error(
                "cant open librarian on a log that doesn't exist.... Could not find %s"
                % logFolder)
            return
        librarian = plotObjects.logLibrarian.Librarian(logFolder=logFolder)
        librarian.edit_traits()

    def _drawSubSpace_changed(self):
        newVisibility = self.drawSubSpace and self.fitSubSpace
        self.imageInspectorReference.ROIPolyPlot.visible = newVisibility
        self._startX_changed()  # update ROI data for plot

    def _fitSubSpace_changed(self):
        self._drawSubSpace_changed()

    def _startX_changed(self):
        if self.imageInspectorReference is None:
            return  # not yet initialized yet
        self.imageInspectorReference.ROIPolyPlot.index = chaco.ArrayDataSource(
            [self.startX, self.endX, self.endX, self.startX])
        self.imageInspectorReference.ROIPolyPlot.value = chaco.ArrayDataSource(
            [self.startY, self.startY, self.endY, self.endY])

    _endX_changed = _startX_changed
    _startY_changed = _startX_changed
    _endY_changed = _startX_changed
Exemplo n.º 26
0
class HCFT(tr.HasStrictTraits):
    """High-Cycle Fatigue Tool"""
    # =========================================================================
    # Traits definitions
    # =========================================================================
    # Assigning the view
    traits_view = hcft_window

    # CSV import
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    file_path = tr.File
    open_file_button = tr.Button('Open file')
    columns_headers = tr.List
    npy_folder_path = tr.Str
    file_name = tr.Str

    # CSV processing
    take_time_from_time_column = tr.Bool(True)
    records_per_second = tr.Float(100)
    time_column = tr.Enum(values='columns_headers')
    skip_first_rows = tr.Range(low=1, high=10**9, value=3, mode='spinner')
    add_columns_average = tr.Button
    columns_to_be_averaged = tr.List
    parse_csv_to_npy = tr.Button

    # Plotting
    x_axis = tr.Enum(values='columns_headers')
    y_axis = tr.Enum(values='columns_headers')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    add_plot = tr.Button
    apply_filters = tr.Bool
    plot_settings_btn = tr.Button
    plot_settings = PlotSettings()
    plot_settings_active = tr.Bool
    normalize_cycles = tr.Bool
    smooth = tr.Bool
    plot_every_nth_point = tr.Range(low=1, high=1000000, mode='spinner')
    old_peak_force_before_cycles = tr.Float
    peak_force_before_cycles = tr.Float
    add_creep_plot = tr.Button(desc='Creep plot of X axis array')
    clear_plot = tr.Button

    force_column = tr.Enum(values='columns_headers')
    window_length = tr.Range(low=1, high=10**9 - 1, value=31, mode='spinner')
    polynomial_order = tr.Range(low=1, high=10**9, value=2, mode='spinner')
    activate_ascending_branch_smoothing = tr.Bool(False, label='Activate')

    generate_filtered_and_creep_npy = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)
    min_cycle_force_range = tr.Float(50)
    cutting_method = tr.Enum('Define min cycle range(force difference)',
                             'Define Max, Min')

    log = tr.Str('')
    clear_log = tr.Button

    # =========================================================================
    # Assigning default values
    # =========================================================================
    ax = tr.Any
    figure = tr.Instance(mpl.figure.Figure)

    def _figure_default(self):
        figure = mpl.figure.Figure(facecolor='white')
        figure.set_tight_layout(True)
        self.create_axes(figure)
        return figure

    def create_axes(self, figure):
        self.ax = figure.add_subplot(1, 1, 1)

    # =========================================================================
    # File management
    # =========================================================================
    def _open_file_button_fired(self):
        try:
            self.reset()

            dialog = FileDialog(title='Select text file',
                                action='open',
                                default_path=self.file_path)
            result = dialog.open()

            # Test if the user opened a file to avoid throwing an exception if he doesn't
            if result == OK:
                self.file_path = dialog.path
            else:
                return

            # Populate headers list which fills the x-axis and y-axis with values automatically
            self.columns_headers = get_headers(self.file_path,
                                               decimal=self.decimal,
                                               delimiter=self.delimiter)

            # Saving file name and path and creating NPY folder
            dir_path = os.path.dirname(self.file_path)
            self.npy_folder_path = os.path.join(dir_path, 'NPY')
            if not os.path.exists(self.npy_folder_path):
                os.makedirs(self.npy_folder_path)

            self.file_name = os.path.splitext(os.path.basename(
                self.file_path))[0]

            self.import_data_json()

        except:
            self.log_exception()

    def _add_columns_average_fired(self):
        try:
            columns_average = ColumnsAverage()
            for name in self.columns_headers:
                columns_average.columns.append(Column(column_name=name))

            # kind='modal' pauses the implementation until the window is closed
            columns_average.configure_traits(kind='modal')

            columns_to_be_averaged_temp = []
            for i in columns_average.columns:
                if i.selected:
                    columns_to_be_averaged_temp.append(i.column_name)

            if columns_to_be_averaged_temp:  # If it's not empty
                self.columns_to_be_averaged.append(columns_to_be_averaged_temp)

                avg_file_suffix = self.get_suffix_for_columns_to_be_averaged(
                    columns_to_be_averaged_temp)
                self.columns_headers.append(avg_file_suffix)
        except:
            self.log_exception()

    def _parse_csv_to_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        # thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.parse_csv_to_npy_fired)
        thread.start()

    def parse_csv_to_npy_fired(self):
        try:
            self.print_custom('Parsing csv into npy files...')
            self.export_data_json()
            """ Exporting npy arrays of original columns """
            for i in range(
                    len(self.columns_headers) -
                    len(self.columns_to_be_averaged)):
                column_name = self.columns_headers[i]
                # One could provide the path directly to pd.read_csv but in this way we insure that this works also if the
                # path to the file include chars like ü,ä
                # (with) makes sure the file stream is closed after using it
                with open(self.file_path, encoding='utf-8') as file_stream:
                    column_array = np.array(
                        pd.read_csv(file_stream,
                                    delimiter=self.delimiter,
                                    decimal=self.decimal,
                                    skiprows=self.skip_first_rows,
                                    usecols=[i]))

                # TODO detect column name before loading completely to skip loading if the following condition applies
                if column_name == self.time_column and self.take_time_from_time_column is False:
                    column_array = np.arange(
                        start=0.0,
                        stop=len(column_array) / self.records_per_second,
                        step=1.0 / self.records_per_second)

                np.save(self.get_npy_file_path(column_name), column_array)
            """ Exporting npy arrays of averaged columns """
            for columns_names in self.columns_to_be_averaged:
                temp_array = np.zeros((1))
                for column_name in columns_names:
                    temp_array = temp_array + np.load(
                        self.get_npy_file_path(column_name)).flatten()
                avg = temp_array / len(columns_names)

                np.save(self.get_average_npy_file_path(columns_names), avg)

            self.print_custom('Finished parsing csv into npy files.')
        except:
            self.log_exception()

    def get_npy_file_path(self, column_name):
        return os.path.join(self.npy_folder_path,
                            self.file_name + '_' + column_name + '.npy')

    def get_filtered_npy_file_path(self, column_name):
        return os.path.join(
            self.npy_folder_path,
            self.file_name + '_' + column_name + '_filtered.npy')

    def get_max_npy_file_path(self, column_name):
        return os.path.join(self.npy_folder_path,
                            self.file_name + '_' + column_name + '_max.npy')

    def get_min_npy_file_path(self, column_name):
        return os.path.join(self.npy_folder_path,
                            self.file_name + '_' + column_name + '_min.npy')

    def get_average_npy_file_path(self, columns_names):
        avg_file_suffix = self.get_suffix_for_columns_to_be_averaged(
            columns_names)
        return os.path.join(self.npy_folder_path,
                            self.file_name + '_' + avg_file_suffix + '.npy')

    def get_suffix_for_columns_to_be_averaged(self, columns_names):
        suffix_for_saved_file_name = 'avg_' + '_'.join(columns_names)
        return suffix_for_saved_file_name

    def export_data_json(self):
        # Output data MUST have exactly similar keys and variable names
        output_data = {
            'take_time_from_time_column': self.take_time_from_time_column,
            'time_column': self.time_column,
            'records_per_second': self.records_per_second,
            'skip_first_rows': self.skip_first_rows,
            'columns_headers': self.columns_headers,
            'columns_to_be_averaged': self.columns_to_be_averaged,
            'x_axis': self.x_axis,
            'y_axis': self.y_axis,
            'x_axis_multiplier': self.x_axis_multiplier,
            'y_axis_multiplier': self.y_axis_multiplier,
            'force_column': self.force_column,
            'window_length': self.window_length,
            'polynomial_order': self.polynomial_order,
            'peak_force_before_cycles': self.peak_force_before_cycles,
            'cutting_method': self.cutting_method,
            'force_max': self.force_max,
            'force_min': self.force_min,
            'min_cycle_force_range': self.min_cycle_force_range
        }
        with open(self.get_json_file_path(), 'w') as outfile:
            json.dump(output_data, outfile, sort_keys=True, indent=4)
        self.print_custom('.json data file exported successfully.')

    def import_data_json(self):
        json_path = self.get_json_file_path()
        if not os.path.isfile(json_path):
            return
        # class_vars is a list with class variables names
        # vars(self) & self.__dict__.items() didn't include some Trait variables like force_column = tr.Enum(values=..
        class_vars = [
            attr for attr in dir(self)
            if not attr.startswith("_") and not attr.startswith("__")
        ]
        with open(json_path) as infile:
            data_in = json.load(infile)
        for key_data, value_data in data_in.items():
            for key_class in class_vars:
                if key_data == key_class:
                    # Equivalent to: self.key_class = value_data
                    setattr(self, key_class, value_data)
                    break
        self.print_custom('.json data file imported successfully.')

    def get_json_file_path(self):
        return os.path.join(self.npy_folder_path, self.file_name + '.json')

    def _generate_filtered_and_creep_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        # thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.generate_filtered_and_creep_npy_fired)
        thread.start()

    def generate_filtered_and_creep_npy_fired(self):
        try:
            self.export_data_json()
            if not self.npy_files_exist(
                    self.get_npy_file_path(self.force_column)):
                return
            self.print_custom('Generating filtered and creep files...')

            # 1- Export filtered force
            force = np.load(self.get_npy_file_path(
                self.force_column)).flatten()
            peak_force_before_cycles_index = np.where(
                abs((force)) > abs(self.peak_force_before_cycles))[0][0]
            force_ascending = force[0:peak_force_before_cycles_index]
            force_rest = force[peak_force_before_cycles_index:]

            force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
                force_rest)

            force_max_min_indices = np.concatenate(
                (force_min_indices, force_max_indices))
            force_max_min_indices.sort()

            force_rest_filtered = force_rest[force_max_min_indices]
            force_filtered = np.concatenate(
                (force_ascending, force_rest_filtered))
            np.save(self.get_filtered_npy_file_path(self.force_column),
                    force_filtered)

            # 2- Export filtered displacements
            # Export displacements combining processed ascending branch and unprocessed min/max values
            self.export_filtered_displacements(force_max_min_indices,
                                               peak_force_before_cycles_index)

            # 3- Export creep for displacements
            # Cut unwanted max min values to get correct full cycles and remove false min/max values caused by noise
            self.export_displacements_creep(force_rest, force_max_indices,
                                            force_min_indices,
                                            peak_force_before_cycles_index)

            self.print_custom('Filtered and creep npy files are generated.')
        except:
            self.log_exception()

    def export_filtered_displacements(self, force_max_min_indices,
                                      peak_force_before_cycles_index):
        for i in range(len(self.columns_headers)):
            if self.columns_headers[
                    i] != self.force_column and self.columns_headers[
                        i] != self.time_column:

                disp = np.load(self.get_npy_file_path(
                    self.columns_headers[i])).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]

                if self.activate_ascending_branch_smoothing:
                    disp_ascending = savgol_filter(
                        disp_ascending,
                        window_length=self.window_length,
                        polyorder=self.polynomial_order)

                disp_rest_filtered = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate(
                    (disp_ascending, disp_rest_filtered))
                np.save(
                    self.get_filtered_npy_file_path(self.columns_headers[i]),
                    filtered_disp)

    def export_displacements_creep(self, force_rest, force_max_indices,
                                   force_min_indices,
                                   peak_force_before_cycles_index):
        if self.cutting_method == "Define Max, Min":
            force_max_indices_cut, force_min_indices_cut = self.cut_indices_of_min_max_range(
                force_rest, force_max_indices, force_min_indices,
                self.force_max, self.force_min)
        elif self.cutting_method == "Define min cycle range(force difference)":
            force_max_indices_cut, force_min_indices_cut = self.cut_indices_of_defined_range(
                force_rest, force_max_indices, force_min_indices,
                self.min_cycle_force_range)
        self.print_custom("Cycles number= ", len(force_min_indices))
        self.print_custom("Cycles number after cutting fake cycles = ",
                          len(force_min_indices_cut))

        for i in range(len(self.columns_headers)):
            if self.columns_headers[i] != self.time_column:
                array = np.load(self.get_npy_file_path(
                    self.columns_headers[i])).flatten()
                array_rest = array[peak_force_before_cycles_index:]
                array_rest_maxima = array_rest[force_max_indices_cut]
                array_rest_minima = array_rest[force_min_indices_cut]
                np.save(self.get_max_npy_file_path(self.columns_headers[i]),
                        array_rest_maxima)
                np.save(self.get_min_npy_file_path(self.columns_headers[i]),
                        array_rest_minima)

    def get_array_max_and_min_indices(self, input_array):
        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if positive_values_count > negative_values_count:
            force_max_indices = self.get_max_indices(input_array)
            force_min_indices = self.get_min_indices(input_array)
        else:
            force_max_indices = self.get_min_indices(input_array)
            force_min_indices = self.get_max_indices(input_array)

        return force_max_indices, force_min_indices

    def get_max_indices(self, a):
        # TODO try to vectorize this
        # This method doesn't qualify first and last elements as max
        max_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True
            while a[i] == a[i + 1] and i < a.size - 1:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            # Append value if it's a local max
            if a[i] > a[i + 1] and a[i] > previous_element:
                max_indices.append(i)
            i += 1
        return np.array(max_indices)

    def get_min_indices(self, a):
        # TODO try to vectorize this
        # This method doesn't qualify first and last elements as min
        min_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True
            while a[i] == a[i + 1]:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            # Append value if it's a local min
            if a[i] < a[i + 1] and a[i] < previous_element:
                min_indices.append(i)
            i += 1
        return np.array(min_indices)

    def cut_indices_of_min_max_range(self, array, max_indices, min_indices,
                                     range_upper_value, range_lower_value):
        # TODO try to vectorize this
        cut_max_indices = []
        cut_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cut_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cut_min_indices.append(min_index)
        return cut_max_indices, cut_min_indices

    def cut_indices_of_defined_range(self, array, max_indices, min_indices,
                                     range_):
        # TODO try to vectorize this
        cut_max_indices = []
        cut_min_indices = []

        for max_index, min_index in zip(max_indices, min_indices):
            if abs(array[max_index] - array[min_index]) > range_:
                cut_max_indices.append(max_index)
                cut_min_indices.append(min_index)

        if max_indices.size > min_indices.size:
            cut_max_indices.append(max_indices[-1])
        elif min_indices.size > max_indices.size:
            cut_min_indices.append(min_indices[-1])

        return cut_max_indices, cut_min_indices

    def _activate_changed(self):
        if not self.activate_ascending_branch_smoothing:
            self.old_peak_force_before_cycles = self.peak_force_before_cycles
            self.peak_force_before_cycles = 0
        else:
            self.peak_force_before_cycles = self.old_peak_force_before_cycles

    def _window_length_changed(self, new):
        if new <= self.polynomial_order:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be bigger than polynomial order.')
            dialog.open()

        if new % 2 == 0 or new <= 0:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be odd positive integer.')
            dialog.open()

    def _polynomial_order_changed(self, new):
        if new >= self.window_length:
            dialog = MessageDialog(
                title='Attention!',
                message='Polynomial order must be smaller than window length.')
            dialog.open()

    # =========================================================================
    # Plotting
    # =========================================================================
    data_changed = tr.Event

    def _plot_settings_btn_fired(self):
        try:
            self.plot_settings.configure_traits(kind='modal')
        except:
            self.log_exception()

    def npy_files_exist(self, path):
        if os.path.exists(path):
            return True
        else:
            self.print_custom(
                'Please parse csv file to generate npy files first!')
            return False

    def filtered_and_creep_npy_files_exist(self, path):
        if os.path.exists(path):
            return True
        else:
            self.print_custom(
                'Please generate filtered and creep npy files first!')
            return False

    def _add_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        # thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_plot_fired)
        thread.start()

    def add_plot_fired(self):
        try:
            if self.apply_filters:
                if not self.filtered_and_creep_npy_files_exist(
                        self.get_filtered_npy_file_path(self.x_axis)):
                    return
                # TODO link this _filtered to the path creation function
                x_axis_name = self.x_axis + '_filtered'
                y_axis_name = self.y_axis + '_filtered'
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(self.get_filtered_npy_file_path(
                    self.x_axis),
                                       mmap_mode='r')
                y_axis_array = np.load(self.get_filtered_npy_file_path(
                    self.y_axis),
                                       mmap_mode='r')
            else:
                if not self.npy_files_exist(self.get_npy_file_path(
                        self.x_axis)):
                    return

                x_axis_name = self.x_axis
                y_axis_name = self.y_axis
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(self.get_npy_file_path(self.x_axis),
                                       mmap_mode='r')
                y_axis_array = np.load(self.get_npy_file_path(self.y_axis),
                                       mmap_mode='r')

            if self.plot_settings_active:
                print(self.plot_settings.num_of_first_rows_to_take)
                print(
                    self.plot_settings.num_of_rows_to_skip_after_each_section)
                print(self.plot_settings.num_of_rows_in_each_section)
                print(np.size(x_axis_array))
                indices = self.get_indices_array(
                    np.size(x_axis_array),
                    self.plot_settings.num_of_first_rows_to_take,
                    self.plot_settings.num_of_rows_to_skip_after_each_section,
                    self.plot_settings.num_of_rows_in_each_section)
                x_axis_array = self.x_axis_multiplier * x_axis_array[indices]
                y_axis_array = self.y_axis_multiplier * y_axis_array[indices]
            else:
                x_axis_array = self.x_axis_multiplier * x_axis_array
                y_axis_array = self.y_axis_multiplier * y_axis_array

            self.print_custom('Adding Plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.ax

            ax.set_xlabel(x_axis_name)
            ax.set_ylabel(y_axis_name)
            ax.plot(x_axis_array,
                    y_axis_array,
                    linewidth=1.2,
                    color=np.random.rand(3),
                    label=self.file_name + ', ' + x_axis_name)
            ax.legend()

            self.data_changed = True
            self.print_custom('Finished adding plot.')

        except:
            self.log_exception()

    def _add_creep_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        # thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_creep_plot_fired)
        thread.start()

    def add_creep_plot_fired(self):
        try:
            if not self.filtered_and_creep_npy_files_exist(
                    self.get_max_npy_file_path(self.x_axis)):
                return

            self.print_custom('Loading npy files...')
            disp_max = self.x_axis_multiplier * np.load(
                self.get_max_npy_file_path(self.x_axis))
            disp_min = self.x_axis_multiplier * np.load(
                self.get_min_npy_file_path(self.x_axis))
            complete_cycles_number = disp_max.size

            self.print_custom('Adding creep-fatigue plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.ax

            ax.set_xlabel('Cycles number')
            ax.set_ylabel(self.x_axis)

            if self.plot_every_nth_point > 1:
                disp_max = disp_max[0::self.plot_every_nth_point]
                disp_min = disp_min[0::self.plot_every_nth_point]

            if self.smooth:
                # Keeping the first item of the array and filtering the rest
                disp_max = np.concatenate(
                    (np.array([disp_max[0]]),
                     savgol_filter(disp_max[1:],
                                   window_length=self.window_length,
                                   polyorder=self.polynomial_order)))
                disp_min = np.concatenate(
                    (np.array([disp_min[0]]),
                     savgol_filter(disp_min[1:],
                                   window_length=self.window_length,
                                   polyorder=self.polynomial_order)))

            if self.normalize_cycles:
                ax.plot(np.linspace(0, 1., disp_max.size),
                        disp_max,
                        'k',
                        linewidth=1.2,
                        color=np.random.rand(3),
                        label='Max' + ', ' + self.file_name + ', ' +
                        self.x_axis)
                ax.plot(np.linspace(0, 1., disp_min.size),
                        disp_min,
                        'k',
                        linewidth=1.2,
                        color=np.random.rand(3),
                        label='Min' + ', ' + self.file_name + ', ' +
                        self.x_axis)
            else:
                ax.plot(np.linspace(0, complete_cycles_number, disp_max.size),
                        disp_max,
                        'k',
                        linewidth=1.2,
                        color=np.random.rand(3),
                        label='Max' + ', ' + self.file_name + ', ' +
                        self.x_axis)
                ax.plot(np.linspace(0, complete_cycles_number, disp_min.size),
                        disp_min,
                        'k',
                        linewidth=1.2,
                        color=np.random.rand(3),
                        label='Min' + ', ' + self.file_name + ', ' +
                        self.x_axis)

            ax.legend()
            self.data_changed = True
            self.print_custom('Finished adding creep-fatigue plot.')

        except:
            self.log_exception()

    def get_indices_array(self, array_size, first_rows, distance,
                          num_of_rows_after_each_distance):
        result_1 = np.arange(first_rows)
        result_2 = np.arange(start=first_rows,
                             stop=array_size,
                             step=distance + num_of_rows_after_each_distance)
        result_2_updated = np.array([], dtype=np.int_)

        for result_2_value in result_2:
            data_slice = np.arange(
                result_2_value,
                result_2_value + num_of_rows_after_each_distance)
            result_2_updated = np.concatenate((result_2_updated, data_slice))

        result = np.concatenate((result_1, result_2_updated))
        return result

    def _clear_plot_fired(self):
        self.figure.clear()
        self.create_axes(self.figure)
        self.data_changed = True

    # =========================================================================
    # Logging
    # =========================================================================
    def print_custom(self, *input_args):
        print(*input_args)
        if self.log == '':
            self.log = ''.join(str(e) for e in list(input_args))
        else:
            self.log = self.log + '\n' + \
                       ''.join(str(e) for e in list(input_args))

    def log_exception(self):
        self.print_custom('SOMETHING WENT WRONG!')
        self.print_custom('--------- Error message: ---------')
        self.print_custom(traceback.format_exc())
        self.print_custom('----------------------------------')

    def _clear_log_fired(self):
        self.log = ''

    # =========================================================================
    # Other functions
    # =========================================================================
    def reset(self):
        self.columns_to_be_averaged = []
        self.log = ''
Exemplo n.º 27
0
class SensorOperationController(tui.Controller):
    """ UI for controlling the hardware. """

    model = t.Instance(EEGSensor)
    connect = t.Button()
    disconnect = t.Button()

    def _connect_changed(self):
        self.model.connect()

    def _disconnect_changed(self):
        self.model.disconnect()

    traits_view = View(
        HGroup(
            spring,
            VGroup(
                HGroup(spring, Heading('EEG Sensor Controls'), spring),
                VGroup(
                    Item('com_port',
                         style='simple',
                         enabled_when="not object.connected"),
                    Item('object.connected', style='readonly'),
                    # Item('history_length', style='readonly'),
                    # Item('timeseries_length', style='readonly'),
                    show_labels=True),
                HGroup(  # spring,
                    Item('controller.connect',
                         enabled_when='not object.connected'),
                    Item('controller.disconnect',
                         enabled_when='object.connected'),
                    spring,
                    show_labels=False),
                Label('Last %d points saved to disk on exit.' %
                      MAX_HISTORY_LENGTH),
            ),
            spring,
            tui.VGrid(Heading('Activate Ch:'),
                      Item('channel_1_on',
                           label='1',
                           enabled_when='not channel_1_enabled'),
                      Item('channel_2_on',
                           label='2',
                           enabled_when='not channel_2_enabled'),
                      Item('channel_3_on',
                           label='3',
                           enabled_when='not channel_3_enabled'),
                      Item('channel_4_on',
                           label='4',
                           enabled_when='not channel_4_enabled'),
                      Item('channel_5_on',
                           label='5',
                           enabled_when='not channel_5_enabled'),
                      Item('channel_6_on',
                           label='6',
                           enabled_when='not channel_6_enabled'),
                      Item('channel_7_on',
                           label='7',
                           enabled_when='not channel_7_enabled'),
                      Item('channel_8_on',
                           label='8',
                           enabled_when='not channel_8_enabled'),
                      Heading('Deactivate Ch:'),
                      Item('channel_1_off',
                           label='1',
                           enabled_when='channel_1_enabled'),
                      Item('channel_2_off',
                           label='2',
                           enabled_when='channel_2_enabled'),
                      Item('channel_3_off',
                           label='3',
                           enabled_when='channel_3_enabled'),
                      Item('channel_4_off',
                           label='4',
                           enabled_when='channel_4_enabled'),
                      Item('channel_5_off',
                           label='5',
                           enabled_when='channel_5_enabled'),
                      Item('channel_6_off',
                           label='6',
                           enabled_when='channel_6_enabled'),
                      Item('channel_7_off',
                           label='7',
                           enabled_when='channel_7_enabled'),
                      Item('channel_8_off',
                           label='8',
                           enabled_when='channel_8_enabled'),
                      show_labels=False,
                      show_border=True,
                      columns=9,
                      enabled_when='object.connected'),
            spring,
        ), )
Exemplo n.º 28
0
class HCFF(tr.HasStrictTraits):
    '''High-Cycle Fatigue Filter
    '''

    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_rows = tr.Int(4, auto_set=False, enter_set=True)
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    force_name = tr.Str('Kraft')
    peak_force_before_cycles = tr.Float(30)
    plots_num = tr.Enum(1, 2, 3, 4, 6, 9)
    plot_list = tr.List()
    plot = tr.Button
    add_plot = tr.Button
    add_creep_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)

    figure = tr.Instance(Figure)

#     plots_list = tr.List(editor=ui.SetEditor(
#         values=['kumquats', 'pomegranates', 'kiwi'],
#         can_move_all=False,
#         left_column_title='List'))

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):
        """ Handles the user clicking the 'Open...' button.
        """
        extns = ['*.csv', ]  # seems to handle only one extension...
        wildcard = '|'.join(extns)

        dialog = FileDialog(title='Select text file',
                            action='open', wildcard=wildcard,
                            default_path=self.file_csv)
        dialog.open()
        self.file_csv = dialog.path

        """ Filling x_axis and y_axis with values """
        headers_array = np.array(
            pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                nrows=1, header=None
            )
        )[0]
        for i in range(len(headers_array)):
            headers_array[i] = self.get_valid_file_name(headers_array[i])
        self.columns_headers_list = list(headers_array)

        """ Saving file name and path and creating NPY folder """
        dir_path = os.path.dirname(self.file_csv)
        self.npy_folder_path = os.path.join(dir_path, 'NPY')
        if os.path.exists(self.npy_folder_path) == False:
            os.makedirs(self.npy_folder_path)

        self.file_name = os.path.splitext(os.path.basename(self.file_csv))[0]

    #=========================================================================
    # Parameters of the filter algorithm
    #=========================================================================

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    def _parse_csv_to_npy_fired(self):
        print('Parsing csv into npy files...')

        for i in range(len(self.columns_headers_list)):
            column_array = np.array(pd.read_csv(
                self.file_csv, delimiter=self.delimiter, decimal=self.decimal, skiprows=self.skip_rows, usecols=[i]))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.columns_headers_list[i] + '.npy'), column_array)

        print('Finsihed parsing csv into npy files.')

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

#     def _add_columns_average_fired(self):
#         columns_average = ColumnsAverage(
#             columns_names=self.columns_headers_list)
#         # columns_average.set_columns_headers_list(self.columns_headers_list)
#         columns_average.configure_traits()

    def _generate_filtered_npy_fired(self):

        # 1- Export filtered force
        force = np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.force_name + '.npy')).flatten()
        peak_force_before_cycles_index = np.where(
            abs((force)) > abs(self.peak_force_before_cycles))[0][0]
        force_ascending = force[0:peak_force_before_cycles_index]
        force_rest = force[peak_force_before_cycles_index:]

        force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
            force_rest)

        force_max_min_indices = np.concatenate(
            (force_min_indices, force_max_indices))
        force_max_min_indices.sort()

        force_rest_filtered = force_rest[force_max_min_indices]
        force_filtered = np.concatenate((force_ascending, force_rest_filtered))
        np.save(os.path.join(self.npy_folder_path, self.file_name +
                             '_' + self.force_name + '_filtered.npy'), force_filtered)

        # 2- Export filtered displacements
        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):

                disp = np.load(os.path.join(self.npy_folder_path, self.file_name +
                                            '_' + self.columns_headers_list[i] + '.npy')).flatten()
                disp_ascending = disp[0:peak_force_before_cycles_index]
                disp_rest = disp[peak_force_before_cycles_index:]
                disp_ascending = savgol_filter(
                    disp_ascending, window_length=51, polyorder=2)
                disp_rest = disp_rest[force_max_min_indices]
                filtered_disp = np.concatenate((disp_ascending, disp_rest))
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_filtered.npy'), filtered_disp)

        # 3- Export creep for displacements
        # Cutting unwanted max min values to get correct full cycles and remove
        # false min/max values caused by noise
        force_max_indices_cutted, force_min_indices_cutted = self.cut_indices_in_range(force_rest,
                                                                                       force_max_indices,
                                                                                       force_min_indices,
                                                                                       self.force_max,
                                                                                       self.force_min)

        print("Cycles number= ", len(force_min_indices))
        print("Cycles number after cutting unwanted max-min range= ",
              len(force_min_indices_cutted))

        # TODO I skipped time with presuming it's the first column
        for i in range(1, len(self.columns_headers_list)):
            if self.columns_headers_list[i] != str(self.force_name):
                disp_rest_maxima = disp_rest[force_max_indices_cutted]
                disp_rest_minima = disp_rest[force_min_indices_cutted]
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_max.npy'), disp_rest_maxima)
                np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                     self.columns_headers_list[i] + '_min.npy'), disp_rest_minima)

        print('Filtered npy files are generated.')

    def cut_indices_in_range(self, array, max_indices, min_indices, range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = argrelextrema(input_array, np.greater_equal)[0]
            force_min_indices = argrelextrema(input_array, np.less_equal)[0]
        else:
            force_max_indices = argrelextrema(input_array, np.less_equal)[0]
            force_min_indices = argrelextrema(input_array, np.greater_equal)[0]

        # Remove subsequent max/min indices (np.greater_equal will give 1,2 for
        # [4, 8, 8, 1])
        force_max_indices = self.remove_subsequent_max_values(
            force_max_indices)
        force_min_indices = self.remove_subsequent_min_values(
            force_min_indices)

        # If size is not equal remove the last element from the big one
        if force_max_indices.size > force_min_indices.size:
            force_max_indices = force_max_indices[:-1]
        elif force_max_indices.size < force_min_indices.size:
            force_min_indices = force_min_indices[:-1]

        return force_max_indices, force_min_indices

    def remove_subsequent_max_values(self, force_max_indices):
        to_delete_from_maxima = []
        for i in range(force_max_indices.size - 1):
            if force_max_indices[i + 1] - force_max_indices[i] == 1:
                to_delete_from_maxima.append(i)

        force_max_indices = np.delete(force_max_indices, to_delete_from_maxima)
        return force_max_indices

    def remove_subsequent_min_values(self, force_min_indices):
        to_delete_from_minima = []
        for i in range(force_min_indices.size - 1):
            if force_min_indices[i + 1] - force_min_indices[i] == 1:
                to_delete_from_minima.append(i)
        force_min_indices = np.delete(force_min_indices, to_delete_from_minima)
        return force_min_indices

    #=========================================================================
    # Plotting
    #=========================================================================
    plot_figure_num = tr.Int(0)

    def _plot_fired(self):
        ax = self.figure.add_subplot()

    def x_plot_fired(self):
        self.plot_figure_num += 1
        plt.draw()
        plt.show()

    data_changed = tr.Event

    def _add_plot_fired(self):

        if False:  # (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        print('Loading npy files...')

        if self.apply_filters:
            x_axis_name = self.x_axis + '_filtered'
            y_axis_name = self.y_axis + '_filtered'
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_filtered.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '_filtered.npy'))
        else:
            x_axis_name = self.x_axis
            y_axis_name = self.y_axis
            x_axis_array = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '.npy'))
            y_axis_array = self.y_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.y_axis + '.npy'))

        print('Adding Plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

#        plt.figure(self.plot_figure_num)
        ax = self.figure.add_subplot(1, 1, 1)

        ax.set_xlabel('Displacement [mm]')
        ax.set_ylabel('kN')
        ax.set_title('Original data', fontsize=20)
        ax.plot(x_axis_array, y_axis_array, 'k', linewidth=0.8)

        self.plot_list.append('{}, {}'.format(x_axis_name, y_axis_name))
        self.data_changed = True
        print('Finished adding plot!')

    def apply_new_subplot(self):
        plt = self.figure
        if (self.plots_num == 1):
            plt.add_subplot(1, 1, 1)
        elif (self.plots_num == 2):
            plot_location = int('12' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 3):
            plot_location = int('13' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 4):
            plot_location = int('22' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 6):
            plot_location = int('23' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)
        elif (self.plots_num == 9):
            plot_location = int('33' + str(len(self.plot_list) + 1))
            plt.add_subplot(plot_location)

    def _add_creep_plot_fired(self):

        plt = self.figure
        if (len(self.plot_list) >= self.plots_num):
            dialog = MessageDialog(
                title='Attention!', message='Max plots number is {}'.format(self.plots_num))
            dialog.open()
            return

        disp_max = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_max.npy'))
        disp_min = self.x_axis_multiplier * \
            np.load(os.path.join(self.npy_folder_path,
                                 self.file_name + '_' + self.x_axis + '_min.npy'))

        print('Adding creep plot...')
        mpl.rcParams['agg.path.chunksize'] = 50000

        self.apply_new_subplot()
        plt.xlabel('Cycles number')
        plt.ylabel('mm')
        plt.title('Fatigue creep curve', fontsize=20)
        plt.plot(np.arange(0, disp_max.size), disp_max,
                 'k', linewidth=0.8, color='red')
        plt.plot(np.arange(0, disp_min.size), disp_min,
                 'k', linewidth=0.8, color='green')

        self.plot_list.append('Plot {}'.format(len(self.plot_list) + 1))

        print('Finished adding creep plot!')

    #=========================================================================
    # Configuration of the view
    #=========================================================================

    traits_view = ui.View(
        ui.HSplit(
            ui.VSplit(
                ui.HGroup(
                    ui.UItem('open_file_csv'),
                    ui.UItem('file_csv', style='readonly'),
                    label='Input data'
                ),
                ui.Item('add_columns_average', show_label=False),
                ui.VGroup(
                    ui.Item('skip_rows'),
                    ui.Item('decimal'),
                    ui.Item('delimiter'),
                    ui.Item('parse_csv_to_npy', show_label=False),
                    label='Filter parameters'
                ),
                ui.VGroup(
                    ui.Item('plots_num'),
                    ui.HGroup(ui.Item('x_axis'), ui.Item('x_axis_multiplier')),
                    ui.HGroup(ui.Item('y_axis'), ui.Item('y_axis_multiplier')),
                    ui.HGroup(ui.Item('add_plot', show_label=False),
                              ui.Item('apply_filters')),
                    ui.HGroup(ui.Item('add_creep_plot', show_label=False)),
                    ui.Item('plot_list'),
                    ui.Item('plot', show_label=False),
                    show_border=True,
                    label='Plotting settings'),
            ),
            ui.VGroup(
                ui.Item('force_name'),
                ui.HGroup(ui.Item('peak_force_before_cycles'),
                          show_border=True, label='Skip noise of ascending branch:'),
                #                     ui.Item('plots_list'),
                ui.VGroup(ui.Item('force_max'),
                          ui.Item('force_min'),
                          show_border=True,
                          label='Cut fake cycles for creep:'),
                ui.Item('generate_filtered_npy', show_label=False),
                show_border=True,
                label='Filters'
            ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     width=0.3,
                     label='2d plots'),
        ),
        title='HCFF Filter',
        resizable=True,
        width=0.6,
        height=0.6

    )
Exemplo n.º 29
0
class SpikesRemoval(SpanSelectorInSignal1D):
    interpolator_kind = t.Enum(
        'Linear',
        'Spline',
        default='Linear',
        desc="the type of interpolation to use when\n"
        "replacing the signal where a spike has been replaced")
    threshold = t.Float(desc="the derivative magnitude threshold above\n"
                        "which to find spikes")
    click_to_show_instructions = t.Button()
    show_derivative_histogram = t.Button()
    spline_order = t.Range(1,
                           10,
                           3,
                           desc="the order of the spline used to\n"
                           "connect the reconstructed data")
    interpolator = None
    default_spike_width = t.Int(
        5,
        desc="the width over which to do the interpolation\n"
        "when removing a spike (this can be "
        "adjusted for each\nspike by clicking "
        "and dragging on the display during\n"
        "spike replacement)")
    index = t.Int(0)
    add_noise = t.Bool(True,
                       desc="whether to add noise to the interpolated\nportion"
                       "of the spectrum. The noise properties defined\n"
                       "in the Signal metadata are used if present,"
                       "otherwise\nshot noise is used as a default")

    thisOKButton = tu.Action(name="OK",
                             action="OK",
                             tooltip="Close the spikes removal tool")

    thisApplyButton = tu.Action(name="Remove spike",
                                action="apply",
                                tooltip="Remove the current spike by "
                                "interpolating\n"
                                "with the specified settings (and find\n"
                                "the next spike automatically)")
    thisFindButton = tu.Action(
        name="Find next",
        action="find",
        tooltip="Find the next (in terms of navigation\n"
        "dimensions) spike in the data.")

    thisPreviousButton = tu.Action(name="Find previous",
                                   action="back",
                                   tooltip="Find the previous (in terms of "
                                   "navigation\n"
                                   "dimensions) spike in the data.")
    view = tu.View(
        tu.Group(
            tu.Group(
                tu.Item(
                    'click_to_show_instructions',
                    show_label=False,
                ),
                tu.Item('show_derivative_histogram',
                        show_label=False,
                        tooltip="To determine the appropriate threshold,\n"
                        "plot the derivative magnitude histogram, \n"
                        "and look for outliers at high magnitudes \n"
                        "(which represent sudden spikes in the data)"),
                'threshold',
                show_border=True,
            ),
            tu.Group('add_noise',
                     'interpolator_kind',
                     'default_spike_width',
                     tu.Group('spline_order',
                              enabled_when='interpolator_kind == \'Spline\''),
                     show_border=True,
                     label='Advanced settings'),
        ),
        buttons=[
            thisOKButton,
            thisPreviousButton,
            thisFindButton,
            thisApplyButton,
        ],
        handler=SpikesRemovalHandler,
        title='Spikes removal tool',
        resizable=False,
    )

    def __init__(self, signal, navigation_mask=None, signal_mask=None):
        super(SpikesRemoval, self).__init__(signal)
        self.interpolated_line = None
        self.coordinates = [
            coordinate
            for coordinate in signal.axes_manager._am_indices_generator() if
            (navigation_mask is None or not navigation_mask[coordinate[::-1]])
        ]
        self.signal = signal
        self.line = signal._plot.signal_plot.ax_lines[0]
        self.ax = signal._plot.signal_plot.ax
        signal._plot.auto_update_plot = False
        if len(self.coordinates) > 1:
            signal.axes_manager.indices = self.coordinates[0]
        self.threshold = 400
        self.index = 0
        self.argmax = None
        self.derivmax = None
        self.kind = "linear"
        self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
        self.signal_mask = signal_mask
        self.navigation_mask = navigation_mask
        md = self.signal.metadata
        from hyperspy.signal import BaseSignal

        if "Signal.Noise_properties" in md:
            if "Signal.Noise_properties.variance" in md:
                self.noise_variance = md.Signal.Noise_properties.variance
                if isinstance(md.Signal.Noise_properties.variance, BaseSignal):
                    self.noise_type = "heteroscedastic"
                else:
                    self.noise_type = "white"
            else:
                self.noise_type = "shot noise"
        else:
            self.noise_type = "shot noise"

    def _threshold_changed(self, old, new):
        self.index = 0
        self.update_plot()

    def _click_to_show_instructions_fired(self):
        m = information(None, "\nTo remove spikes from the data:\n\n"
                        "   1. Click \"Show derivative histogram\" to "
                        "determine at what magnitude the spikes are present.\n"
                        "   2. Enter a suitable threshold (lower than the "
                        "lowest magnitude outlier in the histogram) in the "
                        "\"Threshold\" box, which will be the magnitude "
                        "from which to search. \n"
                        "   3. Click \"Find next\" to find the first spike.\n"
                        "   4. If desired, the width and position of the "
                        "boundaries used to replace the spike can be "
                        "adjusted by clicking and dragging on the displayed "
                        "plot.\n "
                        "   5. View the spike (and the replacement data that "
                        "will be added) and click \"Remove spike\" in order "
                        "to alter the data as shown. The tool will "
                        "automatically find the next spike to replace.\n"
                        "   6. Repeat this process for each spike throughout "
                        "the dataset, until the end of the dataset is "
                        "reached.\n"
                        "   7. Click \"OK\" when finished to close the spikes "
                        "removal tool.\n\n"
                        "Note: Various settings can be configured in "
                        "the \"Advanced settings\" section. Hover the "
                        "mouse over each parameter for a description of what "
                        "it does."
                        "\n",
                        title="Instructions"),

    def _show_derivative_histogram_fired(self):
        self.signal._spikes_diagnosis(signal_mask=self.signal_mask,
                                      navigation_mask=self.navigation_mask)

    def detect_spike(self):
        derivative = np.diff(self.signal())
        if self.signal_mask is not None:
            derivative[self.signal_mask[:-1]] = 0
        if self.argmax is not None:
            left, right = self.get_interpolation_range()
            self._temp_mask[left:right] = True
            derivative[self._temp_mask[:-1]] = 0
        if abs(derivative.max()) >= self.threshold:
            self.argmax = derivative.argmax()
            self.derivmax = abs(derivative.max())
            return True
        else:
            return False

    def _reset_line(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
            self.reset_span_selector()

    def find(self, back=False):
        self._reset_line()
        ncoordinates = len(self.coordinates)
        spike = self.detect_spike()
        while not spike and ((self.index < ncoordinates - 1 and back is False)
                             or (self.index > 0 and back is True)):
            if back is False:
                self.index += 1
            else:
                self.index -= 1
            spike = self.detect_spike()

        if spike is False:
            messages.information('End of dataset reached')
            self.index = 0
            self._reset_line()
            return
        else:
            minimum = max(0, self.argmax - 50)
            maximum = min(len(self.signal()) - 1, self.argmax + 50)
            thresh_label = DerivativeTextParameters(
                text="$\mathsf{\delta}_\mathsf{max}=$", color="black")
            self.ax.legend([thresh_label], [repr(int(self.derivmax))],
                           handler_map={
                               DerivativeTextParameters:
                               DerivativeTextHandler()
                           },
                           loc='best')
            self.ax.set_xlim(
                self.signal.axes_manager.signal_axes[0].index2value(minimum),
                self.signal.axes_manager.signal_axes[0].index2value(maximum))
            self.update_plot()
            self.create_interpolation_line()

    def update_plot(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None
        self.reset_span_selector()
        self.update_spectrum_line()
        if len(self.coordinates) > 1:
            self.signal._plot.pointer._update_patch_position()

    def update_spectrum_line(self):
        self.line.auto_update = True
        self.line.update()
        self.line.auto_update = False

    def _index_changed(self, old, new):
        self.signal.axes_manager.indices = self.coordinates[new]
        self.argmax = None
        self._temp_mask[:] = False

    def on_disabling_span_selector(self):
        if self.interpolated_line is not None:
            self.interpolated_line.close()
            self.interpolated_line = None

    def _spline_order_changed(self, old, new):
        self.kind = self.spline_order
        self.span_selector_changed()

    def _add_noise_changed(self, old, new):
        self.span_selector_changed()

    def _interpolator_kind_changed(self, old, new):
        if new == 'linear':
            self.kind = new
        else:
            self.kind = self.spline_order
        self.span_selector_changed()

    def _ss_left_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def _ss_right_value_changed(self, old, new):
        if not (np.isnan(self.ss_right_value) or np.isnan(self.ss_left_value)):
            self.span_selector_changed()

    def create_interpolation_line(self):
        self.interpolated_line = drawing.signal1d.Signal1DLine()
        self.interpolated_line.data_function = self.get_interpolated_spectrum
        self.interpolated_line.set_line_properties(color='blue', type='line')
        self.signal._plot.signal_plot.add_line(self.interpolated_line)
        self.interpolated_line.autoscale = False
        self.interpolated_line.plot()

    def get_interpolation_range(self):
        axis = self.signal.axes_manager.signal_axes[0]
        if np.isnan(self.ss_left_value) or np.isnan(self.ss_right_value):
            left = self.argmax - self.default_spike_width
            right = self.argmax + self.default_spike_width
        else:
            left = axis.value2index(self.ss_left_value)
            right = axis.value2index(self.ss_right_value)

        # Clip to the axis dimensions
        nchannels = self.signal.axes_manager.signal_shape[0]
        left = left if left >= 0 else 0
        right = right if right < nchannels else nchannels - 1

        return left, right

    def get_interpolated_spectrum(self, axes_manager=None):
        data = self.signal().copy()
        axis = self.signal.axes_manager.signal_axes[0]
        left, right = self.get_interpolation_range()
        if self.kind == 'linear':
            pad = 1
        else:
            pad = 10
        ileft = left - pad
        iright = right + pad
        ileft = np.clip(ileft, 0, len(data))
        iright = np.clip(iright, 0, len(data))
        left = int(np.clip(left, 0, len(data)))
        right = int(np.clip(right, 0, len(data)))
        x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
        y = np.hstack((data[ileft:left], data[right:iright]))
        if ileft == 0:
            # Extrapolate to the left
            data[left:right] = data[right + 1]

        elif iright == (len(data) - 1):
            # Extrapolate to the right
            data[left:right] = data[left - 1]

        else:
            # Interpolate
            intp = sp.interpolate.interp1d(x, y, kind=self.kind)
            data[left:right] = intp(axis.axis[left:right])

        # Add noise
        if self.add_noise is True:
            if self.noise_type == "white":
                data[left:right] += np.random.normal(scale=np.sqrt(
                    self.noise_variance),
                                                     size=right - left)
            elif self.noise_type == "heteroscedastic":
                noise_variance = self.noise_variance(
                    axes_manager=self.signal.axes_manager)[left:right]
                noise = [
                    np.random.normal(scale=np.sqrt(item))
                    for item in noise_variance
                ]
                data[left:right] += noise
            else:
                data[left:right] = np.random.poisson(
                    np.clip(data[left:right], 0, np.inf))

        return data

    def span_selector_changed(self):
        if self.interpolated_line is None:
            return
        else:
            self.interpolated_line.update()

    def apply(self):
        self.signal()[:] = self.get_interpolated_spectrum()
        self.signal.events.data_changed.trigger(obj=self.signal)
        self.update_spectrum_line()
        self.interpolated_line.close()
        self.interpolated_line = None
        self.reset_span_selector()
        self.find()
Exemplo n.º 30
0
class Matplotlibify(traits.HasTraits):
    logFilePlotReference = traits.Instance(
        logFilePlots.plotObjects.logFilePlot.LogFilePlot)
    plotPropertiesList = traits.List(PlotProperties)
    logFilePlot1 = traits.Any()
    logFilePlot2 = traits.Any()
    logFilePlotsReference = traits.Instance(
        logFilePlots.LogFilePlots)  #refernce to logFilePlots object
    isPriviliged = traits.Bool(False)
    hardCodeLegendBool = traits.Bool(
        False,
        desc=
        "click if you want to write your own legend otherwise it will generate legend based on series and legend replacement dict"
    )
    hardCodeLegendString = traits.String(
        "", desc="comma seperated string for each legend entry")
    #xLim = traits.Tuple()
    replacementStrings = {}
    savedPrintsDirectory = traits.Directory(
        os.path.join("\\\\ursa", "AQOGroupFolder", "Experiment Humphry",
                     "Data", "savedPrints"))
    showWaterMark = traits.Bool(True)

    matplotlibifyMode = traits.Enum("default", "dual plot")
    generatePlotScriptButton = traits.Button("generate plot")
    showPlotButton = traits.Button("show")
    #templatesFolder = os.path.join( os.path.expanduser('~'),"Google Drive","Thesis","python scripts","matplotlibify")
    templatesFolder = os.path.join("\\\\ursa", "AQOGroupFolder",
                                   "Experiment Humphry",
                                   "Experiment Control And Software",
                                   "LogFilePlots", "matplotlibify",
                                   "templates")
    templateFile = traits.File(
        os.path.join(templatesFolder, "matplotlibifyDefaultTemplate.py"))
    generatedScriptLocation = traits.File(
        os.path.join(os.path.expanduser('~'), "Google Drive", "Thesis",
                     "python scripts", "matplotlibify", "debug.py"))
    saveToOneNote = traits.Button("Save to OneNote")
    printButton = traits.Button("print")
    dualPlotMode = traits.Enum('sharedXY', 'sharedX', 'sharedY', 'stacked',
                               'stackedX', 'stackedY')
    logLibrarianReference = None

    secondPlotGroup = traitsui.VGroup(
        traitsui.Item("matplotlibifyMode", label="mode"),
        traitsui.HGroup(
            traitsui.Item("logFilePlot1",
                          visible_when="matplotlibifyMode=='dual plot'"),
            traitsui.Item("logFilePlot2",
                          visible_when="matplotlibifyMode=='dual plot'"),
            traitsui.Item('dualPlotMode',
                          visible_when="matplotlibifyMode=='dual plot'",
                          show_label=False)),
    )
    plotPropertiesGroup = traitsui.Item(
        "plotPropertiesList",
        editor=traitsui.ListEditor(style="custom"),
        show_label=False,
        resizable=True)

    generalGroup = traitsui.VGroup(
        traitsui.Item("showWaterMark", label="show watermark"),
        traitsui.HGroup(
            traitsui.Item("hardCodeLegendBool", label="hard code legend?"),
            traitsui.Item("hardCodeLegendString",
                          show_label=False,
                          visible_when="hardCodeLegendBool")),
        traitsui.Item("templateFile"),
        traitsui.Item("generatedScriptLocation", visible_when='isPriviliged'),
        traitsui.Item('generatePlotScriptButton', visible_when='isPriviliged'),
        traitsui.Item('showPlotButton'),
        traitsui.Item(
            'saveToOneNote', enabled_when='True'
        ),  # was deactivated for some time, probably there was an error, I try to debug this now
        traitsui.Item('printButton'))

    traits_view = traitsui.View(secondPlotGroup,
                                plotPropertiesGroup,
                                generalGroup,
                                resizable=True,
                                kind='live')

    def __init__(self, **traitsDict):
        super(Matplotlibify, self).__init__(**traitsDict)
        self.plotPropertiesList = [PlotProperties(self.logFilePlotReference)]
        self.generateReplacementStrings()
        self.add_trait(
            "logFilePlot1",
            traits.Trait(
                self.logFilePlotReference.logFilePlotsTabName, {
                    lfp.logFilePlotsTabName: lfp
                    for lfp in self.logFilePlotsReference.lfps
                }))
        self.add_trait(
            "logFilePlot2",
            traits.Trait(
                self.logFilePlotReference.logFilePlotsTabName, {
                    lfp.logFilePlotsTabName: lfp
                    for lfp in self.logFilePlotsReference.lfps
                }))

    def generateReplacementStrings(self):
        self.replacementStrings = {}
        if self.matplotlibifyMode == 'default':
            specific = self.plotPropertiesList[
                0].getReplacementStringsSpecific(identifier="")
            generic = self.getGlobalReplacementStrings()
            self.replacementStrings.update(specific)
            self.replacementStrings.update(generic)

        elif self.matplotlibifyMode == 'dual plot':
            specific1 = self.plotPropertiesList[
                0].getReplacementStringsSpecific(identifier="lfp1.")
            specific2 = self.plotPropertiesList[
                1].getReplacementStringsSpecific(identifier="lfp2.")
            generic = self.getGlobalReplacementStrings()
            self.replacementStrings.update(specific1)
            self.replacementStrings.update(specific2)
            self.replacementStrings.update(generic)

        for key in self.replacementStrings.keys(
        ):  #wrap strings in double quotes
            logger.info("%s = %s" % (self.replacementStrings[key],
                                     type(self.replacementStrings[key])))
            if isinstance(self.replacementStrings[key], (str, unicode)):
                if self.replacementStrings[key].startswith("def "):
                    continue  #if it is a function definition then dont wrap in quotes!
                else:
                    self.replacementStrings[key] = unicode(
                        self.wrapInQuotes(self.replacementStrings[key]))

    def getGlobalReplacementStrings(self, identifier=""):
        """generates the replacement strings that are specific to a log file plot """
        return {
            '{{%shardCodeLegendBool}}' % identifier: self.hardCodeLegendBool,
            '{{%shardCodeLegendString}}' % identifier:
            self.hardCodeLegendString,
            '{{%smatplotlibifyMode}}' % identifier: self.matplotlibifyMode,
            '{{%sshowWaterMark}}' % identifier: self.showWaterMark,
            '{{%sdualPlotMode}}' % identifier: self.dualPlotMode
        }

    def wrapInQuotes(self, string):
        return '"%s"' % string

    def _isPriviliged_default(self):
        if os.path.exists(
                os.path.join("C:", "Users", "tharrison", "Google Drive",
                             "Thesis", "python scripts", "matplotlibify")):
            return True
        else:
            return False

    def _generatedScriptLocation_default(self):
        root = os.path.join("C:", "Users", "tharrison", "Google Drive",
                            "Thesis", "python scripts", "matplotlibify")
        head, tail = os.path.split(self.logFilePlotReference.logFile)
        matplotlibifyName = os.path.splitext(tail)[0] + "-%s-vs-%s" % (
            self.plotPropertiesList[0]._yAxisLabel_default(),
            self.plotPropertiesList[0]._xAxisLabel_default())
        baseName = os.path.join(root, matplotlibifyName)
        filename = baseName + ".py"
        c = 0
        while os.path.exists(filename + ".py"):
            filename = baseName + "-%s.py" % c
            c += 1
        return filename

    def replace_all(self, text, replacementDictionary):
        for placeholder, new in replacementDictionary.iteritems():
            text = text.replace(placeholder, str(new))
        return text

    def _generatePlotScriptButton_fired(self):
        self.writePlotScriptToFile(self.generatedScriptLocation)

    def writePlotScriptToFile(self, path):
        """writes the script that generates the plot to the path """
        logger.info("attempting to generate matplotlib script...")
        self.generateReplacementStrings()
        with open(self.templateFile, "rb") as template:
            text = self.replace_all(template.read(), self.replacementStrings)
        with open(self.generatedScriptLocation, "wb") as output:
            output.write(text)
        logger.info("succesfully generated matplotlib script at location %s " %
                    self.generatedScriptLocation)

    def autoSavePlotWithMatplotlib(self, path):
        """runs the script with an appended plt.save() and plt.close("all")"""
        logger.info("attempting to save matplotlib plot...")
        self.generateReplacementStrings()
        with open(self.templateFile, "rb") as template:
            text = self.replace_all(template.read(), self.replacementStrings)
        ns = {}
        saveCode = "\n\nplt.savefig(r'%s', dpi=300)\nplt.close('all')" % path
        logger.info("executing save statement:%s" % saveCode)
        text += saveCode
        exec text in ns
        logger.info("exec completed succesfully...")

    def _showPlotButton_fired(self):
        logger.info("attempting to show matplotlib plot...")
        self.generateReplacementStrings()
        with open(self.templateFile, "rb") as template:
            text = self.replace_all(template.read(), self.replacementStrings)
        ns = {}
        exec text in ns
        logger.info("exec completed succesfully...")

    def _saveToOneNote_fired(self):
        """calls the lfp function to save the file in the log folder and then
        save it to oneNote. THis way there is no oneNote code in matplotlibify"""
        if self.logLibrarianReference is None:
            self.logFilePlotReference.savePlotAsImage(self)
        else:
            self.logFilePlotReference.savePlotAsImage(
                self, self.logLibrarianReference)

    def _matplotlibifyMode_changed(self):
        """change default template depending on whether or not this is a double axis plot """
        if self.matplotlibifyMode == "default":
            self.templateFile = os.path.join(
                self.templatesFolder, "matplotlibifyDefaultTemplate.py")
            self.plotPropertiesList = [
                PlotProperties(self.logFilePlotReference)
            ]
        elif self.matplotlibifyMode == "dual plot":
            self.templateFile = os.path.join(
                self.templatesFolder, "matplotlibifyDualPlotTemplate.py")
            if len(self.plotPropertiesList) > 1:
                self.plotPropertiesList[1] = PlotProperties(
                    self.logFilePlot2_)  #or should it be logFilePlot2_???
                logger.info("chanigng second element of plot properties list")
            elif len(self.plotPropertiesList) == 1:
                self.plotPropertiesList.append(
                    PlotProperties(self.logFilePlot2_))
                logger.info("appending to plot properties list")
            else:
                logger.error(
                    "there only be 1 or 2 elements in plot properties but found %s elements"
                    % len(self.plotPropertiesList))

    def _logFilePlot1_changed(self):
        """logFilePlot1 changed so update plotPropertiesList """
        logger.info("logFilePlot1 changed. updating plotPropertiesList")
        self.plotPropertiesList[0] = PlotProperties(self.logFilePlot1_)

    def _logFilePlot2_changed(self):
        logger.info("logFilePlot2 changed. updating plotPropertiesList")
        self.plotPropertiesList[1] = PlotProperties(self.logFilePlot2_)

    def dualPlotModeUpdates(self):
        """called when either _logFilePlot1 or _logFilePLot2 change """
        if (self.logFilePlot1_.xAxis == self.logFilePlot2_.xAxis
            ):  #Twin X 2 y axes mode
            if self.logFilePlot1_.yAxis == self.logFilePlot2_.yAxis:
                self.dualPlotMode = 'sharedXY'
            else:
                self.dualPlotMode = 'sharedX'
        elif self.logFilePlot1_.yAxis == self.logFilePlot2_.yAxis:
            self.dualPlotMode = 'sharedY'
        else:
            self.dualPlotMode = 'stacked'

    def _printButton_fired(self):
        """uses windows built in print image functionality to send png of plot to printer """
        logFolder, tail = os.path.split(self.logFilePlotReference.logFile)
        #logName = tail.strip(".csv")+" - "+str(self.selectedLFP.xAxis)+" vs "+str(self.selectedLFP.yAxis)
        imageFileName = os.path.join(logFolder, "temporary_print.png")
        self.logFilePlotReference.savePlotAsImage(self,
                                                  name=imageFileName,
                                                  oneNote=False)
        logger.info("attempting to use windows native printing dialog")
        os.startfile(os.path.normpath(imageFileName), "print")
        logger.info("saving to savedPrints folder")
        head, tail = os.path.split(self._generatedScriptLocation_default())
        tail = tail.replace(".py", ".png")
        dst = os.path.join(self.savedPrintsDirectory, tail)
        shutil.copyfile(os.path.normpath(imageFileName), dst)
        logger.info("saved to savedPrints folder")