Example #1
0
class XASNormPanel(TaskPanel):
    """XAS normalization Panel"""
    def __init__(self, parent, controller=None, **kws):
        TaskPanel.__init__(self,
                           parent,
                           controller,
                           configname='xasnorm_config',
                           config=defaults,
                           **kws)

    def build_display(self):
        panel = self.panel
        self.wids = {}

        self.plotone_op = Choice(panel,
                                 choices=list(PlotOne_Choices.keys()),
                                 action=self.onPlotOne,
                                 size=(200, -1))
        self.plotsel_op = Choice(panel,
                                 choices=list(PlotSel_Choices.keys()),
                                 action=self.onPlotSel,
                                 size=(200, -1))

        self.plotone_op.SetSelection(1)
        self.plotsel_op.SetSelection(1)

        plot_one = Button(panel,
                          'Plot Current Group',
                          size=(170, -1),
                          action=self.onPlotOne)

        plot_sel = Button(panel,
                          'Plot Selected Groups',
                          size=(170, -1),
                          action=self.onPlotSel)

        e0panel = wx.Panel(panel)
        self.wids['auto_e0'] = Check(e0panel,
                                     default=True,
                                     label='auto?',
                                     action=self.onSet_XASE0)
        self.wids['showe0'] = Check(e0panel,
                                    default=True,
                                    label='show?',
                                    action=self.onSet_XASE0)
        sx = wx.BoxSizer(wx.HORIZONTAL)
        sx.Add(self.wids['auto_e0'], 0, LEFT, 4)
        sx.Add(self.wids['showe0'], 0, LEFT, 4)
        pack(e0panel, sx)

        self.wids['auto_step'] = Check(panel,
                                       default=True,
                                       label='auto?',
                                       action=self.onNormMethod)

        self.wids['nvict'] = Choice(panel,
                                    choices=('0', '1', '2', '3'),
                                    size=(100, -1),
                                    action=self.onNormMethod,
                                    default=0)

        self.wids['nnorm'] = Choice(panel,
                                    choices=list(Nnorm_choices.values()),
                                    size=(100, -1),
                                    action=self.onNormMethod,
                                    default=0)

        opts = {
            'size': (100, -1),
            'digits': 2,
            'increment': 5.0,
            'action': self.onSet_Ranges
        }

        xas_pre1 = self.add_floatspin('pre1', value=defaults['pre1'], **opts)
        xas_pre2 = self.add_floatspin('pre2', value=defaults['pre2'], **opts)
        xas_norm1 = self.add_floatspin('norm1',
                                       value=defaults['norm1'],
                                       **opts)
        xas_norm2 = self.add_floatspin('norm2',
                                       value=defaults['norm2'],
                                       **opts)

        opts = {'digits': 3, 'increment': 0.1, 'value': 0}
        plot_voff = self.add_floatspin('plot_voff',
                                       with_pin=False,
                                       size=(80, -1),
                                       action=self.onVoffset,
                                       **opts)

        xas_e0 = self.add_floatspin('e0', action=self.onSet_XASE0Val, **opts)
        xas_step = self.add_floatspin('step',
                                      action=self.onSet_XASStep,
                                      with_pin=False,
                                      min_val=0.0,
                                      **opts)

        opts['value'] = 1.0
        scale = self.add_floatspin('scale', action=self.onSet_Scale, **opts)

        self.wids['norm_method'] = Choice(
            panel,
            choices=('polynomial', 'mback'),  # , 'area'),
            size=(120, -1),
            action=self.onNormMethod)
        self.wids['norm_method'].SetSelection(0)
        atsyms = ['?'] + self.larch.symtable._xray._xraydb.atomic_symbols
        edges = ('K', 'L3', 'L2', 'L1', 'M5')

        self.wids['atsym'] = Choice(panel, choices=atsyms, size=(75, -1))
        self.wids['edge'] = Choice(panel, choices=edges, size=(60, -1))

        self.wids['is_frozen'] = Check(panel,
                                       default=False,
                                       label='Freeze Group',
                                       action=self.onFreezeGroup)

        saveconf = Button(panel,
                          'Save as Default Settings',
                          size=(200, -1),
                          action=self.onSaveConfigBtn)

        use_auto = Button(panel,
                          'Use Default Settings',
                          size=(200, -1),
                          action=self.onAutoNorm)
        copy_auto = Button(panel,
                           'Copy',
                           size=(60, -1),
                           action=self.onCopyAuto)

        def CopyBtn(name):
            return Button(panel,
                          'Copy',
                          size=(60, -1),
                          action=partial(self.onCopyParam, name))

        add_text = self.add_text
        HLINEWID = 575
        panel.Add(SimpleText(panel,
                             'XAS Pre-edge subtraction and Normalization',
                             **self.titleopts),
                  dcol=4)
        panel.Add(SimpleText(panel, 'Copy to Selected Groups:'),
                  style=RIGHT,
                  dcol=2)

        panel.Add(plot_sel, newrow=True)
        panel.Add(self.plotsel_op, dcol=3)
        panel.Add(SimpleText(panel, 'Y Offset:'), style=RIGHT)
        panel.Add(plot_voff, style=RIGHT)

        panel.Add(plot_one, newrow=True)
        panel.Add(self.plotone_op, dcol=4)
        panel.Add(CopyBtn('plotone_op'), dcol=1, style=RIGHT)

        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)
        add_text('Non-XAS Data Scale:')
        panel.Add(scale, dcol=2)

        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)
        add_text('XAS Data:')
        panel.Add(use_auto, dcol=4)
        panel.Add(copy_auto, dcol=1, style=RIGHT)

        add_text('Element and Edge: ', newrow=True)
        panel.Add(self.wids['atsym'])
        panel.Add(self.wids['edge'], dcol=3)
        panel.Add(CopyBtn('atsym'), dcol=1, style=RIGHT)

        add_text('E0 : ')
        panel.Add(xas_e0)
        panel.Add(e0panel, dcol=3)
        panel.Add(CopyBtn('xas_e0'), dcol=1, style=RIGHT)

        add_text('Edge Step: ')
        panel.Add(xas_step)
        panel.Add(self.wids['auto_step'], dcol=3)
        panel.Add(CopyBtn('xas_step'), dcol=1, style=RIGHT)

        panel.Add((5, 5), newrow=True)
        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)

        add_text('Pre-edge range: ')
        panel.Add(xas_pre1)
        add_text(' : ', newrow=False)
        panel.Add(xas_pre2, dcol=2)
        panel.Add(CopyBtn('xas_pre'), dcol=1, style=RIGHT)

        panel.Add(SimpleText(panel, 'Victoreen order:'), newrow=True)
        panel.Add(self.wids['nvict'], dcol=4)

        panel.Add((5, 5), newrow=True)
        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)

        add_text('Normalization method: ')
        panel.Add(self.wids['norm_method'], dcol=4)
        panel.Add(CopyBtn('xas_norm'), dcol=1, style=RIGHT)

        add_text('Normalization range: ')
        panel.Add(xas_norm1)
        add_text(' : ', newrow=False)
        panel.Add(xas_norm2, dcol=2)
        panel.Add(SimpleText(panel, 'Polynomial Type:'), newrow=True)
        panel.Add(self.wids['nnorm'], dcol=4)

        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)
        panel.Add((5, 5), newrow=True)
        panel.Add(self.wids['is_frozen'], newrow=True)
        panel.Add(saveconf, dcol=5)

        panel.Add((5, 5), newrow=True)
        panel.Add(HLine(panel, size=(HLINEWID, 3)), dcol=6, newrow=True)
        panel.pack()

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add((5, 5), 0, LEFT, 3)
        sizer.Add(panel, 0, LEFT, 3)
        sizer.Add((5, 5), 0, LEFT, 3)

        pack(self, sizer)

    def get_config(self, dgroup=None):
        """custom get_config to possibly inherit from Athena settings"""
        if dgroup is None:
            dgroup = self.controller.get_group()
        if dgroup is None:
            return self.get_defaultconfig()
        if hasattr(dgroup, self.configname):
            conf = getattr(dgroup, self.configname)
        else:
            conf = self.get_defaultconfig()
            if hasattr(dgroup, 'bkg_params'):  # from Athena
                for attr in ('e0', 'pre1', 'pre2', 'nnorm'):
                    conf[attr] = getattr(dgroup.bkg_params, attr, conf[attr])
                for attr, aattr in (('norm1', 'nor1'), ('norm2', 'nor2')):
                    conf[attr] = getattr(dgroup.bkg_params, aattr, conf[attr])
                conf['auto_step'] = (float(
                    getattr(dgroup.bkg_params, 'fixstep', 0.0)) < 0.5)
                conf['edge_step'] = getattr(dgroup.bkg_params, 'step',
                                            conf['edge_step'])

        if conf['edge_step'] is None:
            conf['edge_step'] = getattr(dgroup, 'edge_step', conf['edge_step'])
        conf['atsym'] = getattr(dgroup, 'atsym', conf['atsym'])
        conf['edge'] = getattr(dgroup, 'edge', conf['edge'])
        if hasattr(dgroup, 'e0') and conf['atsym'] == '?':
            atsym, edge = guess_edge(dgroup.e0)
            conf['atsym'] = atsym
            conf['edge'] = edge

        if hasattr(dgroup, 'mback_params'):
            conf['atsym'] = getattr(dgroup.mback_params, 'atsym',
                                    conf['atsym'])
            conf['edge'] = getattr(dgroup.mback_params, 'edge', conf['edge'])

        setattr(dgroup, self.configname, conf)
        return conf

    def fill_form(self, dgroup):
        """fill in form from a data group"""
        opts = self.get_config(dgroup)

        self.skip_process = True
        if dgroup.datatype == 'xas':
            self.plotone_op.SetChoices(list(PlotOne_Choices.keys()))
            self.plotsel_op.SetChoices(list(PlotSel_Choices.keys()))

            self.plotone_op.SetStringSelection(opts['plotone_op'])
            self.plotsel_op.SetStringSelection(opts['plotsel_op'])
            self.wids['e0'].SetValue(opts['e0'])
            edge_step = opts.get('edge_step', None)
            if edge_step is None:
                edge_step = 1.0

            if hasattr(dgroup, 'e0') and opts['atsym'] == '?':
                atsym, edge = guess_edge(dgroup.e0)
                opts['atsym'] = atsym
                opts['edge'] = edge

            self.wids['step'].SetValue(edge_step)
            autoset_fs_increment(self.wids['step'], edge_step)
            for attr in ('pre1', 'pre2', 'norm1', 'norm2'):
                val = opts.get(attr, None)
                if val is not None:
                    self.wids[attr].SetValue(val)

            self.set_nnorm_widget(opts.get('nnorm'))

            self.wids['nvict'].SetSelection(opts['nvict'])
            self.wids['showe0'].SetValue(opts['show_e0'])
            self.wids['auto_e0'].SetValue(opts['auto_e0'])
            self.wids['auto_step'].SetValue(opts['auto_step'])
            self.wids['edge'].SetStringSelection(opts['edge'].title())
            self.wids['atsym'].SetStringSelection(opts['atsym'].title())
            self.wids['norm_method'].SetStringSelection(
                opts['norm_method'].lower())
            for attr in ('pre1', 'pre2', 'norm1', 'norm2', 'nnorm', 'edge',
                         'atsym', 'step', 'norm_method'):
                self.wids[attr].Enable()
            self.wids['scale'].Disable()

        else:
            self.plotone_op.SetChoices(list(PlotOne_Choices_nonxas.keys()))
            self.plotsel_op.SetChoices(list(PlotSel_Choices_nonxas.keys()))
            self.wids['scale'].SetValue(opts['scale'])
            for attr in ('pre1', 'pre2', 'norm1', 'norm2', 'nnorm', 'edge',
                         'atsym', 'step', 'norm_method'):
                self.wids[attr].Disable()
            self.wids['scale'].Enable()

        frozen = opts.get('is_frozen', False)
        if hasattr(dgroup, 'is_frozen'):
            frozen = dgroup.is_frozen

        self.wids['is_frozen'].SetValue(frozen)
        self._set_frozen(frozen)
        wx.CallAfter(self.unset_skip_process)

    def set_nnorm_widget(self, nnorm=None):
        if nnorm is None:
            nnorm_str = 'auto'
        else:
            try:
                nnorm = int(nnorm)
            except ValueError:
                nnorm = None
            nnorm_str = Nnorm_choices.get(nnorm, 'auto')
        self.wids['nnorm'].SetStringSelection(nnorm_str)

    def unset_skip_process(self):
        self.skip_process = False

    def read_form(self):
        "read form, return dict of values"
        form_opts = {}
        form_opts['e0'] = self.wids['e0'].GetValue()
        form_opts['edge_step'] = self.wids['step'].GetValue()
        for attr in ('pre1', 'pre2', 'norm1', 'norm2'):
            val = self.wids[attr].GetValue()
            if val == 0: val = None
            form_opts[attr] = val

        form_opts['nnorm'] = Nnorm_names.get(
            self.wids['nnorm'].GetStringSelection(), None)
        form_opts['nvict'] = int(self.wids['nvict'].GetSelection())
        form_opts['plotone_op'] = self.plotone_op.GetStringSelection()
        form_opts['plotsel_op'] = self.plotsel_op.GetStringSelection()
        form_opts['plot_voff'] = self.wids['plot_voff'].GetValue()
        form_opts['show_e0'] = self.wids['showe0'].IsChecked()
        form_opts['auto_e0'] = self.wids['auto_e0'].IsChecked()
        form_opts['auto_step'] = self.wids['auto_step'].IsChecked()

        form_opts['norm_method'] = self.wids['norm_method'].GetStringSelection(
        ).lower()
        form_opts['edge'] = self.wids['edge'].GetStringSelection().title()
        form_opts['atsym'] = self.wids['atsym'].GetStringSelection().title()
        form_opts['scale'] = self.wids['scale'].GetValue()
        return form_opts

    def onNormMethod(self, evt=None):
        method = self.wids['norm_method'].GetStringSelection().lower()
        self.update_config({'norm_method': method})
        if method.startswith('mback'):
            dgroup = self.controller.get_group()
            cur_elem = self.wids['atsym'].GetStringSelection()
            if hasattr(dgroup, 'e0') and cur_elem == 'H':
                atsym, edge = guess_edge(dgroup.e0)
                self.wids['edge'].SetStringSelection(edge)
                self.wids['atsym'].SetStringSelection(atsym)
                self.update_config({'edge': edge, 'atsym': atsym})
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def _set_frozen(self, frozen):
        try:
            dgroup = self.controller.get_group()
            dgroup.is_frozen = frozen
        except:
            pass
        for wattr in ('e0', 'step', 'pre1', 'pre2', 'norm1', 'norm2', 'nvict',
                      'nnorm', 'showe0', 'auto_e0', 'auto_step', 'norm_method',
                      'edge', 'atsym'):
            self.wids[wattr].Enable(not frozen)

    def onFreezeGroup(self, evt=None):
        self._set_frozen(evt.IsChecked())

    def onPlotOne(self, evt=None):
        self.plot(self.controller.get_group())

    def onVoffset(self, evt=None):
        time.sleep(0.01)
        wx.CallAfter(self.onPlotSel)

    def onPlotSel(self, evt=None):
        newplot = True
        group_ids = self.controller.filelist.GetCheckedStrings()
        if len(group_ids) < 1:
            return
        last_id = group_ids[-1]

        groupname = self.controller.file_groups[str(last_id)]
        dgroup = self.controller.get_group(groupname)

        plot_choices = PlotSel_Choices
        if dgroup.datatype != 'xas':
            plot_choices = PlotSel_Choices_nonxas

        ytitle = self.plotsel_op.GetStringSelection()
        yarray_name = plot_choices[ytitle]
        ylabel = getattr(plotlabels, yarray_name, ytitle)

        if yarray_name == 'norm':
            norm_method = self.wids['norm_method'].GetStringSelection().lower()
            if norm_method.startswith('mback'):
                yarray_name = 'norm_mback'
                ylabel = "%s (MBACK)" % ylabel
            elif norm_method.startswith('area'):
                yarray_name = 'norm_area'
                ylabel = "%s (Area)" % ylabel
        voff = self.wids['plot_voff'].GetValue()
        for ix, checked in enumerate(group_ids):
            yoff = ix * voff
            groupname = self.controller.file_groups[str(checked)]
            dgroup = self.controller.get_group(groupname)
            plot_yarrays = [(yarray_name, PLOTOPTS_1, dgroup.filename)]
            if dgroup is not None:
                dgroup.plot_extras = []
                self.plot(dgroup,
                          title='',
                          new=newplot,
                          multi=True,
                          yoff=yoff,
                          plot_yarrays=plot_yarrays,
                          with_extras=False,
                          delay_draw=True)
                newplot = False
        ppanel = self.controller.get_display(stacked=False).panel
        ppanel.conf.show_legend = True
        ppanel.conf.draw_legend()
        ppanel.unzoom_all()

    def onAutoNorm(self, evt=None):
        dgroup = self.controller.get_group()
        try:
            norm2 = max(dgroup.energy) - dgroup.e0
            norm1 = 5.0 * int(norm2 / 15.0)
            nnorm = 2
            if (norm2 - norm1 < 350): nnorm = 1
            if (norm2 - norm1 < 50): nnorm = 0
        except:
            nnorm = None
        self.wids['auto_step'].SetValue(1)
        self.wids['auto_e0'].SetValue(1)
        self.wids['nvict'].SetSelection(0)
        self.wids['pre1'].SetValue(0)
        self.wids['pre2'].SetValue(0)
        self.wids['norm1'].SetValue(0)
        self.wids['norm2'].SetValue(0)
        if nnorm is not None:
            self.set_nnorm_widget(nnorm)
        self.wids['norm_method'].SetSelection(0)
        self.onReprocess()

    def onCopyAuto(self, evt=None):
        opts = dict(pre1=0,
                    pre2=0,
                    nvict=0,
                    norm1=0,
                    norm2=0,
                    norm_method='polynomial',
                    nnorm=2,
                    auto_e0=1,
                    auto_step=1)
        for checked in self.controller.filelist.GetCheckedStrings():
            groupname = self.controller.file_groups[str(checked)]
            grp = self.controller.get_group(groupname)
            if grp != self.controller.group and not grp.is_frozen:
                # try:
                #    norm2 = max(grp.energy) - grp.e0
                #     norm1 = 5.0*int(norm2/15.0)
                #    nnorm = 2
                #    if (norm2-norm1 < 350): nnorm = 1
                #    if (norm2-norm1 < 50): nnorm = 0
                # except:
                #    nnorm = 1
                # opts['nnorm'] = nnorm
                self.update_config(opts, dgroup=grp)
                self.fill_form(grp)
                self.process(grp, noskip=True)

    def onSaveConfigBtn(self, evt=None):
        conf = self.get_config()
        conf.update(self.read_form())
        self.set_defaultconfig(conf)

    def onCopyParam(self, name=None, evt=None):
        conf = self.get_config()
        form = self.read_form()
        conf.update(form)
        dgroup = self.controller.get_group()
        self.update_config(conf)
        self.fill_form(dgroup)
        opts = {}
        name = str(name)

        def copy_attrs(*args):
            for a in args:
                opts[a] = conf[a]

        if name == 'plotone_op':
            copy_attrs('plotone_op')
        elif name == 'xas_e0':
            copy_attrs('e0', 'show_e0', 'auto_e0')
        elif name == 'xas_step':
            copy_attrs('edge_step', 'auto_step')
        elif name == 'xas_pre':
            copy_attrs('pre1', 'pre2', 'nvict')
        elif name == 'atsym':
            copy_attrs('atsym', 'edge')
        elif name == 'xas_norm':
            copy_attrs('norm_method', 'nnorm', 'norm1', 'norm2')

        for checked in self.controller.filelist.GetCheckedStrings():
            groupname = self.controller.file_groups[str(checked)]
            grp = self.controller.get_group(groupname)
            if grp != self.controller.group and not grp.is_frozen:
                self.update_config(opts, dgroup=grp)
                self.fill_form(grp)
                self.process(grp, noskip=True)

    def onSet_XASE0(self, evt=None, value=None):
        "handle setting auto e0 / show e0"
        auto_e0 = self.wids['auto_e0'].GetValue()
        self.update_config({
            'e0': self.wids['e0'].GetValue(),
            'auto_e0': self.wids['auto_e0'].GetValue()
        })
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onSet_XASE0Val(self, evt=None, value=None):
        "handle setting e0"
        self.wids['auto_e0'].SetValue(0)
        self.update_config({
            'e0': self.wids['e0'].GetValue(),
            'auto_e0': self.wids['auto_e0'].GetValue()
        })
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onSet_XASStep(self, evt=None, value=None):
        "handle setting edge step"
        edge_step = self.wids['step'].GetValue()
        if edge_step < 0:
            self.wids['step'].SetValue(abs(edge_step))
        self.wids['auto_step'].SetValue(0)
        self.update_config({'edge_step': abs(edge_step), 'auto_step': False})
        autoset_fs_increment(self.wids['step'], abs(edge_step))
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onSet_Scale(self, evt=None, value=None):
        "handle setting non-XAFS scale value"
        self.update_config({'scale': self.wids['scale'].GetValue()})
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onSet_Ranges(self, evt=None, **kws):
        conf = {}
        for attr in ('pre1', 'pre2', 'norm1', 'norm2'):
            conf[attr] = self.wids[attr].GetValue()
        self.update_config(conf)
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onSelPoint(self, evt=None, opt='__', relative_e0=True, win=None):
        """
        get last selected point from a specified plot window
        and fill in the value for the widget defined by `opt`.

        by default it finds the latest cursor position from the
        cursor history of the first 20 plot windows.
        """
        if opt not in self.wids:
            return None

        _x, _y = last_cursor_pos(win=win, _larch=self.larch)
        if _x is None:
            return
        e0 = self.wids['e0'].GetValue()
        if opt == 'e0':
            self.wids['e0'].SetValue(_x)
            self.wids['auto_e0'].SetValue(0)
        elif opt in ('pre1', 'pre2', 'norm1', 'norm2'):
            self.wids[opt].SetValue(_x - e0)
        time.sleep(0.01)
        wx.CallAfter(self.onReprocess)

    def onReprocess(self, evt=None, value=None, **kws):
        "handle request reprocess"
        if self.skip_process:
            return
        try:
            dgroup = self.controller.get_group()
        except TypeError:
            return
        if not hasattr(dgroup, self.configname):
            return
        form = self.read_form()
        self.process(dgroup=dgroup)
        self.plot(dgroup)

    def make_dnormde(self, dgroup):
        form = dict(group=dgroup.groupname)
        self.larch_eval(
            "{group:s}.dnormde={group:s}.dmude/{group:s}.edge_step".format(
                **form))

    def process(self, dgroup=None, force_mback=False, noskip=False, **kws):
        """ handle process (pre-edge/normalize) of XAS data from XAS form
        """
        if self.skip_process and not noskip:
            return
        if dgroup is None:
            dgroup = self.controller.get_group()
        if dgroup is None:
            return
        self.skip_process = True
        conf = self.get_config(dgroup)
        dgroup.custom_plotopts = {}

        form = self.read_form()
        form['group'] = dgroup.groupname

        if dgroup.datatype != 'xas':
            self.skip_process = False
            dgroup.mu = dgroup.ydat * 1.0
            opts = {'group': dgroup.groupname, 'scale': conf.get('scale', 1.0)}
            self.larch_eval("{group:s}.scale = {scale:.8f}".format(**opts))
            self.larch_eval(
                "{group:s}.norm = {scale:.8f}*{group:s}.ydat".format(**opts))
            return

        en_units = getattr(dgroup, 'energy_units', None)
        if en_units is None:
            en_units = guess_energy_units(dgroup.energy)

        if en_units != 'eV':
            mono_dspace = getattr(dgroup, 'mono_dspace', 1)
            dlg = EnergyUnitsDialog(self.parent,
                                    dgroup.energy,
                                    unitname=en_units,
                                    dspace=mono_dspace)
            res = dlg.GetResponse()
            dlg.Destroy()
            if res.ok:
                en_units = res.units
                dgroup.mono_dspace = res.dspace
                dgroup.xdat = dgroup.energy = res.energy
        dgroup.energy_units = en_units

        e0 = form['e0']
        edge_step = form['edge_step']

        copts = [dgroup.groupname]
        if not form['auto_e0']:
            if e0 < max(dgroup.energy) and e0 > min(dgroup.energy):
                copts.append("e0=%.4f" % float(e0))

        if not form['auto_step']:
            copts.append("step=%s" % gformat(float(edge_step)))

        for attr in ('pre1', 'pre2', 'nvict', 'nnorm', 'norm1', 'norm2'):
            if form[attr] is None:
                copts.append("%s=None" % attr)
            else:
                copts.append("%s=%.2f" % (attr, form[attr]))

        self.larch_eval("pre_edge(%s)" % (', '.join(copts)))
        self.larch_eval(
            "{group:s}.norm_poly = 1.0*{group:s}.norm".format(**form))

        norm_method = form['norm_method'].lower()
        form['normmeth'] = 'poly'
        if force_mback or norm_method.startswith('mback'):
            form['normmeth'] = 'mback'
            copts = [dgroup.groupname]
            copts.append("z=%d" % atomic_number(form['atsym']))
            copts.append("edge='%s'" % form['edge'])
            for attr in ('pre1', 'pre2', 'nvict', 'nnorm', 'norm1', 'norm2'):
                if form[attr] is None:
                    copts.append("%s=None" % attr)
                else:
                    copts.append("%s=%.2f" % (attr, form[attr]))

            self.larch_eval("mback_norm(%s)" % (', '.join(copts)))

            if form['auto_step']:
                norm_expr = """{group:s}.norm = 1.0*{group:s}.norm_{normmeth:s}
{group:s}.edge_step = 1.0*{group:s}.edge_step_{normmeth:s}"""
                self.larch_eval(norm_expr.format(**form))
            else:
                norm_expr = """{group:s}.norm = 1.0*{group:s}.norm_{normmeth:s}
{group:s}.norm *= {group:s}.edge_step_{normmeth:s}/{edge_step:.8f}"""
                self.larch_eval(norm_expr.format(**form))

        if norm_method.startswith('area'):
            form['normmeth'] = 'area'
            expr = """{group:s}.norm = 1.0*{group:s}.norm_{normmeth:s}
{group:s}.edge_step = 1.0*{group:s}.edge_step_{normmeth:s}"""
            self.larch_eval(expr.format(**form))

        self.make_dnormde(dgroup)

        if form['auto_e0']:
            self.wids['e0'].SetValue(dgroup.e0)
        if form['auto_step']:
            self.wids['step'].SetValue(dgroup.edge_step)
            autoset_fs_increment(self.wids['step'], dgroup.edge_step)

        self.wids['atsym'].SetStringSelection(dgroup.atsym)
        self.wids['edge'].SetStringSelection(dgroup.edge)

        self.set_nnorm_widget(dgroup.pre_edge_details.nnorm)
        for attr in ('e0', 'edge_step'):
            conf[attr] = getattr(dgroup, attr)
        for attr in ('pre1', 'pre2', 'norm1', 'norm2'):
            conf[attr] = val = getattr(dgroup.pre_edge_details, attr, None)
            if val is not None:
                self.wids[attr].SetValue(val)

        if hasattr(dgroup, 'mback_params'):  # from mback
            conf['atsym'] = getattr(dgroup.mback_params, 'atsym')
            conf['edge'] = getattr(dgroup.mback_params, 'edge')
        self.update_config(conf, dgroup=dgroup)
        wx.CallAfter(self.unset_skip_process)

    def get_plot_arrays(self, dgroup):
        lab = plotlabels.norm
        if dgroup is None:
            return

        dgroup.plot_y2label = None
        dgroup.plot_xlabel = plotlabels.energy
        dgroup.plot_yarrays = [('norm', PLOTOPTS_1, lab)]

        if dgroup.datatype != 'xas':
            pchoice = PlotOne_Choices_nonxas[
                self.plotone_op.GetStringSelection()]
            dgroup.plot_xlabel = 'x'
            dgroup.plot_ylabel = 'y'
            dgroup.plot_yarrays = [('ydat', PLOTOPTS_1, 'ydat')]
            dgroup.dmude = np.gradient(dgroup.ydat) / np.gradient(dgroup.xdat)
            if not hasattr(dgroup, 'scale'):
                dgroup.scale = 1.0

            dgroup.norm = dgroup.ydat * dgroup.scale
            if pchoice == 'dmude':
                dgroup.plot_ylabel = 'dy/dx'
                dgroup.plot_yarrays = [('dmude', PLOTOPTS_1, 'dy/dx')]
            elif pchoice == 'norm':
                dgroup.plot_ylabel = 'scaled y'
                dgroup.plot_yarrays = [('norm', PLOTOPTS_1, 'y/scale')]
            elif pchoice == 'norm+dnormde':
                lab = plotlabels.norm
                dgroup.plot_y2label = 'dy/dx'
                dgroup.plot_yarrays = [('ydat', PLOTOPTS_1, 'y'),
                                       ('dnormde', PLOTOPTS_D, 'dy/dx')]
            return

        req_attrs = ['e0', 'norm', 'dmude', 'pre_edge']

        pchoice = PlotOne_Choices[self.plotone_op.GetStringSelection()]
        if pchoice in ('mu', 'norm', 'flat', 'dmude'):
            lab = getattr(plotlabels, pchoice)
            dgroup.plot_yarrays = [(pchoice, PLOTOPTS_1, lab)]

        elif pchoice == 'prelines':
            dgroup.plot_yarrays = [('mu', PLOTOPTS_1, plotlabels.mu),
                                   ('pre_edge', PLOTOPTS_2, 'pre edge'),
                                   ('post_edge', PLOTOPTS_2, 'post edge')]
        elif pchoice == 'preedge':
            lab = r'pre-edge subtracted $\mu$'
            dgroup.pre_edge_sub = dgroup.norm * dgroup.edge_step
            dgroup.plot_yarrays = [('pre_edge_sub', PLOTOPTS_1, lab)]

        elif pchoice == 'mu+dmude':
            lab = plotlabels.mu
            lab2 = plotlabels.dmude
            dgroup.plot_yarrays = [('mu', PLOTOPTS_1, lab),
                                   ('dmude', PLOTOPTS_D, lab2)]
            dgroup.plot_y2label = lab2

        elif pchoice == 'norm+dnormde':
            lab = plotlabels.norm
            lab2 = plotlabels.dmude + ' (normalized)'
            dgroup.plot_yarrays = [('norm', PLOTOPTS_1, lab),
                                   ('dnormde', PLOTOPTS_D, lab2)]
            dgroup.plot_y2label = lab2

        elif pchoice == 'mback_norm':
            req_attrs.append('mback_norm')
            lab = r'$\mu$'
            if not hasattr(dgroup, 'mback_mu'):
                self.process(dgroup=dgroup, force_mback=True)
            dgroup.plot_yarrays = [('mu', PLOTOPTS_1, lab),
                                   ('mback_mu', PLOTOPTS_2,
                                    r'tabulated $\mu(E)$')]

        elif pchoice == 'mback_poly':
            req_attrs.append('mback_norm')
            lab = plotlabels.norm
            if not hasattr(dgroup, 'mback_mu'):
                self.process(dgroup=dgroup, force_mback=True)
            dgroup.plot_yarrays = [('norm_mback', PLOTOPTS_1, 'mback'),
                                   ('norm_poly', PLOTOPTS_2, 'polynomial')]

        elif pchoice == 'area_norm':
            dgroup.plot_yarrays = [('norm_area', PLOTOPTS_1, 'area'),
                                   ('norm_poly', PLOTOPTS_2, 'polynomial')]

        dgroup.plot_ylabel = lab
        needs_proc = False
        for attr in req_attrs:
            needs_proc = needs_proc or (not hasattr(dgroup, attr))

        if needs_proc:
            self.process(dgroup=dgroup, noskip=True)

        y4e0 = dgroup.ydat = getattr(dgroup, dgroup.plot_yarrays[0][0],
                                     dgroup.mu)
        dgroup.plot_extras = []

        if self.wids['showe0'].IsChecked():
            ie0 = index_of(dgroup.energy, dgroup.e0)
            dgroup.plot_extras.append(('marker', dgroup.e0, y4e0[ie0], {}))

    def plot(self,
             dgroup,
             title=None,
             plot_yarrays=None,
             yoff=0,
             delay_draw=False,
             multi=False,
             new=True,
             zoom_out=True,
             with_extras=True,
             **kws):
        if self.skip_plotting:
            return
        ppanel = self.controller.get_display(stacked=False).panel
        viewlims = ppanel.get_viewlimits()
        plotcmd = ppanel.oplot
        if new:
            plotcmd = ppanel.plot

        groupname = getattr(dgroup, 'groupname', None)
        if groupname is None:
            return

        if not hasattr(dgroup, 'xdat'):
            print("Cannot plot group ", groupname)

        if ((getattr(dgroup, 'plot_yarrays', None) is None
             or getattr(dgroup, 'energy', None) is None
             or getattr(dgroup, 'mu', None) is None
             or getattr(dgroup, 'e0', None) is None
             or getattr(dgroup, 'dmude', None) is None
             or getattr(dgroup, 'norm', None) is None)):
            self.process(dgroup=dgroup)
        self.get_plot_arrays(dgroup)

        if plot_yarrays is None and hasattr(dgroup, 'plot_yarrays'):
            plot_yarrays = dgroup.plot_yarrays

        popts = kws
        path, fname = os.path.split(dgroup.filename)
        if 'label' not in popts:
            popts['label'] = dgroup.plot_ylabel

        zoom_out = (zoom_out or min(dgroup.xdat) >= viewlims[1]
                    or max(dgroup.xdat) <= viewlims[0]
                    or min(dgroup.ydat) >= viewlims[3]
                    or max(dgroup.ydat) <= viewlims[2])

        if not zoom_out:
            popts['xmin'] = viewlims[0]
            popts['xmax'] = viewlims[1]
            popts['ymin'] = viewlims[2]
            popts['ymax'] = viewlims[3]

        popts['xlabel'] = dgroup.plot_xlabel
        popts['ylabel'] = dgroup.plot_ylabel
        if getattr(dgroup, 'plot_y2label', None) is not None:
            popts['y2label'] = dgroup.plot_y2label

        plot_choices = PlotSel_Choices
        if dgroup.datatype != 'xas':
            plot_choices = PlotSel_Choices_nonxas

        if multi:
            ylabel = self.plotsel_op.GetStringSelection()
            yarray_name = plot_choices[ylabel]
            if dgroup.datatype == 'xas':
                ylabel = getattr(plotlabels, yarray_name, ylabel)
            popts['ylabel'] = ylabel

        plot_extras = None
        if new:
            if title is None:
                title = fname
            plot_extras = getattr(dgroup, 'plot_extras', None)

        popts['title'] = title
        popts['delay_draw'] = delay_draw
        if hasattr(dgroup, 'custom_plotopts'):
            popts.update(dgroup.custom_plotopts)

        popts['show_legend'] = len(plot_yarrays) > 1
        narr = len(plot_yarrays) - 1
        for i, pydat in enumerate(plot_yarrays):
            yaname, yopts, yalabel = pydat
            popts.update(yopts)
            if yalabel is not None:
                popts['label'] = yalabel

            popts['delay_draw'] = delay_draw or (i != narr)
            if yaname == 'dnormde' and not hasattr(dgroup, yaname):
                self.make_dnormde(dgroup)
            if yaname == 'norm_mback' and not hasattr(dgroup, yaname):
                self.process(dgroup=dgroup, noskip=True, force_mback=True)

            plotcmd(dgroup.xdat, getattr(dgroup, yaname) + yoff, **popts)
            plotcmd = ppanel.oplot

        if with_extras and plot_extras is not None:
            axes = ppanel.axes
            for etype, x, y, opts in plot_extras:
                if etype == 'marker':
                    xpopts = {
                        'marker': 'o',
                        'markersize': 4,
                        'label': '_nolegend_',
                        'markerfacecolor': 'red',
                        'markeredgecolor': '#884444'
                    }
                    xpopts.update(opts)
                    axes.plot([x], [y], **xpopts)
                elif etype == 'vline':
                    xpopts = {
                        'ymin': 0,
                        'ymax': 1.0,
                        'label': '_nolegend_',
                        'color': '#888888'
                    }
                    xpopts.update(opts)
                    axes.axvline(x, **xpopts)
        if not popts['delay_draw']:
            ppanel.canvas.draw()
Example #2
0
class XASNormPanel(TaskPanel):
    """XAS normalization Panel"""
    def __init__(self, parent, controller=None, **kws):
        TaskPanel.__init__(self,
                           parent,
                           controller,
                           configname='xasnorm_config',
                           config=defaults,
                           **kws)

    def build_display(self):
        titleopts = dict(font=Font(12), colour='#AA0000')

        xas = self.panel
        self.wids = {}

        self.plotone_op = Choice(xas,
                                 choices=list(PlotOne_Choices.keys()),
                                 action=self.onPlotOne,
                                 size=(175, -1))
        self.plotsel_op = Choice(xas,
                                 choices=list(PlotSel_Choices.keys()),
                                 action=self.onPlotSel,
                                 size=(175, -1))

        self.plotone_op.SetSelection(1)
        self.plotsel_op.SetSelection(1)

        plot_one = Button(xas,
                          'Plot This Group',
                          size=(150, -1),
                          action=self.onPlotOne)

        plot_sel = Button(xas,
                          'Plot Selected Groups',
                          size=(150, -1),
                          action=self.onPlotSel)

        opts = dict(action=self.onReprocess)

        e0opts_panel = wx.Panel(xas)
        self.wids['autoe0'] = Check(e0opts_panel,
                                    default=True,
                                    label='auto?',
                                    **opts)
        self.wids['showe0'] = Check(e0opts_panel,
                                    default=True,
                                    label='show?',
                                    **opts)
        sx = wx.BoxSizer(wx.HORIZONTAL)
        sx.Add(self.wids['autoe0'], 0, LCEN, 4)
        sx.Add(self.wids['showe0'], 0, LCEN, 4)
        pack(e0opts_panel, sx)

        self.wids['autostep'] = Check(xas, default=True, label='auto?', **opts)

        opts['size'] = (50, -1)
        self.wids['vict'] = Choice(xas, choices=('0', '1', '2', '3'), **opts)
        self.wids['nnor'] = Choice(xas, choices=('0', '1', '2', '3'), **opts)
        self.wids['vict'].SetSelection(1)
        self.wids['nnor'].SetSelection(1)

        opts.update({'size': (100, -1), 'digits': 2, 'increment': 5.0})

        xas_pre1 = self.add_floatspin('pre1', value=-1000, **opts)
        xas_pre2 = self.add_floatspin('pre2', value=-30, **opts)
        xas_nor1 = self.add_floatspin('nor1', value=50, **opts)
        xas_nor2 = self.add_floatspin('nor2', value=5000, **opts)

        opts = {'digits': 2, 'increment': 0.1, 'value': 0}
        xas_e0 = self.add_floatspin('e0', action=self.onSet_XASE0, **opts)
        xas_step = self.add_floatspin('step',
                                      action=self.onSet_XASStep,
                                      with_pin=False,
                                      **opts)

        saveconf = Button(xas,
                          'Save as Default Settings',
                          size=(200, -1),
                          action=self.onSaveConfigBtn)

        def CopyBtn(name):
            return Button(xas,
                          'Copy',
                          size=(50, -1),
                          action=partial(self.onCopyParam, name))

        add_text = self.add_text

        xas.Add(SimpleText(xas, ' XAS Pre-edge subtraction and Normalization',
                           **titleopts),
                dcol=4)
        xas.Add(SimpleText(xas, 'Copy to Selected Groups?'),
                style=RCEN,
                dcol=3)

        xas.Add(plot_sel, newrow=True)
        xas.Add(self.plotsel_op, dcol=6)

        xas.Add(plot_one, newrow=True)
        xas.Add(self.plotone_op, dcol=4)
        xas.Add((10, 10))
        xas.Add(CopyBtn('plotone_op'), style=RCEN)

        add_text('E0 : ')
        xas.Add(xas_e0)
        xas.Add(e0opts_panel, dcol=3)
        xas.Add((10, 1))
        xas.Add(CopyBtn('xas_e0'), style=RCEN)

        add_text('Edge Step: ')
        xas.Add(xas_step)
        xas.Add(self.wids['autostep'], dcol=3)
        xas.Add((10, 1))
        xas.Add(CopyBtn('xas_step'), style=RCEN)

        add_text('Pre-edge range: ')
        xas.Add(xas_pre1)
        add_text(' : ', newrow=False)
        xas.Add(xas_pre2)
        xas.Add(SimpleText(xas, 'Victoreen:'))
        xas.Add(self.wids['vict'])
        xas.Add(CopyBtn('xas_pre'), style=RCEN)

        add_text('Normalization range: ')
        xas.Add(xas_nor1)
        add_text(' : ', newrow=False)
        xas.Add(xas_nor2)
        xas.Add(SimpleText(xas, 'Poly Order:'))
        xas.Add(self.wids['nnor'])
        xas.Add(CopyBtn('xas_norm'), style=RCEN)

        xas.Add(saveconf, dcol=6, newrow=True)
        xas.pack()

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add((5, 5), 0, LCEN, 3)
        sizer.Add(HLine(self, size=(550, 2)), 0, LCEN, 3)
        sizer.Add(xas, 0, LCEN, 3)
        sizer.Add((5, 5), 0, LCEN, 3)
        sizer.Add(HLine(self, size=(550, 2)), 0, LCEN, 3)
        pack(self, sizer)

    def get_config(self, dgroup=None):
        """custom get_config to possibly inherit from Athena settings"""
        if dgroup is None:
            dgroup = self.controller.get_group()

        if hasattr(dgroup, self.configname):
            conf = getattr(dgroup, self.configname)
        else:
            conf = self.get_defaultconfig()
            if hasattr(dgroup, 'bkg_params'):  # from Athena
                conf['e0'] = getattr(dgroup.bkg_params, 'e0', conf['e0'])
                conf['pre1'] = getattr(dgroup.bkg_params, 'pre1', conf['pre1'])
                conf['pre2'] = getattr(dgroup.bkg_params, 'pre2', conf['pre2'])
                conf['norm1'] = getattr(dgroup.bkg_params, 'nor1',
                                        conf['norm1'])
                conf['norm2'] = getattr(dgroup.bkg_params, 'nor2',
                                        conf['norm2'])
                conf['nnorm'] = getattr(dgroup.bkg_params, 'nnor',
                                        conf['nnorm'])
                conf['nvict'] = getattr(dgroup.bkg_params, 'nvic',
                                        conf['nvict'])
                conf['autostep'] = (float(
                    getattr(dgroup.bkg_params, 'fixstep', 0.0)) < 0.5)

        setattr(dgroup, self.configname, conf)
        return conf

    def fill_form(self, dgroup):
        """fill in form from a data group"""
        opts = self.get_config(dgroup)
        self.skip_process = True

        if dgroup.datatype == 'xas':
            for k in self.wids.values():
                k.Enable()

            self.plotone_op.SetChoices(list(PlotOne_Choices.keys()))
            self.plotsel_op.SetChoices(list(PlotSel_Choices.keys()))

            self.plotone_op.SetStringSelection(opts['plotone_op'])
            self.plotsel_op.SetStringSelection(opts['plotsel_op'])
            self.wids['e0'].SetValue(opts['e0'])
            edge_step = opts.get('edge_step', None)
            if edge_step is None:
                edge_step = 1.0

            ndigits = int(2 - round(np.log10(abs(edge_step))))
            self.wids['step'].SetDigits(ndigits + 1)
            self.wids['step'].SetIncrement(0.2 * 10**(-ndigits))
            self.wids['step'].SetValue(edge_step)

            self.wids['pre1'].SetValue(opts['pre1'])
            self.wids['pre2'].SetValue(opts['pre2'])
            self.wids['nor1'].SetValue(opts['norm1'])
            self.wids['nor2'].SetValue(opts['norm2'])
            self.wids['vict'].SetSelection(opts['nvict'])
            self.wids['nnor'].SetSelection(opts['nnorm'])
            self.wids['showe0'].SetValue(opts['show_e0'])
            self.wids['autoe0'].SetValue(opts['auto_e0'])
            self.wids['autostep'].SetValue(opts['auto_step'])
        else:
            self.plotone_op.SetChoices(list(PlotOne_Choices_nonxas.keys()))
            self.plotsel_op.SetChoices(list(PlotSel_Choices_nonxas.keys()))
            self.plotone_op.SetStringSelection('Raw Data')
            self.plotsel_op.SetStringSelection('Raw Data')
            for k in self.wids.values():
                k.Disable()
        self.skip_process = False

    def read_form(self):
        "read form, return dict of values"
        form_opts = {}
        form_opts['e0'] = self.wids['e0'].GetValue()
        form_opts['edge_step'] = self.wids['step'].GetValue()
        form_opts['pre1'] = self.wids['pre1'].GetValue()
        form_opts['pre2'] = self.wids['pre2'].GetValue()
        form_opts['norm1'] = self.wids['nor1'].GetValue()
        form_opts['norm2'] = self.wids['nor2'].GetValue()
        form_opts['nnorm'] = int(self.wids['nnor'].GetSelection())
        form_opts['nvict'] = int(self.wids['vict'].GetSelection())

        form_opts['plotone_op'] = self.plotone_op.GetStringSelection()
        form_opts['plotsel_op'] = self.plotsel_op.GetStringSelection()

        form_opts['show_e0'] = self.wids['showe0'].IsChecked()
        form_opts['auto_e0'] = self.wids['autoe0'].IsChecked()
        form_opts['auto_step'] = self.wids['autostep'].IsChecked()

        return form_opts

    def onPlotOne(self, evt=None):
        self.plot(self.controller.get_group())

    def onPlotSel(self, evt=None):
        newplot = True
        group_ids = self.controller.filelist.GetCheckedStrings()
        if len(group_ids) < 1:
            return
        last_id = group_ids[-1]

        yarray_name = PlotSel_Choices[self.plotsel_op.GetStringSelection()]
        ylabel = getattr(plotlabels, yarray_name)

        for checked in group_ids:
            groupname = self.controller.file_groups[str(checked)]
            dgroup = self.controller.get_group(groupname)
            plot_yarrays = [(yarray_name, PLOTOPTS_1, dgroup.filename)]
            if dgroup is not None:
                dgroup.plot_extras = []
                self.plot(dgroup,
                          title='',
                          new=newplot,
                          multi=True,
                          plot_yarrays=plot_yarrays,
                          show_legend=True,
                          with_extras=False,
                          delay_draw=(last_id != checked))
                newplot = False
        self.controller.get_display(stacked=False).panel.canvas.draw()

    def onSaveConfigBtn(self, evt=None):
        conf = self.get_config()
        conf.update(self.read_form())
        self.set_defaultconfig(conf)

    def onCopyParam(self, name=None, evt=None):
        conf = self.get_config()
        conf.update(self.read_form())
        opts = {}
        name = str(name)

        def copy_attrs(*args):
            for a in args:
                opts[a] = conf[a]

        if name == 'plotone_op':
            copy_attrs('plotone_op')
        elif name == 'xas_e0':
            copy_attrs('e0', 'show_e0', 'auto_e0')
        elif name == 'xas_step':
            copy_attrs('edge_step', 'auto_step')
        elif name == 'xas_pre':
            copy_attrs('pre1', 'pre2', 'nvict')
        elif name == 'xas_norm':
            copy_attrs('nnorm', 'norm1', 'norm2')

        for checked in self.controller.filelist.GetCheckedStrings():
            groupname = self.controller.file_groups[str(checked)]
            grp = self.controller.get_group(groupname)
            if grp != self.controller.group:
                self.set_config(grp, opts)
                self.fill_form(grp)
                self.process(grp)

    def onSet_XASE0(self, evt=None, value=None):
        "handle setting e0"
        self.wids['autoe0'].SetValue(0)
        self.onReprocess()

    def onSet_XASStep(self, evt=None, value=None):
        "handle setting edge step"
        self.wids['autostep'].SetValue(0)
        self.onReprocess()

    def onSelPoint(self, evt=None, opt='__', win=None):
        """
        get last selected point from a specified plot window
        and fill in the value for the widget defined by `opt`.

        by default it finds the latest cursor position from the
        cursor history of the first 20 plot windows.
        """
        if opt not in self.wids:
            return None

        _x, _y = last_cursor_pos(win=win, _larch=self.larch)
        if _x is None:
            return

        e0 = self.wids['e0'].GetValue()
        if opt == 'e0':
            self.wids['e0'].SetValue(_x)
            self.wids['autoe0'].SetValue(0)
        elif opt in ('pre1', 'pre2', 'nor1', 'nor2'):
            self.wids[opt].SetValue(_x - e0)

        self.onReprocess()

    def onReprocess(self, evt=None, value=None, **kws):
        "handle request reprocess"
        if self.skip_process:
            return
        try:
            dgroup = self.controller.get_group()
        except TypeError:
            return
        self.process(dgroup=dgroup)
        self.plot(dgroup)

    def make_dnormde(self, dgroup):
        form = dict(group=dgroup.groupname)
        self.larch_eval(
            "{group:s}.dnormde={group:s}.dmude/{group:s}.edge_step".format(
                **form))

    def process(self, dgroup=None, **kws):
        """ handle process (pre-edge/normalize) of XAS data from XAS form
        """
        if self.skip_process:
            return

        if dgroup is None:
            dgroup = self.controller.get_group()

        self.skip_process = True
        # print("process ", dgroup, dgroup.filename)
        self.get_config(dgroup)

        dgroup.custom_plotopts = {}
        # print("XAS norm process ", dgroup.datatype)

        if dgroup.datatype != 'xas':
            self.skip_process = False
            dgroup.mu = dgroup.ydat * 1.0
            return

        en_units = getattr(dgroup, 'energy_units', None)
        if en_units is None:
            en_units = 'eV'
            units = guess_energy_units(dgroup.energy)

            if units != 'eV':
                dlg = EnergyUnitsDialog(self.parent, units, dgroup.energy)
                res = dlg.GetResponse()
                dlg.Destroy()
                if res.ok:
                    en_units = res.units
                    dgroup.xdat = dgroup.energy = res.energy
            dgroup.energy_units = en_units

        form = self.read_form()
        e0 = form['e0']
        edge_step = form['edge_step']

        form['group'] = dgroup.groupname

        copts = [dgroup.groupname]
        if not form['auto_e0']:
            if e0 < max(dgroup.energy) and e0 > min(dgroup.energy):
                copts.append("e0=%.4f" % float(e0))

        if not form['auto_step']:
            copts.append("step=%.4f" % float(edge_step))

        for attr in ('pre1', 'pre2', 'nvict', 'nnorm', 'norm1', 'norm2'):
            copts.append("%s=%.2f" % (attr, form[attr]))

        self.larch_eval("pre_edge(%s)" % (', '.join(copts)))
        self.make_dnormde(dgroup)

        if form['auto_e0']:
            self.wids['e0'].SetValue(dgroup.e0)  # , act=False)
        if form['auto_step']:
            self.wids['step'].SetValue(dgroup.edge_step)  # , act=False)

        self.wids['pre1'].SetValue(dgroup.pre_edge_details.pre1)
        self.wids['pre2'].SetValue(dgroup.pre_edge_details.pre2)
        self.wids['nor1'].SetValue(dgroup.pre_edge_details.norm1)
        self.wids['nor2'].SetValue(dgroup.pre_edge_details.norm2)

        conf = {}
        for attr in ('e0', 'edge_step'):
            conf[attr] = getattr(dgroup, attr)
        for attr in ('pre1', 'pre2', 'nnorm', 'norm1', 'norm2'):
            conf[attr] = getattr(dgroup.pre_edge_details, attr)

        self.set_config(dgroup, conf)
        self.skip_process = False

    def get_plot_arrays(self, dgroup):
        form = self.read_form()

        lab = plotlabels.norm
        if dgroup is None:
            return

        dgroup.plot_y2label = None
        dgroup.plot_xlabel = plotlabels.energy
        dgroup.plot_yarrays = [('norm', PLOTOPTS_1, lab)]

        if dgroup.datatype != 'xas':
            pchoice = PlotOne_Choices_nonxas[
                self.plotone_op.GetStringSelection()]
            dgroup.plot_xlabel = 'x'
            dgroup.plot_ylabel = 'y'
            dgroup.plot_yarrays = [('ydat', PLOTOPTS_1, 'ydat')]
            dgroup.dmude = np.gradient(dgroup.ydat) / np.gradient(dgroup.xdat)
            if pchoice == 'dmude':
                dgroup.plot_ylabel = 'dy/dx'
                dgroup.plot_yarrays = [('dmude', PLOTOPTS_1, 'dy/dx')]
            elif pchoice == 'norm+deriv':
                lab = plotlabels.norm
                dgroup.plot_y2label = 'dy/dx'
                dgroup.plot_yarrays = [('ydat', PLOTOPTS_1, 'y'),
                                       ('dmude', PLOTOPTS_D, 'dy/dx')]
            return

        pchoice = PlotOne_Choices[self.plotone_op.GetStringSelection()]
        if pchoice in ('mu', 'norm', 'flat', 'dmude'):
            lab = getattr(plotlabels, pchoice)
            dgroup.plot_yarrays = [(pchoice, PLOTOPTS_1, lab)]

        elif pchoice == 'prelines':
            dgroup.plot_yarrays = [('mu', PLOTOPTS_1, plotlabels.mu),
                                   ('pre_edge', PLOTOPTS_2, 'pre edge'),
                                   ('post_edge', PLOTOPTS_2, 'post edge')]
        elif pchoice == 'preedge':
            dgroup.pre_edge_sub = dgroup.norm * dgroup.edge_step
            dgroup.plot_yarrays = [('pre_edge_sub', PLOTOPTS_1,
                                    r'pre-edge subtracted $\mu$')]
            lab = r'pre-edge subtracted $\mu$'

        elif pchoice == 'norm+deriv':
            lab = plotlabels.norm
            lab2 = plotlabels.dmude
            dgroup.plot_yarrays = [('norm', PLOTOPTS_1, lab),
                                   ('dmude', PLOTOPTS_D, lab2)]
            dgroup.plot_y2label = lab2

        dgroup.plot_ylabel = lab
        y4e0 = dgroup.ydat = getattr(dgroup, dgroup.plot_yarrays[0][0],
                                     dgroup.mu)
        dgroup.plot_extras = []
        if form['show_e0']:
            ie0 = index_of(dgroup.energy, dgroup.e0)
            dgroup.plot_extras.append(('marker', dgroup.e0, y4e0[ie0], {}))

    def plot(self,
             dgroup,
             title=None,
             plot_yarrays=None,
             delay_draw=False,
             multi=False,
             new=True,
             zoom_out=True,
             with_extras=True,
             **kws):

        self.get_plot_arrays(dgroup)
        ppanel = self.controller.get_display(stacked=False).panel
        viewlims = ppanel.get_viewlimits()
        plotcmd = ppanel.oplot
        if new:
            plotcmd = ppanel.plot

        groupname = dgroup.groupname

        if not hasattr(dgroup, 'xdat'):
            print("Cannot plot group ", groupname)

        if ((getattr(dgroup, 'plot_yarrays', None) is None
             or getattr(dgroup, 'energy', None) is None
             or getattr(dgroup, 'mu', None) is None)):
            self.process(dgroup=dgroup)

        if plot_yarrays is None and hasattr(dgroup, 'plot_yarrays'):
            plot_yarrays = dgroup.plot_yarrays

        popts = kws
        path, fname = os.path.split(dgroup.filename)
        if 'label' not in popts:
            popts['label'] = dgroup.plot_ylabel

        zoom_out = (zoom_out or min(dgroup.xdat) >= viewlims[1]
                    or max(dgroup.xdat) <= viewlims[0]
                    or min(dgroup.ydat) >= viewlims[3]
                    or max(dgroup.ydat) <= viewlims[2])

        if not zoom_out:
            popts['xmin'] = viewlims[0]
            popts['xmax'] = viewlims[1]
            popts['ymin'] = viewlims[2]
            popts['ymax'] = viewlims[3]

        popts['xlabel'] = dgroup.plot_xlabel
        popts['ylabel'] = dgroup.plot_ylabel
        if getattr(dgroup, 'plot_y2label', None) is not None:
            popts['y2label'] = dgroup.plot_y2label

        if multi:
            yarray_name = PlotSel_Choices[self.plotsel_op.GetStringSelection()]
            popts['ylabel'] = getattr(plotlabels, yarray_name)

        plot_extras = None
        if new:
            if title is None:
                title = fname
            plot_extras = getattr(dgroup, 'plot_extras', None)

        popts['title'] = title
        popts['delay_draw'] = delay_draw
        if hasattr(dgroup, 'custom_plotopts'):
            popts.update(dgroup.custom_plotopts)

        narr = len(plot_yarrays) - 1
        for i, pydat in enumerate(plot_yarrays):
            yaname, yopts, yalabel = pydat
            popts.update(yopts)
            if yalabel is not None:
                popts['label'] = yalabel

            popts['delay_draw'] = delay_draw or (i != narr)
            # print("plot:: ", i, popts['delay_draw'], plotcmd, popts)
            if yaname == 'dnormde' and not hasattr(dgroup, yaname):
                self.make_dnormde(dgroup)

            plotcmd(dgroup.xdat, getattr(dgroup, yaname), **popts)
            plotcmd = ppanel.oplot

        if with_extras and plot_extras is not None:
            axes = ppanel.axes
            for etype, x, y, opts in plot_extras:
                if etype == 'marker':
                    xpopts = {
                        'marker': 'o',
                        'markersize': 4,
                        'label': '_nolegend_',
                        'markerfacecolor': 'red',
                        'markeredgecolor': '#884444'
                    }
                    xpopts.update(opts)
                    axes.plot([x], [y], **xpopts)
                elif etype == 'vline':
                    xpopts = {
                        'ymin': 0,
                        'ymax': 1.0,
                        'label': '_nolegend_',
                        'color': '#888888'
                    }
                    xpopts.update(opts)
                    axes.axvline(x, **xpopts)
        if not popts['delay_draw']:
            ppanel.canvas.draw()
Example #3
0
class PrePeakPanel(wx.Panel):
    def __init__(self, parent=None, controller=None, **kws):

        wx.Panel.__init__(self, parent, -1, size=(550, 625), **kws)
        self.parent = parent
        self.controller = controller
        self.larch = controller.larch
        self.fit_components = OrderedDict()
        self.fit_model = None
        self.fit_params = None
        self.user_added_params = None
        self.summary = None
        self.sizer = wx.GridBagSizer(10, 6)
        self.build_display()
        self.pick2_timer = wx.Timer(self)
        self.pick2_group = None
        self.Bind(wx.EVT_TIMER, self.onPick2Timer, self.pick2_timer)
        self.pick2_t0 = 0.
        self.pick2_timeout = 15.

        self.pick2erase_timer = wx.Timer(self)
        self.pick2erase_panel = None
        self.Bind(wx.EVT_TIMER, self.onPick2EraseTimer, self.pick2erase_timer)

    def onPanelExposed(self, **kws):
        # called when notebook is selected
        try:
            fname = self.controller.filelist.GetStringSelection()
            gname = self.controller.file_groups[fname]
            dgroup = self.controller.get_group(gname)
            # print(" Fill prepeak panel from group ", fname, gname, dgroup)
            self.fill_form(dgroup)
        except:
            pass # print(" Cannot Fill prepeak panel from group ")

    def larch_eval(self, cmd):
        """eval"""
        self.controller.larch.eval(cmd)

    def build_display(self):
        self.mod_nb = flat_nb.FlatNotebook(self, -1, agwStyle=FNB_STYLE)
        self.mod_nb.SetTabAreaColour(wx.Colour(250,250,250))
        self.mod_nb.SetActiveTabColour(wx.Colour(254,254,195))

        self.mod_nb.SetNonActiveTabTextColour(wx.Colour(10,10,128))
        self.mod_nb.SetActiveTabTextColour(wx.Colour(128,0,0))
        self.mod_nb.Bind(wx.EVT_NOTEBOOK_PAGE_CHANGED, self.onNBChanged)

        pan = self.panel = GridPanel(self, ncols=4, nrows=4, pad=2, itemstyle=LCEN)

        self.wids = {}

        def FloatSpinWithPin(name, value, **kws):
            s = wx.BoxSizer(wx.HORIZONTAL)
            self.wids[name] = FloatSpin(pan, value=value, **kws)
            bb = BitmapButton(pan, get_icon('pin'), size=(25, 25),
                              action=partial(self.onSelPoint, opt=name),
                              tooltip='use last point selected from plot')
            s.Add(self.wids[name])
            s.Add(bb)
            return s

        opts = dict(digits=2, increment=0.1)
        ppeak_e0   = FloatSpinWithPin('ppeak_e0', value=0, **opts)
        ppeak_elo  = FloatSpinWithPin('ppeak_elo', value=-15, **opts)
        ppeak_ehi  = FloatSpinWithPin('ppeak_ehi', value=-5, **opts)
        ppeak_emin = FloatSpinWithPin('ppeak_emin', value=-30, **opts)
        ppeak_emax = FloatSpinWithPin('ppeak_emax', value=0, **opts)

        self.fitbline_btn  = Button(pan,'Fit Baseline', action=self.onFitBaseline,
                                    size=(150, 25))
        self.fitmodel_btn = Button(pan, 'Fit Model',
                                   action=self.onFitModel,  size=(150, 25))
        self.fitsel_btn = Button(pan, 'Fit Selected Groups',
                                 action=self.onFitSelected,  size=(150, 25))
        self.fitmodel_btn.Disable()
        self.fitsel_btn.Disable()

        self.array_choice = Choice(pan, size=(150, -1),
                                   choices=list(Array_Choices.keys()))
        self.array_choice.SetSelection(1)

        models_peaks = Choice(pan, size=(150, -1),
                              choices=ModelChoices['peaks'],
                              action=self.addModel)

        models_other = Choice(pan, size=(150, -1),
                              choices=ModelChoices['other'],
                              action=self.addModel)

        self.plot_choice = Choice(pan, size=(150, -1),
                                  choices=PlotChoices,
                                  action=self.onPlot)

        self.message = SimpleText(pan,
                                 'first fit baseline, then add peaks to fit model.')

        self.msg_centroid = SimpleText(pan, '----')

        opts = dict(default=True, size=(75, -1), action=self.onPlot)
        self.show_centroid  = Check(pan, label='show?', **opts)
        self.show_peakrange = Check(pan, label='show?', **opts)
        self.show_fitrange  = Check(pan, label='show?', **opts)
        self.show_e0        = Check(pan, label='show?', **opts)

        opts = dict(default=False, size=(200, -1), action=self.onPlot)
        self.plot_sub_bline = Check(pan, label='Subtract Baseline?', **opts)

        def add_text(text, dcol=1, newrow=True):
            pan.Add(SimpleText(pan, text), dcol=dcol, newrow=newrow)

        titleopts = dict(font=Font(12), colour='#AA0000')
        pan.Add(SimpleText(pan, ' Pre-edge Peak Fitting', **titleopts), dcol=5)
        add_text(' Run Fit:', newrow=False)

        add_text('Array to fit: ')
        pan.Add(self.array_choice, dcol=3)
        pan.Add((10, 10))
        pan.Add(self.fitbline_btn)

        add_text('E0: ')
        pan.Add(ppeak_e0)
        pan.Add((10, 10), dcol=2)
        pan.Add(self.show_e0)
        pan.Add(self.fitmodel_btn)


        add_text('Fit Energy Range: ')
        pan.Add(ppeak_emin)
        add_text(' : ', newrow=False)
        pan.Add(ppeak_emax)
        pan.Add(self.show_fitrange)
        pan.Add(self.fitsel_btn)

        t = SimpleText(pan, 'Pre-edge Peak Range: ')
        t.SetToolTip('Range used as mask for background')

        pan.Add(t, newrow=True)
        pan.Add(ppeak_elo)
        add_text(' : ', newrow=False)
        pan.Add(ppeak_ehi)
        pan.Add(self.show_peakrange)

        add_text( 'Peak Centroid: ')
        pan.Add(self.msg_centroid, dcol=3)
        pan.Add(self.show_centroid, dcol=1)


        #  plot buttons
        ts = wx.BoxSizer(wx.HORIZONTAL)
        ts.Add(self.plot_choice)
        ts.Add(self.plot_sub_bline)

        pan.Add(SimpleText(pan, 'Plot: '), newrow=True)
        pan.Add(ts, dcol=7)

        #  add model
        ts = wx.BoxSizer(wx.HORIZONTAL)
        ts.Add(models_peaks)
        ts.Add(models_other)

        pan.Add(SimpleText(pan, 'Add Component: '), newrow=True)
        pan.Add(ts, dcol=7)

        pan.Add(SimpleText(pan, 'Messages: '), newrow=True)
        pan.Add(self.message, dcol=7)

        pan.pack()

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add((5,5), 0, LCEN, 3)
        sizer.Add(HLine(self, size=(550, 2)), 0, LCEN, 3)
        sizer.Add(pan,   0, LCEN, 3)
        sizer.Add((5,5), 0, LCEN, 3)
        sizer.Add(HLine(self, size=(550, 2)), 0, LCEN, 3)
        sizer.Add((5,5), 0, LCEN, 3)
        sizer.Add(self.mod_nb,  1, LCEN|wx.GROW, 10)

        pack(self, sizer)

    def get_config(self, dgroup=None):
        """get processing configuration for a group"""
        if dgroup is None:
            dgroup = self.controller.get_group()

        conf = getattr(dgroup, 'prepeak_config', {})
        if 'e0' not in conf:
            conf = dict(e0 = dgroup.e0, elo=-10, ehi=-5,
                        emin=-40, emax=0, yarray='norm')

        dgroup.prepeak_config = conf
        if not hasattr(dgroup, 'prepeaks'):
            dgroup.prepeaks = Group()

        return conf

    def fill_form(self, dat):
        if isinstance(dat, Group):
            self.wids['ppeak_e0'].SetValue(dat.e0)
            if hasattr(dat, 'prepeaks'):
                self.wids['ppeak_emin'].SetValue(dat.prepeaks.emin)
                self.wids['ppeak_emax'].SetValue(dat.prepeaks.emax)
                self.wids['ppeak_elo'].SetValue(dat.prepeaks.elo)
                self.wids['ppeak_ehi'].SetValue(dat.prepeaks.ehi)

        elif instance(dat, dict):
            self.wids['ppeak_e0'].SetValue(dat['e0'])
            self.wids['ppeak_emin'].SetValue(dat['emin'])
            self.wids['ppeak_emax'].SetValue(dat['emax'])
            self.wids['ppeak_elo'].SetValue(dat['elo'])
            self.wids['ppeak_ehi'].SetValue(dat['ehi'])

            self.array_choice.SetStringSelection(dat['array_desc'])
            self.show_e0.Enable(dat['show_e0'])
            self.show_centroid.Enable(dat['show_centroid'])
            self.show_fitrange.Enable(dat['show_fitrange'])
            self.show_peakrange.Enable(dat['show_peakrange'])
            self.plot_sub_bline.Enable(dat['plot_sub_bline'])

    def read_form(self):
        "read for, returning dict of values"
        dgroup = self.controller.get_group()
        array_desc = self.array_choice.GetStringSelection()
        form_opts = {'gname': dgroup.groupname,
                     'array_desc': array_desc.lower(),
                     'array_name': Array_Choices[array_desc],
                     'baseline_form': 'lorentzian'}

        form_opts['e0'] = self.wids['ppeak_e0'].GetValue()
        form_opts['emin'] = self.wids['ppeak_emin'].GetValue()
        form_opts['emax'] = self.wids['ppeak_emax'].GetValue()
        form_opts['elo'] = self.wids['ppeak_elo'].GetValue()
        form_opts['ehi'] = self.wids['ppeak_ehi'].GetValue()
        form_opts['plot_sub_bline'] = self.plot_sub_bline.IsChecked()
        form_opts['show_centroid'] = self.show_centroid.IsChecked()
        form_opts['show_peakrange'] = self.show_peakrange.IsChecked()
        form_opts['show_fitrange'] = self.show_fitrange.IsChecked()
        form_opts['show_e0'] = self.show_e0.IsChecked()
        return form_opts


    def onFitBaseline(self, evt=None):
        opts = self.read_form()

        cmd = """{gname:s}.ydat = 1.0*{gname:s}.{array_name:s}
pre_edge_baseline(energy={gname:s}.energy, norm={gname:s}.ydat, group={gname:s},
form='{baseline_form:s}', with_line=True,
elo={elo:.3f}, ehi={ehi:.3f}, emin={emin:.3f}, emax={emax:.3f})
"""
        self.larch_eval(cmd.format(**opts))

        dgroup = self.controller.get_group()
        ppeaks = dgroup.prepeaks
        dgroup.centroid_msg = "%.4f +/- %.4f eV" % (ppeaks.centroid,
                                                    ppeaks.delta_centroid)

        self.msg_centroid.SetLabel(dgroup.centroid_msg)

        if 'loren_' not in self.fit_components:
            self.addModel(model='Lorentzian', prefix='loren_', isbkg=True)
        if 'line_' not in self.fit_components:
            self.addModel(model='Linear', prefix='line_', isbkg=True)

        for prefix in ('loren_', 'line_'):
            cmp = self.fit_components[prefix]
            # cmp.bkgbox.SetValue(1)
            self.fill_model_params(prefix, dgroup.prepeaks.fit_details.params)

        self.fill_form(dgroup)
        self.fitmodel_btn.Enable()
        # self.fitallmodel_btn.Enable()

        i1, i2 = self.get_xranges(dgroup.energy)
        dgroup.yfit = dgroup.xfit = 0.0*dgroup.energy[i1:i2]

        self.plot_choice.SetStringSelection(PLOT_BASELINE)
        self.onPlot()

    def fill_model_params(self, prefix, params):
        comp = self.fit_components[prefix]
        parwids = comp.parwids
        for pname, par in params.items():
            pname = prefix + pname
            if pname in parwids:
                wids = parwids[pname]
                if wids.minval is not None:
                    wids.minval.SetValue(par.min)
                if wids.maxval is not None:
                    wids.maxval.SetValue(par.max)
                wids.value.SetValue(par.value)
                varstr = 'vary' if par.vary else 'fix'
                if par.expr is not None:
                    varstr = 'constrain'
                if wids.vary is not None:
                    wids.vary.SetStringSelection(varstr)

    def onPlot(self, evt=None):
        plot_choice = self.plot_choice.GetStringSelection()

        opts = self.read_form()
        dgroup = self.controller.get_group()

        ppeaks = getattr(dgroup, 'prepeaks', None)
        if ppeaks is None:
            return

        i1, i2 = self.get_xranges(dgroup.xdat)
        # i2 = len(ppeaks.baseline) + i1

        if len(dgroup.yfit) > len(ppeaks.baseline):
            i2 = i1 + len(ppeaks.baseline)
        # print(" Indexes: ", i1, i2, i2-i1, len(dgroup.yfit), len(ppeaks.baseline))

        xdat = 1.0*dgroup.energy
        ydat = 1.0*dgroup.ydat
        yfit = 1.0*dgroup.ydat
        baseline = 1.0*dgroup.ydat
        yfit[i1:i2] = dgroup.yfit[:i2-i1]
        baseline[i1:i2] = ppeaks.baseline[:i2-i1]


        if opts['plot_sub_bline']:
            ydat = ydat - baseline
            if plot_choice in (PLOT_FIT, PLOT_RESID):
                yfit = yfit - baseline
        if plot_choice == PLOT_RESID:
            resid = ydat - yfit

        _xs = dgroup.energy[i1:i2]
        xrange = max(_xs) - min(_xs)
        pxmin = min(_xs) - 0.05 * xrange
        pxmax = max(_xs) + 0.05 * xrange

        jmin = index_of(dgroup.energy, pxmin)
        jmax = index_of(dgroup.energy, pxmax) + 1

        _ys = ydat[jmin:jmax]
        yrange = max(_ys) - min(_ys)
        pymin = min(_ys) - 0.05 * yrange
        pymax = max(_ys) + 0.05 * yrange

        title = ' pre-edge fit'
        if plot_choice == PLOT_BASELINE:
            title = ' pre-edge baseline'
            if opts['plot_sub_bline']:
                title = ' pre-edge peaks'

        array_desc = self.array_choice.GetStringSelection()

        plotopts = {'xmin': pxmin, 'xmax': pxmax,
                    'ymin': pymin, 'ymax': pymax,
                    'title': '%s: %s' % (opts['gname'], title),
                    'xlabel': 'Energy (eV)',
                    'ylabel': '%s $\mu$' % opts['array_desc'],
                    'label': '%s $\mu$' % opts['array_desc'],
                    'delay_draw': True,
                    'show_legend': True}

        plot_extras = []
        if opts['show_fitrange']:
            popts = {'color': '#DDDDCC'}
            emin = opts['emin']
            emax = opts['emax']
            imin = index_of(dgroup.energy, emin)
            imax = index_of(dgroup.energy, emax)

            plot_extras.append(('vline', emin, None, popts))
            plot_extras.append(('vline', emax, None, popts))

        if opts['show_peakrange']:
            popts = {'marker': '+', 'markersize': 6}
            elo = opts['elo']
            ehi = opts['ehi']
            ilo = index_of(dgroup.xdat, elo)
            ihi = index_of(dgroup.xdat, ehi)

            plot_extras.append(('marker', elo, ydat[ilo], popts))
            plot_extras.append(('marker', ehi, ydat[ihi], popts))

        if opts['show_centroid']:
            popts = {'color': '#EECCCC'}
            ecen = getattr(dgroup.prepeaks, 'centroid', -1)
            if ecen > min(dgroup.energy):
                plot_extras.append(('vline', ecen, None,  popts))


        pframe = self.controller.get_display(win=2,
                                             stacked=(plot_choice==PLOT_RESID))
        ppanel = pframe.panel
        axes = ppanel.axes

        plotopts.update(PLOTOPTS_1)

        ppanel.plot(xdat, ydat, **plotopts)
        if plot_choice == PLOT_BASELINE:
            if not opts['plot_sub_bline']:
                ppanel.oplot(dgroup.prepeaks.energy,
                             dgroup.prepeaks.baseline,
                             label='baseline', **PLOTOPTS_2)

        elif plot_choice in (PLOT_FIT, PLOT_RESID):
            ppanel.oplot(dgroup.energy, yfit,
                         label='fit', **PLOTOPTS_1)

            if hasattr(dgroup, 'ycomps'):
                ncomp = len(dgroup.ycomps)
                icomp = 0
                for label, ycomp in dgroup.ycomps.items():
                    icomp +=1
                    fcomp = self.fit_components[label]
                    # print("ycomp: ", plot_choice, label, len(ycomp), len(dgroup.xfit),
                    #       fcomp.bkgbox.IsChecked(), opts['plot_sub_bline'], icomp, ncomp)
                    if not (fcomp.bkgbox.IsChecked() and opts['plot_sub_bline']):
                        ppanel.oplot(dgroup.xfit, ycomp, label=label,
                                     delay_draw=(icomp!=ncomp), style='short dashed')

            if plot_choice == PLOT_RESID:
                _ys = resid
                yrange = max(_ys) - min(_ys)
                plotopts['ymin'] = min(_ys) - 0.05 * yrange
                plotopts['ymax'] = max(_ys) + 0.05 * yrange
                plotopts['delay_draw'] = False
                plotopts['ylabel'] = 'data-fit'
                plotopts['label'] = '_nolegend_'

                pframe.plot(dgroup.energy, resid, panel='bot', **plotopts)
                pframe.Show()
                # print(" RESIDUAL PLOT  margins: ")
                # print(" top : ", pframe.panel.conf.margins)
                # print(" bot : ", pframe.panel_bot.conf.margins)


        for etype, x, y, opts in plot_extras:
            if etype == 'marker':
                popts = {'marker': 'o', 'markersize': 4,
                         'label': '_nolegend_',
                         'markerfacecolor': 'red',
                         'markeredgecolor': '#884444'}
                popts.update(opts)
                axes.plot([x], [y], **popts)
            elif etype == 'vline':
                popts = {'ymin': 0, 'ymax': 1.0, 'color': '#888888',
                         'label': '_nolegend_'}
                popts.update(opts)
                axes.axvline(x, **popts)
        ppanel.canvas.draw()

    def onNBChanged(self, event=None):
        idx = self.mod_nb.GetSelection()

    def addModel(self, event=None, model=None, prefix=None, isbkg=False):
        if model is None and event is not None:
            model = event.GetString()
        if model is None or model.startswith('<'):
            return

        if prefix is None:
            p = model[:5].lower()
            curmodels = ["%s%i_" % (p, i+1) for i in range(1+len(self.fit_components))]
            for comp in self.fit_components:
                if comp in curmodels:
                    curmodels.remove(comp)

            prefix = curmodels[0]

        label = "%s(prefix='%s')" % (model, prefix)
        title = "%s: %s " % (prefix[:-1], model)
        title = prefix[:-1]
        mclass_kws = {'prefix': prefix}
        if 'step' in model.lower():
            form = model.lower().replace('step', '').strip()

            if form.startswith('err'): form = 'erf'
            label = "Step(form='%s', prefix='%s')" % (form, prefix)
            title = "%s: Step %s" % (prefix[:-1], form[:3])
            mclass = lm_models.StepModel
            mclass_kws['form'] = form
            minst = mclass(form=form, prefix=prefix)
        else:
            if model in ModelFuncs:
                mclass = getattr(lm_models, ModelFuncs[model])
            else:
                mclass = getattr(lm_models, model+'Model')

            minst = mclass(prefix=prefix)

        panel = GridPanel(self.mod_nb, ncols=2, nrows=5, pad=2, itemstyle=CEN)

        def SLabel(label, size=(80, -1), **kws):
            return  SimpleText(panel, label,
                               size=size, style=wx.ALIGN_LEFT, **kws)
        usebox = Check(panel, default=True, label='Use in Fit?', size=(100, -1))
        bkgbox = Check(panel, default=False, label='Is Baseline?', size=(125, -1))
        if isbkg:
            bkgbox.SetValue(1)

        delbtn = Button(panel, 'Delete Component', size=(125, -1),
                        action=partial(self.onDeleteComponent, prefix=prefix))

        pick2msg = SimpleText(panel, "    ", size=(125, -1))
        pick2btn = Button(panel, 'Pick Values from Data', size=(150, -1),
                          action=partial(self.onPick2Points, prefix=prefix))

        # SetTip(mname,  'Label for the model component')
        SetTip(usebox,   'Use this component in fit?')
        SetTip(bkgbox,   'Label this component as "background" when plotting?')
        SetTip(delbtn,   'Delete this model component')
        SetTip(pick2btn, 'Select X range on Plot to Guess Initial Values')

        panel.Add(SLabel(label, size=(275, -1), colour='#0000AA'),
                  dcol=3,  style=wx.ALIGN_LEFT, newrow=True)
        panel.Add(usebox, dcol=1)
        panel.Add(bkgbox, dcol=2, style=LCEN)
        panel.Add(delbtn, dcol=1, style=wx.ALIGN_LEFT)

        panel.Add(pick2btn, dcol=2, style=wx.ALIGN_LEFT, newrow=True)
        panel.Add(pick2msg, dcol=2, style=wx.ALIGN_RIGHT)

        # panel.Add((10, 10), newrow=True)
        # panel.Add(HLine(panel, size=(150,  3)), dcol=4, style=wx.ALIGN_CENTER)

        panel.Add(SLabel("Parameter "), style=wx.ALIGN_LEFT,  newrow=True)
        panel.AddMany((SLabel(" Value"), SLabel(" Type"), SLabel(' Bounds'),
                       SLabel("  Min", size=(60, -1)),
                       SLabel("  Max", size=(60, -1)),  SLabel(" Expression")))

        parwids = OrderedDict()
        parnames = sorted(minst.param_names)

        for a in minst._func_allargs:
            pname = "%s%s" % (prefix, a)
            if (pname not in parnames and
                a in minst.param_hints and
                a not in minst.independent_vars):
                parnames.append(pname)

        for pname in parnames:
            sname = pname[len(prefix):]
            hints = minst.param_hints.get(sname, {})

            par = Parameter(name=pname, value=0, vary=True)
            if 'min' in hints:
                par.min = hints['min']
            if 'max' in hints:
                par.max = hints['max']
            if 'value' in hints:
                par.value = hints['value']
            if 'expr' in hints:
                par.expr = hints['expr']

            pwids = ParameterWidgets(panel, par, name_size=100, expr_size=125,
                                     float_size=80, prefix=prefix,
                                     widgets=('name', 'value',  'minval',
                                              'maxval', 'vary', 'expr'))
            parwids[par.name] = pwids
            panel.Add(pwids.name, newrow=True)

            panel.AddMany((pwids.value, pwids.vary, pwids.bounds,
                           pwids.minval, pwids.maxval, pwids.expr))

        for sname, hint in minst.param_hints.items():
            pname = "%s%s" % (prefix, sname)
            if 'expr' in hint and pname not in parnames:
                par = Parameter(name=pname, value=0, expr=hint['expr'])
                pwids = ParameterWidgets(panel, par, name_size=100, expr_size=225,
                                         float_size=80, prefix=prefix,
                                         widgets=('name', 'value', 'expr'))
                parwids[par.name] = pwids
                panel.Add(pwids.name, newrow=True)
                panel.Add(pwids.value)
                panel.Add(pwids.expr, dcol=4, style=wx.ALIGN_RIGHT)
                pwids.value.Disable()

        fgroup = Group(prefix=prefix, title=title, mclass=mclass,
                       mclass_kws=mclass_kws, usebox=usebox, panel=panel,
                       parwids=parwids, float_size=65, expr_size=150,
                       pick2_msg=pick2msg, bkgbox=bkgbox)


        self.fit_components[prefix] = fgroup
        panel.pack()

        self.mod_nb.AddPage(panel, title, True)
        sx,sy = self.GetSize()
        self.SetSize((sx, sy+1))
        self.SetSize((sx, sy))

    def onDeleteComponent(self, evt=None, prefix=None):
        fgroup = self.fit_components.get(prefix, None)
        if fgroup is None:
            return

        for i in range(self.mod_nb.GetPageCount()):
            if fgroup.title == self.mod_nb.GetPageText(i):
                self.mod_nb.DeletePage(i)

        for attr in dir(fgroup):
            setattr(fgroup, attr, None)

        self.fit_components.pop(prefix)

        # sx,sy = self.GetSize()
        # self.SetSize((sx, sy+1))
        # self.SetSize((sx, sy))

    def onPick2EraseTimer(self, evt=None):
        """erases line trace showing automated 'Pick 2' guess """
        self.pick2erase_timer.Stop()
        panel = self.pick2erase_panel
        ntrace = panel.conf.ntrace - 1
        trace = panel.conf.get_mpl_line(ntrace)
        panel.conf.get_mpl_line(ntrace).set_data(np.array([]), np.array([]))
        panel.conf.ntrace = ntrace
        panel.draw()

    def onPick2Timer(self, evt=None):
        """checks for 'Pick 2' events, and initiates 'Pick 2' guess
        for a model from the selected data range
        """
        try:
            plotframe = self.controller.get_display()
            curhist = plotframe.cursor_hist[:]
            plotframe.Raise()
        except:
            return

        if (time.time() - self.pick2_t0) > self.pick2_timeout:
            msg = self.pick2_group.pick2_msg.SetLabel(" ")
            plotframe.cursor_hist = []
            self.pick2_timer.Stop()
            return

        if len(curhist) < 2:
            self.pick2_group.pick2_msg.SetLabel("%i/2" % (len(curhist)))
            return

        self.pick2_group.pick2_msg.SetLabel("done.")
        self.pick2_timer.Stop()

        # guess param values
        xcur = (curhist[0][0], curhist[1][0])
        xmin, xmax = min(xcur), max(xcur)

        dgroup = getattr(self.larch.symtable, self.controller.groupname)
        x, y = dgroup.xdat, dgroup.ydat
        i0 = index_of(dgroup.xdat, xmin)
        i1 = index_of(dgroup.xdat, xmax)
        x, y = dgroup.xdat[i0:i1+1], dgroup.ydat[i0:i1+1]

        mod = self.pick2_group.mclass(prefix=self.pick2_group.prefix)
        parwids = self.pick2_group.parwids
        try:
            guesses = mod.guess(y, x=x)
        except:
            return

        for name, param in guesses.items():
            if name in parwids:
                parwids[name].value.SetValue(param.value)

        dgroup._tmp = mod.eval(guesses, x=dgroup.xdat)
        plotframe = self.controller.get_display()
        plotframe.cursor_hist = []
        plotframe.oplot(dgroup.xdat, dgroup._tmp)
        self.pick2erase_panel = plotframe.panel

        self.pick2erase_timer.Start(5000)


    def onPick2Points(self, evt=None, prefix=None):
        fgroup = self.fit_components.get(prefix, None)
        if fgroup is None:
            return

        plotframe = self.controller.get_display()
        plotframe.Raise()

        plotframe.cursor_hist = []
        fgroup.npts = 0
        self.pick2_group = fgroup

        if fgroup.pick2_msg is not None:
            fgroup.pick2_msg.SetLabel("0/2")

        self.pick2_t0 = time.time()
        self.pick2_timer.Start(250)


    def onSaveFitResult(self, event=None):
        dgroup = self.controller.get_group()
        deffile = dgroup.filename.replace('.', '_') + '.modl'

        outfile = FileSave(self, 'Save Fit Result',
                           default_file=deffile,
                           wildcard=ModelWcards)

        if outfile is not None:
            try:
                self.save_fit_result(dgroup.fit_history[-1], outfile)
            except IOError:
                print('could not write %s' % outfile)

    def onLoadFitResult(self, event=None):
        mfile = FileOpen(self, 'Load Fit Result',
                         default_file='', wildcard=ModelWcards)
        if mfile is not None:
            self.load_modelresult(mfile)

    def save_fit_result(self, fitresult, outfile):
        """saves a customized ModelResult"""
        save_modelresult(fitresult, outfile)

    def load_modelresult(self, inpfile):
        """read a customized ModelResult"""
        result = load_modelresult(inpfile)

        for prefix in list(self.fit_components.keys()):
            self.onDeleteComponent(self, prefix=prefix)

        for comp in result.model.components:
            isbkg = comp.prefix in result.user_options['bkg_components']
            self.addModel(model=comp.func.__name__,
                          prefix=comp.prefix, isbkg=isbkg)

        for comp in result.model.components:
            parwids = self.fit_components[comp.prefix].parwids
            for pname, par in result.params.items():
                if pname in parwids:
                    wids = parwids[pname]
                    if wids.minval is not None:
                        wids.minval.SetValue(par.min)
                    if wids.maxval is not None:
                        wids.maxval.SetValue(par.max)
                    val = result.init_values.get(pname, par.value)
                    wids.value.SetValue(val)

        self.fill_form(result.user_options)
        return result

    def onExportFitResult(self, event=None):
        dgroup = self.controller.get_group()
        deffile = dgroup.filename.replace('.', '_') + '_result.xdi'
        wcards = 'All files (*.*)|*.*'

        outfile = FileSave(self, 'Export Fit Result',
                           default_file=deffile, wildcard=wcards)

        if outfile is not None:
            i1, i2 = self.get_xranges(dgroup.xdat)
            x = dgroup.xdat[i1:i2]
            y = dgroup.ydat[i1:i2]
            yerr = None
            if hasattr(dgroup, 'yerr'):
                yerr = 1.0*dgroup.yerr
                if not isinstance(yerr, np.ndarray):
                    yerr = yerr * np.ones(len(y))
                else:
                    yerr = yerr[i1:i2]

            export_modelresult(dgroup.fit_history[-1],
                               filename=outfile,
                               datafile=dgroup.filename,
                               ydata=y, yerr=yerr, x=x)


    def onSelPoint(self, evt=None, opt='__', relative_e0=False, win=None):
        """
        get last selected point from a specified plot window
        and fill in the value for the widget defined by `opt`.

        by default it finds the latest cursor position from the
        cursor history of the first 20 plot windows.
        """
        if opt not in self.wids:
            return None

        _x, _y = last_cursor_pos(win=win, _larch=self.larch)

        if _x is not None:
            if relative_e0 and 'e0' in self.wids:
                _x -= self.wids['e0'].GetValue()
            self.wids[opt].SetValue(_x)

    def get_xranges(self, x):
        opts = self.read_form()
        dgroup = self.controller.get_group()
        en_eps = min(np.diff(dgroup.energy)) / 5.

        i1 = index_of(x, opts['emin'] + en_eps)
        i2 = index_of(x, opts['emax'] + en_eps) + 1
        return i1, i2

    def build_fitmodel(self):
        """ use fit components to build model"""
        dgroup = self.controller.get_group()
        fullmodel = None
        params = Parameters()
        self.summary = {'components': [], 'options': {}}
        peaks = []
        for comp in self.fit_components.values():
            _cen, _amp = None, None
            if comp.usebox is not None and comp.usebox.IsChecked():
                for parwids in comp.parwids.values():
                    params.add(parwids.param)
                    #print(" add param ", parwids.param)
                    if parwids.param.name.endswith('_center'):
                        _cen = parwids.param.name
                    elif parwids.param.name.endswith('_amplitude'):
                        _amp = parwids.param.name

                self.summary['components'].append((comp.mclass.__name__, comp.mclass_kws))
                thismodel = comp.mclass(**comp.mclass_kws)
                if fullmodel is None:
                   fullmodel = thismodel
                else:
                    fullmodel += thismodel
                if not comp.bkgbox.IsChecked() and _cen is not None and _amp is not None:
                    peaks.append((_amp, _cen))

        if len(peaks) > 0:
            denom = '+'.join([p[0] for p in peaks])
            numer = '+'.join(["%s*%s "% p for p in peaks])
            params.add('fit_centroid', expr="(%s)/(%s)" %(numer, denom))

        self.fit_model = fullmodel
        self.fit_params = params

        if dgroup is not None:
            i1, i2 = self.get_xranges(dgroup.xdat)
            xsel = dgroup.xdat[i1:i2]
            dgroup.xfit = xsel
            dgroup.yfit = self.fit_model.eval(self.fit_params, x=xsel)
            dgroup.ycomps = self.fit_model.eval_components(params=self.fit_params,
                                                           x=xsel)
        return dgroup

    def onFitSelected(self, event=None):
        dgroup = self.build_fitmodel()
        opts = self.read_form()
        print("fitting selected groups in progress")

    def onFitModel(self, event=None):
        dgroup = self.build_fitmodel()
        opts = self.read_form()

        i1, i2 = self.get_xranges(dgroup.xdat)
        dgroup.xfit = dgroup.xdat[i1:i2]
        ysel = dgroup.ydat[i1:i2]
        # print('onFit Model : xrange ', i1, i2, len(dgroup.xfit), len(dgroup.yfit))
        weights = np.ones(len(ysel))

        if hasattr(dgroup, 'yerr'):
            yerr = 1.0*dgroup.yerr
            if not isinstance(yerr, np.ndarray):
                yerr = yerr * np.ones(len(ysel))
            else:
                yerr = yerr[i1:i2]
            yerr_min = 1.e-9*ysel.mean()
            yerr[np.where(yerr < yerr_min)] = yerr_min
            weights = 1.0/yerr

        result = self.fit_model.fit(ysel, params=self.fit_params,
                                    x=dgroup.xfit, weights=weights,
                                    method='leastsq')
        self.summary['xmin'] = dgroup.xdat[i1]
        self.summary['xmax'] = dgroup.xdat[i2]
        for attr in ('aic', 'bic', 'chisqr', 'redchi', 'ci_out', 'covar',
                     'flatchain', 'success', 'nan_policy', 'nfev', 'ndata',
                     'nfree', 'nvarys', 'init_values'):
            self.summary[attr] = getattr(result, attr)
        self.summary['params'] = result.params

        dgroup.yfit = result.best_fit
        dgroup.ycomps = self.fit_model.eval_components(params=result.params,
                                                       x=dgroup.xfit)

        result.model_repr = self.fit_model._reprstring(long=True)

        ## hacks to save user options
        result.user_options = opts
        bkg_comps = []
        for label, comp in self.fit_components.items():
            if comp.bkgbox.IsChecked():
                bkg_comps.append(label)
        result.user_options['bkg_components'] = bkg_comps

        self.autosave_modelresult(result)
        if not hasattr(dgroup, 'fit_history'):
            dgroup.fit_history = []

        dgroup.fit_history.append(result)
        self.plot_choice.SetStringSelection(PLOT_FIT)
        self.onPlot()

        self.parent.show_subframe('prepeak_result_frame', FitResultFrame,
                                  datagroup=dgroup, peakframe=self)

        self.parent.subframes['prepeak_result_frame'].show_fitresult()
        [m.Enable(True) for m in self.parent.afterfit_menus]

    def update_start_values(self, params):
        """fill parameters with best fit values"""
        allparwids = {}
        for comp in self.fit_components.values():
            if comp.usebox is not None and comp.usebox.IsChecked():
                for name, parwids in comp.parwids.items():
                    allparwids[name] = parwids

        for pname, par in params.items():
            if pname in allparwids:
                allparwids[pname].value.SetValue(par.value)

    def autosave_modelresult(self, result, fname=None):
        """autosave model result to user larch folder"""
        xasguidir = os.path.join(site_config.usr_larchdir, 'xasgui')
        if not os.path.exists(xasguidir):
            try:
                os.makedirs(xasguidir)
            except OSError:
                print("Warning: cannot create XAS GUI user folder")
                return
        if not HAS_MODELSAVE:
            print("Warning: cannot save model results: upgrade lmfit")
            return
        if fname is None:
            fname = 'autosave.fitresult'
        fname = os.path.join(xasguidir, fname)

        self.save_fit_result(result, fname)
Example #4
0
class PrePeakPanel(TaskPanel):
    def __init__(self, parent=None, controller=None, **kws):
        TaskPanel.__init__(self, parent, controller,
                           configname='prepeaks_config',
                           config=defaults, **kws)

        self.fit_components = OrderedDict()
        self.user_added_params = None

        self.pick2_timer = wx.Timer(self)
        self.pick2_group = None
        self.Bind(wx.EVT_TIMER, self.onPick2Timer, self.pick2_timer)
        self.pick2_t0 = 0.
        self.pick2_timeout = 15.

        self.pick2erase_timer = wx.Timer(self)
        self.pick2erase_panel = None
        self.Bind(wx.EVT_TIMER, self.onPick2EraseTimer, self.pick2erase_timer)

    def onPanelExposed(self, **kws):
        # called when notebook is selected
        try:
            fname = self.controller.filelist.GetStringSelection()
            gname = self.controller.file_groups[fname]
            dgroup = self.controller.get_group(gname)
            self.fill_form(dgroup)
        except:
            pass # print(" Cannot Fill prepeak panel from group ")

    def build_display(self):
        self.mod_nb = flatnotebook(self, {})
        pan = self.panel = GridPanel(self, ncols=4, nrows=4, pad=2, itemstyle=LEFT)

        self.wids = {}

        fsopts = dict(digits=2, increment=0.1, with_pin=True)
        ppeak_e0   = self.add_floatspin('ppeak_e0', value=0, **fsopts)
        ppeak_elo  = self.add_floatspin('ppeak_elo', value=-15, **fsopts)
        ppeak_ehi  = self.add_floatspin('ppeak_ehi', value=-5, **fsopts)
        ppeak_emin = self.add_floatspin('ppeak_emin', value=-30, **fsopts)
        ppeak_emax = self.add_floatspin('ppeak_emax', value=0, **fsopts)

        self.fitbline_btn  = Button(pan,'Fit Baseline', action=self.onFitBaseline,
                                    size=(125, -1))
        self.plotmodel_btn = Button(pan, 'Plot Model',
                                   action=self.onPlotModel,  size=(125, -1))
        self.fitmodel_btn = Button(pan, 'Fit Model',
                                   action=self.onFitModel,  size=(125, -1))
        self.loadmodel_btn = Button(pan, 'Load Model',
                                    action=self.onLoadFitResult,  size=(125, -1))
        self.fitmodel_btn.Disable()

        self.array_choice = Choice(pan, size=(175, -1),
                                   choices=list(Array_Choices.keys()))
        self.array_choice.SetSelection(1)

        models_peaks = Choice(pan, size=(150, -1),
                              choices=ModelChoices['peaks'],
                              action=self.addModel)

        models_other = Choice(pan, size=(150, -1),
                              choices=ModelChoices['other'],
                              action=self.addModel)

        self.models_peaks = models_peaks
        self.models_other = models_other


        self.message = SimpleText(pan,
                                 'first fit baseline, then add peaks to fit model.')

        self.msg_centroid = SimpleText(pan, '----')

        opts = dict(default=True, size=(75, -1), action=self.onPlot)
        self.show_centroid  = Check(pan, label='show?', **opts)
        self.show_peakrange = Check(pan, label='show?', **opts)
        self.show_fitrange  = Check(pan, label='show?', **opts)
        self.show_e0        = Check(pan, label='show?', **opts)

        opts = dict(default=False, size=(200, -1), action=self.onPlot)

        def add_text(text, dcol=1, newrow=True):
            pan.Add(SimpleText(pan, text), dcol=dcol, newrow=newrow)

        pan.Add(SimpleText(pan, ' Pre-edge Peak Fitting',
                           **self.titleopts), dcol=5)
        add_text(' Run Fit:', newrow=False)

        add_text('Array to fit: ')
        pan.Add(self.array_choice, dcol=3)
        pan.Add((10, 10))
        pan.Add(self.fitbline_btn)

        add_text('E0: ')
        pan.Add(ppeak_e0)
        pan.Add((10, 10), dcol=2)
        pan.Add(self.show_e0)
        pan.Add(self.plotmodel_btn)


        add_text('Fit Energy Range: ')
        pan.Add(ppeak_emin)
        add_text(' : ', newrow=False)
        pan.Add(ppeak_emax)
        pan.Add(self.show_fitrange)
        pan.Add(self.fitmodel_btn)

        t = SimpleText(pan, 'Pre-edge Peak Range: ')
        SetTip(t, 'Range used as mask for background')

        pan.Add(t, newrow=True)
        pan.Add(ppeak_elo)
        add_text(' : ', newrow=False)
        pan.Add(ppeak_ehi)
        pan.Add(self.show_peakrange)


        # pan.Add(self.fitsel_btn)

        add_text( 'Peak Centroid: ')
        pan.Add(self.msg_centroid, dcol=3)
        pan.Add(self.show_centroid, dcol=1)
        pan.Add(self.loadmodel_btn)

        #  add model
        ts = wx.BoxSizer(wx.HORIZONTAL)
        ts.Add(models_peaks)
        ts.Add(models_other)

        pan.Add(SimpleText(pan, 'Add Component: '), newrow=True)
        pan.Add(ts, dcol=7)

        pan.Add(SimpleText(pan, 'Messages: '), newrow=True)
        pan.Add(self.message, dcol=7)

        pan.pack()

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add((10, 10), 0, LEFT, 3)
        sizer.Add(pan, 0, LEFT, 3)
        sizer.Add((10, 10), 0, LEFT, 3)
        sizer.Add(HLine(self, size=(550, 2)), 0, LEFT, 3)
        sizer.Add((10, 10), 0, LEFT, 3)
        sizer.Add(self.mod_nb,  1, LEFT|wx.GROW, 10)

        pack(self, sizer)

    def get_config(self, dgroup=None):
        """get processing configuration for a group"""
        if dgroup is None:
            dgroup = self.controller.get_group()

        conf = getattr(dgroup, 'prepeak_config', {})
        if 'e0' not in conf:
            conf = defaults
            conf['e0'] = getattr(dgroup, 'e0', -1)

        dgroup.prepeak_config = conf
        if not hasattr(dgroup, 'prepeaks'):
            dgroup.prepeaks = Group()

        return conf

    def fill_form(self, dat):
        if isinstance(dat, Group):
            self.wids['ppeak_e0'].SetValue(dat.e0)
            if hasattr(dat, 'prepeaks'):
                self.wids['ppeak_emin'].SetValue(dat.prepeaks.emin)
                self.wids['ppeak_emax'].SetValue(dat.prepeaks.emax)
                self.wids['ppeak_elo'].SetValue(dat.prepeaks.elo)
                self.wids['ppeak_ehi'].SetValue(dat.prepeaks.ehi)
        elif isinstance(dat, dict):
            self.wids['ppeak_e0'].SetValue(dat['e0'])
            self.wids['ppeak_emin'].SetValue(dat['emin'])
            self.wids['ppeak_emax'].SetValue(dat['emax'])
            self.wids['ppeak_elo'].SetValue(dat['elo'])
            self.wids['ppeak_ehi'].SetValue(dat['ehi'])

            self.array_choice.SetStringSelection(dat['array_desc'])
            self.show_e0.Enable(dat['show_e0'])
            self.show_centroid.Enable(dat['show_centroid'])
            self.show_fitrange.Enable(dat['show_fitrange'])
            self.show_peakrange.Enable(dat['show_peakrange'])

    def read_form(self):
        "read for, returning dict of values"
        dgroup = self.controller.get_group()
        array_desc = self.array_choice.GetStringSelection()
        form_opts = {'gname': dgroup.groupname,
                     'filename': dgroup.filename,
                     'array_desc': array_desc.lower(),
                     'array_name': Array_Choices[array_desc],
                     'baseline_form': 'lorentzian',
                     'bkg_components': []}

        form_opts['e0'] = self.wids['ppeak_e0'].GetValue()
        form_opts['emin'] = self.wids['ppeak_emin'].GetValue()
        form_opts['emax'] = self.wids['ppeak_emax'].GetValue()
        form_opts['elo'] = self.wids['ppeak_elo'].GetValue()
        form_opts['ehi'] = self.wids['ppeak_ehi'].GetValue()
        form_opts['plot_sub_bline'] = False # self.plot_sub_bline.IsChecked()
        form_opts['show_centroid'] = self.show_centroid.IsChecked()
        form_opts['show_peakrange'] = self.show_peakrange.IsChecked()
        form_opts['show_fitrange'] = self.show_fitrange.IsChecked()
        form_opts['show_e0'] = self.show_e0.IsChecked()
        return form_opts

    def onFitBaseline(self, evt=None):
        opts = self.read_form()
        cmd = """{gname:s}.ydat = 1.0*{gname:s}.{array_name:s}
pre_edge_baseline(energy={gname:s}.energy, norm={gname:s}.ydat, group={gname:s}, form='{baseline_form:s}',
                  with_line=True, elo={elo:.3f}, ehi={ehi:.3f}, emin={emin:.3f}, emax={emax:.3f})"""
        self.larch_eval(cmd.format(**opts))

        dgroup = self.controller.get_group()
        ppeaks = dgroup.prepeaks
        dgroup.centroid_msg = "%.4f +/- %.4f eV" % (ppeaks.centroid,
                                                    ppeaks.delta_centroid)

        self.msg_centroid.SetLabel(dgroup.centroid_msg)

        if 'bpeak_' not in self.fit_components:
            self.addModel(model='Lorentzian', prefix='bpeak_', isbkg=True)
        if 'bline_' not in self.fit_components:
            self.addModel(model='Linear', prefix='bline_', isbkg=True)

        for prefix in ('bpeak_', 'bline_'):
            cmp = self.fit_components[prefix]
            # cmp.bkgbox.SetValue(1)
            self.fill_model_params(prefix, dgroup.prepeaks.fit_details.params)

        self.fill_form(dgroup)
        self.fitmodel_btn.Enable()
        # self.fitallmodel_btn.Enable()

        i1, i2 = self.get_xranges(dgroup.energy)
        dgroup.yfit = dgroup.xfit = 0.0*dgroup.energy[i1:i2]

        # self.plot_choice.SetStringSelection(PLOT_BASELINE)
        self.onPlot(baseline_only=True)

    def fill_model_params(self, prefix, params):
        comp = self.fit_components[prefix]
        parwids = comp.parwids
        for pname, par in params.items():
            pname = prefix + pname
            if pname in parwids:
                wids = parwids[pname]
                if wids.minval is not None:
                    wids.minval.SetValue(par.min)
                if wids.maxval is not None:
                    wids.maxval.SetValue(par.max)
                wids.value.SetValue(par.value)
                varstr = 'vary' if par.vary else 'fix'
                if par.expr is not None:
                    varstr = 'constrain'
                if wids.vary is not None:
                    wids.vary.SetStringSelection(varstr)

    def onPlotModel(self, evt=None):
        dgroup = self.controller.get_group()
        g = self.build_fitmodel(dgroup)
        self.onPlot(show_init=True)

    def onPlot(self, evt=None, baseline_only=False, show_init=False):
        opts = self.read_form()
        dgroup = self.controller.get_group()

        opts['group'] = opts['gname']
        self.larch_eval(COMMANDS['prepeaks_setup'].format(**opts))

        cmd = "plot_prepeaks_fit"
        args = ['{gname}']
        if baseline_only:
            cmd = "plot_prepeaks_baseline"
        else:
            args.append("show_init=%s" % (show_init))
        cmd = "%s(%s)" % (cmd, ', '.join(args))
        self.larch_eval(cmd.format(**opts))

    def addModel(self, event=None, model=None, prefix=None, isbkg=False):
        if model is None and event is not None:
            model = event.GetString()
        if model is None or model.startswith('<'):
            return

        self.models_peaks.SetSelection(0)
        self.models_other.SetSelection(0)

        if prefix is None:
            p = model[:5].lower()
            curmodels = ["%s%i_" % (p, i+1) for i in range(1+len(self.fit_components))]
            for comp in self.fit_components:
                if comp in curmodels:
                    curmodels.remove(comp)

            prefix = curmodels[0]

        label = "%s(prefix='%s')" % (model, prefix)
        title = "%s: %s " % (prefix[:-1], model)
        title = prefix[:-1]
        mclass_kws = {'prefix': prefix}
        if 'step' in model.lower():
            form = model.lower().replace('step', '').strip()
            if form.startswith('err'):
                form = 'erf'
            label = "Step(form='%s', prefix='%s')" % (form, prefix)
            title = "%s: Step %s" % (prefix[:-1], form[:3])
            mclass = lm_models.StepModel
            mclass_kws['form'] = form
            minst = mclass(form=form, prefix=prefix)
        else:
            if model in ModelFuncs:
                mclass = getattr(lm_models, ModelFuncs[model])
            else:
                mclass = getattr(lm_models, model+'Model')

            minst = mclass(prefix=prefix)

        panel = GridPanel(self.mod_nb, ncols=2, nrows=5, pad=1, itemstyle=CEN)

        def SLabel(label, size=(80, -1), **kws):
            return  SimpleText(panel, label,
                               size=size, style=wx.ALIGN_LEFT, **kws)
        usebox = Check(panel, default=True, label='Use in Fit?', size=(100, -1))
        bkgbox = Check(panel, default=False, label='Is Baseline?', size=(125, -1))
        if isbkg:
            bkgbox.SetValue(1)

        delbtn = Button(panel, 'Delete This Component', size=(200, -1),
                        action=partial(self.onDeleteComponent, prefix=prefix))

        pick2msg = SimpleText(panel, "    ", size=(125, -1))
        pick2btn = Button(panel, 'Pick Values from Plotted Data', size=(200, -1),
                          action=partial(self.onPick2Points, prefix=prefix))

        # SetTip(mname,  'Label for the model component')
        SetTip(usebox,   'Use this component in fit?')
        SetTip(bkgbox,   'Label this component as "background" when plotting?')
        SetTip(delbtn,   'Delete this model component')
        SetTip(pick2btn, 'Select X range on Plot to Guess Initial Values')

        panel.Add(SLabel(label, size=(275, -1), colour='#0000AA'),
                  dcol=4,  style=wx.ALIGN_LEFT, newrow=True)
        panel.Add(usebox, dcol=2)
        panel.Add(bkgbox, dcol=1, style=RIGHT)

        panel.Add(pick2btn, dcol=2, style=wx.ALIGN_LEFT, newrow=True)
        panel.Add(pick2msg, dcol=3, style=wx.ALIGN_RIGHT)
        panel.Add(delbtn, dcol=2, style=wx.ALIGN_RIGHT)

        # panel.Add(HLine(panel, size=(150,  3)), dcol=4, style=wx.ALIGN_CENTER)

        panel.Add(SLabel("Parameter "), style=wx.ALIGN_LEFT,  newrow=True)
        panel.AddMany((SLabel(" Value"), SLabel(" Type"), SLabel(' Bounds'),
                       SLabel("  Min", size=(60, -1)),
                       SLabel("  Max", size=(60, -1)),  SLabel(" Expression")))

        parwids = OrderedDict()
        parnames = sorted(minst.param_names)

        for a in minst._func_allargs:
            pname = "%s%s" % (prefix, a)
            if (pname not in parnames and
                a in minst.param_hints and
                a not in minst.independent_vars):
                parnames.append(pname)

        for pname in parnames:
            sname = pname[len(prefix):]
            hints = minst.param_hints.get(sname, {})

            par = Parameter(name=pname, value=0, vary=True)
            if 'min' in hints:
                par.min = hints['min']
            if 'max' in hints:
                par.max = hints['max']
            if 'value' in hints:
                par.value = hints['value']
            if 'expr' in hints:
                par.expr = hints['expr']

            pwids = ParameterWidgets(panel, par, name_size=100, expr_size=150,
                                     float_size=80, prefix=prefix,
                                     widgets=('name', 'value',  'minval',
                                              'maxval', 'vary', 'expr'))
            parwids[par.name] = pwids
            panel.Add(pwids.name, newrow=True)

            panel.AddMany((pwids.value, pwids.vary, pwids.bounds,
                           pwids.minval, pwids.maxval, pwids.expr))

        for sname, hint in minst.param_hints.items():
            pname = "%s%s" % (prefix, sname)
            if 'expr' in hint and pname not in parnames:
                par = Parameter(name=pname, value=0, expr=hint['expr'])
                pwids = ParameterWidgets(panel, par, name_size=100, expr_size=400,
                                         float_size=80, prefix=prefix,
                                         widgets=('name', 'value', 'expr'))
                parwids[par.name] = pwids
                panel.Add(pwids.name, newrow=True)
                panel.Add(pwids.value)
                panel.Add(pwids.expr, dcol=5, style=wx.ALIGN_RIGHT)
                pwids.value.Disable()

        fgroup = Group(prefix=prefix, title=title, mclass=mclass,
                       mclass_kws=mclass_kws, usebox=usebox, panel=panel,
                       parwids=parwids, float_size=65, expr_size=150,
                       pick2_msg=pick2msg, bkgbox=bkgbox)


        self.fit_components[prefix] = fgroup
        panel.pack()

        self.mod_nb.AddPage(panel, title, True)
        sx,sy = self.GetSize()
        self.SetSize((sx, sy+1))
        self.SetSize((sx, sy))
        self.fitmodel_btn.Enable()


    def onDeleteComponent(self, evt=None, prefix=None):
        fgroup = self.fit_components.get(prefix, None)
        if fgroup is None:
            return

        for i in range(self.mod_nb.GetPageCount()):
            if fgroup.title == self.mod_nb.GetPageText(i):
                self.mod_nb.DeletePage(i)

        for attr in dir(fgroup):
            setattr(fgroup, attr, None)

        self.fit_components.pop(prefix)
        if len(self.fit_components) < 1:
            self.fitmodel_btn.Disable()

        # sx,sy = self.GetSize()
        # self.SetSize((sx, sy+1))
        # self.SetSize((sx, sy))

    def onPick2EraseTimer(self, evt=None):
        """erases line trace showing automated 'Pick 2' guess """
        self.pick2erase_timer.Stop()
        panel = self.pick2erase_panel
        ntrace = panel.conf.ntrace - 1
        trace = panel.conf.get_mpl_line(ntrace)
        panel.conf.get_mpl_line(ntrace).set_data(np.array([]), np.array([]))
        panel.conf.ntrace = ntrace
        panel.draw()

    def onPick2Timer(self, evt=None):
        """checks for 'Pick 2' events, and initiates 'Pick 2' guess
        for a model from the selected data range
        """
        try:
            plotframe = self.controller.get_display(win=1)
            curhist = plotframe.cursor_hist[:]
            plotframe.Raise()
        except:
            return

        if (time.time() - self.pick2_t0) > self.pick2_timeout:
            msg = self.pick2_group.pick2_msg.SetLabel(" ")
            plotframe.cursor_hist = []
            self.pick2_timer.Stop()
            return

        if len(curhist) < 2:
            self.pick2_group.pick2_msg.SetLabel("%i/2" % (len(curhist)))
            return

        self.pick2_group.pick2_msg.SetLabel("done.")
        self.pick2_timer.Stop()

        # guess param values
        xcur = (curhist[0][0], curhist[1][0])
        xmin, xmax = min(xcur), max(xcur)

        dgroup = getattr(self.larch.symtable, self.controller.groupname)
        x, y = dgroup.xdat, dgroup.ydat
        i0 = index_of(dgroup.xdat, xmin)
        i1 = index_of(dgroup.xdat, xmax)
        x, y = dgroup.xdat[i0:i1+1], dgroup.ydat[i0:i1+1]

        mod = self.pick2_group.mclass(prefix=self.pick2_group.prefix)
        parwids = self.pick2_group.parwids
        try:
            guesses = mod.guess(y, x=x)
        except:
            return

        for name, param in guesses.items():
            if name in parwids:
                parwids[name].value.SetValue(param.value)

        dgroup._tmp = mod.eval(guesses, x=dgroup.xdat)
        plotframe = self.controller.get_display(win=1)
        plotframe.cursor_hist = []
        plotframe.oplot(dgroup.xdat, dgroup._tmp)
        self.pick2erase_panel = plotframe.panel

        self.pick2erase_timer.Start(5000)


    def onPick2Points(self, evt=None, prefix=None):
        fgroup = self.fit_components.get(prefix, None)
        if fgroup is None:
            return

        plotframe = self.controller.get_display(win=1)
        plotframe.Raise()

        plotframe.cursor_hist = []
        fgroup.npts = 0
        self.pick2_group = fgroup

        if fgroup.pick2_msg is not None:
            fgroup.pick2_msg.SetLabel("0/2")

        self.pick2_t0 = time.time()
        self.pick2_timer.Start(250)


    def onLoadFitResult(self, event=None):
        dlg = wx.FileDialog(self, message="Load Saved File Model",
                            wildcard=ModelWcards, style=wx.FD_OPEN)
        rfile = None
        if dlg.ShowModal() == wx.ID_OK:
            rfile = dlg.GetPath()
        dlg.Destroy()

        if rfile is None:
            return

        self.larch_eval("# peakmodel = lm_load_modelresult('%s')" %rfile)

        result = load_modelresult(str(rfile))
        for prefix in list(self.fit_components.keys()):
            self.onDeleteComponent(self, prefix=prefix)

        for comp in result.model.components:
            isbkg = comp.prefix in result.user_options['bkg_components']
            self.addModel(model=comp.func.__name__,
                          prefix=comp.prefix, isbkg=isbkg)

        for comp in result.model.components:
            parwids = self.fit_components[comp.prefix].parwids
            for pname, par in result.params.items():
                if pname in parwids:
                    wids = parwids[pname]
                    if wids.minval is not None:
                        wids.minval.SetValue(par.min)
                    if wids.maxval is not None:
                        wids.maxval.SetValue(par.max)
                    val = result.init_values.get(pname, par.value)
                    wids.value.SetValue(val)
        self.fill_form(result.user_options)


    def onSelPoint(self, evt=None, opt='__', relative_e0=False, win=None):
        """
        get last selected point from a specified plot window
        and fill in the value for the widget defined by `opt`.

        by default it finds the latest cursor position from the
        cursor history of the first 20 plot windows.
        """
        if opt not in self.wids:
            return None

        _x, _y = last_cursor_pos(win=win, _larch=self.larch)

        if _x is not None:
            if relative_e0 and 'e0' in self.wids:
                _x -= self.wids['e0'].GetValue()
            self.wids[opt].SetValue(_x)

    def get_xranges(self, x):
        opts = self.read_form()
        dgroup = self.controller.get_group()
        en_eps = min(np.diff(dgroup.energy)) / 5.

        i1 = index_of(x, opts['emin'] + en_eps)
        i2 = index_of(x, opts['emax'] + en_eps) + 1
        return i1, i2

    def build_fitmodel(self, dgroup):
        """ use fit components to build model"""
        # self.summary = {'components': [], 'options': {}}
        peaks = []
        cmds = ["## set up pre-edge peak parameters", "peakpars = Parameters()"]
        modcmds = ["## define pre-edge peak model"]
        modop = " ="
        opts = self.read_form()


        opts['group'] = opts['gname']
        self.larch_eval(COMMANDS['prepeaks_setup'].format(**opts))


        for comp in self.fit_components.values():
            _cen, _amp = None, None
            if comp.usebox is not None and comp.usebox.IsChecked():
                for parwids in comp.parwids.values():
                    this = parwids.param
                    pargs = ["'%s'" % this.name, 'value=%f' % (this.value),
                             'min=%f' % (this.min), 'max=%f' % (this.max)]
                    if this.expr is not None:
                        pargs.append("expr='%s'" % (this.expr))
                    elif not this.vary:
                        pargs.pop()
                        pargs.pop()
                        pargs.append("vary=False")

                    cmds.append("peakpars.add(%s)" % (', '.join(pargs)))
                    if this.name.endswith('_center'):
                        _cen = this.name
                    elif parwids.param.name.endswith('_amplitude'):
                        _amp = this.name
                compargs = ["%s='%s'" % (k,v) for k,v in comp.mclass_kws.items()]
                modcmds.append("peakmodel %s %s(%s)" % (modop, comp.mclass.__name__,
                                                        ', '.join(compargs)))

                modop = "+="
                if not comp.bkgbox.IsChecked() and _cen is not None and _amp is not None:
                    peaks.append((_amp, _cen))

        if len(peaks) > 0:
            denom = '+'.join([p[0] for p in peaks])
            numer = '+'.join(["%s*%s "% p for p in peaks])
            cmds.append("peakpars.add('fit_centroid', expr='(%s)/(%s)')" % (numer, denom))

        cmds.extend(modcmds)
        cmds.append(COMMANDS['prepfit'].format(group=dgroup.groupname,
                                               user_opts=repr(opts)))

        self.larch_eval("\n".join(cmds))

    def onFitSelected(self, event=None):
        dgroup = self.controller.get_group()
        self.build_fitmodel(dgroup)

    def onFitModel(self, event=None):
        dgroup = self.controller.get_group()
        if dgroup is None:
            return
        self.build_fitmodel(dgroup)
        opts = self.read_form()

        dgroup = self.controller.get_group()
        opts['group'] = opts['gname']
        self.larch_eval(COMMANDS['prepeaks_setup'].format(**opts))

        ppeaks = dgroup.prepeaks


        # add bkg_component to saved user options
        bkg_comps = []
        for label, comp in self.fit_components.items():
            if comp.bkgbox.IsChecked():
                bkg_comps.append(label)
        opts['bkg_components'] = bkg_comps

        imin, imax = self.get_xranges(dgroup.xdat)

        cmds = ["## do peak fit: "]

        yerr_type = 'set_yerr_const'
        yerr = getattr(dgroup, 'yerr', None)
        if yerr is None:
            if hasattr(dgroup, 'norm_std'):
                cmds.append("{group}.yerr = {group}.norm_std")
                yerr_type = 'set_yerr_array'
            elif hasattr(dgroup, 'mu_std'):
                cmds.append("{group}.yerr = {group}.mu_std/(1.e-15+{group}.edge_step)")
                yerr_type = 'set_yerr_array'
            else:
                cmds.append("{group}.yerr = 1")
        elif isinstance(dgroup.yerr, np.ndarray):
                yerr_type = 'set_yerr_array'


        cmds.extend([COMMANDS[yerr_type], COMMANDS['dofit']])

        cmd = '\n'.join(cmds)
        self.larch_eval(cmd.format(group=dgroup.groupname,
                                   imin=imin, imax=imax,
                                   user_opts=repr(opts)))

        self.autosave_modelresult(self.larch_get("peakresult"))

        self.onPlot()
        self.show_subframe('prepeak_result_frame', FitResultFrame,
                                  datagroup=dgroup, peakframe=self)
        self.subframes['prepeak_result_frame'].show_results()

    def update_start_values(self, params):
        """fill parameters with best fit values"""
        allparwids = {}
        for comp in self.fit_components.values():
            if comp.usebox is not None and comp.usebox.IsChecked():
                for name, parwids in comp.parwids.items():
                    allparwids[name] = parwids

        for pname, par in params.items():
            if pname in allparwids:
                allparwids[pname].value.SetValue(par.value)

    def autosave_modelresult(self, result, fname=None):
        """autosave model result to user larch folder"""
        confdir = os.path.join(site_config.usr_larchdir, 'xas_viewer')
        if not os.path.exists(confdir):
            try:
                os.makedirs(confdir)
            except OSError:
                print("Warning: cannot create XAS_Viewer user folder")
                return
        if not HAS_MODELSAVE:
            print("Warning: cannot save model results: upgrade lmfit")
            return
        if fname is None:
            fname = 'autosave.fitmodel'
        save_modelresult(result, os.path.join(confdir, fname))