Ejemplo n.º 1
0
class PlotPlotPanel(wx.Panel):
    def __init__(self, parent, dpi=None, **kwargs):
        wx.Panel.__init__(self, parent, wx.ID_ANY, wx.DefaultPosition, wx.DefaultSize, **kwargs)
        self.ztv_frame = self.GetTopLevelParent()
        self.figure = Figure(dpi=None, figsize=(1.,1.))
        self.axes = self.figure.add_subplot(111)
        self.canvas = FigureCanvasWxAgg(self, -1, self.figure)
        self.Bind(wx.EVT_SIZE, self._onSize)
        self.axes_widget = AxesWidget(self.figure.gca())
        self.axes_widget.connect_event('motion_notify_event', self.on_motion)
        self.plot_point = None
        
    def on_motion(self, evt):
        if evt.xdata is not None:
            xarg = np.abs(self.ztv_frame.plot_panel.plot_positions - evt.xdata).argmin()
            ydata = self.ztv_frame.plot_panel.plot_im_values[xarg]
            self.ztv_frame.plot_panel.cursor_position_textctrl.SetValue('{0:.6g},{1:.6g}'.format(evt.xdata, ydata))
            if self.plot_point is None:
                self.plot_point, = self.axes.plot([evt.xdata], [ydata], 'xm')
            else:
                self.plot_point.set_data([[evt.xdata], [ydata]])
            self.figure.canvas.draw()

    def _onSize(self, event):
        self._SetSize()

    def _SetSize(self):
        pixels = tuple(self.GetClientSize())
        self.SetSize(pixels)
        self.canvas.SetSize(pixels)
        self.figure.set_size_inches(float(pixels[0])/self.figure.get_dpi(), float(pixels[1])/self.figure.get_dpi())
Ejemplo n.º 2
0
def link_ngl_wdgt_to_ax_pos(ax, pos, ngl_widget):
    from matplotlib.widgets import AxesWidget
    from scipy.spatial import cKDTree
    r"""
    Initial idea for this function comes from @arose, the rest is @gph82 and @clonker
    """

    kdtree = cKDTree(pos)
    #assert ngl_widget.trajectory_0.n_frames == pos.shape[0]
    x, y = pos.T

    lineh = ax.axhline(ax.get_ybound()[0], c="black", ls='--')
    linev = ax.axvline(ax.get_xbound()[0], c="black", ls='--')
    dot, = ax.plot(pos[0, 0], pos[0, 1], 'o', c='red', ms=7)

    ngl_widget.isClick = False

    def onclick(event):
        linev.set_xdata((event.xdata, event.xdata))
        lineh.set_ydata((event.ydata, event.ydata))
        data = [event.xdata, event.ydata]
        _, index = kdtree.query(x=data, k=1)
        dot.set_xdata((x[index]))
        dot.set_ydata((y[index]))
        ngl_widget.isClick = True
        ngl_widget.frame = index

    def my_observer(change):
        r"""Here comes the code that you want to execute
        """
        ngl_widget.isClick = False
        _idx = change["new"]
        try:
            dot.set_xdata((x[_idx]))
            dot.set_ydata((y[_idx]))
        except IndexError as e:
            dot.set_xdata((x[0]))
            dot.set_ydata((y[0]))
            print("caught index error with index %s (new=%s, old=%s)" %
                  (_idx, change["new"], change["old"]))

    # Connect axes to widget
    axes_widget = AxesWidget(ax)
    axes_widget.connect_event('button_release_event', onclick)

    # Connect widget to axes
    ngl_widget.observe(my_observer, "frame", "change")
Ejemplo n.º 3
0
class MPLFigureEditor(BasicEditorFactory):
    klass = _MPLFigureEditor

    def setup_mpl_events(self):
        self.image_axeswidget = AxesWidget(self.image_axes)
        self.image_axeswidget.connect_event('motion_notify_event',
                                            self.image_on_motion)
        #self.image_axeswidget.connect_event('figure_leave_event', self.on_cursor_leave)
        #self.image_axeswidget.connect_event('figure_enter_event', self.on_cursor_enter)
        wx.EVT_RIGHT_DOWN(self.image_figure.canvas, self.on_right_down)

    def on_right_down(self, event):
        if self.image_popup_menu is None:
            menu = wx.Menu()

    def image_on_motion(self, event):
        if event.xdata is None or event.ydata is None:
            return
Ejemplo n.º 4
0
class PlotPlotPanel(wx.Panel):
    def __init__(self, parent, dpi=None, **kwargs):
        wx.Panel.__init__(self, parent, wx.ID_ANY, wx.DefaultPosition,
                          wx.DefaultSize, **kwargs)
        self.ztv_frame = self.GetTopLevelParent()
        self.figure = Figure(dpi=None, figsize=(1., 1.))
        self.axes = self.figure.add_subplot(111)
        self.canvas = FigureCanvasWxAgg(self, -1, self.figure)
        self.Bind(wx.EVT_SIZE, self._onSize)
        self.axes_widget = AxesWidget(self.figure.gca())
        self.axes_widget.connect_event('motion_notify_event', self.on_motion)
        self.plot_point = None

    def on_motion(self, evt):
        if evt.xdata is not None:
            xarg = np.abs(self.ztv_frame.plot_panel.plot_positions -
                          evt.xdata).argmin()
            ydata = self.ztv_frame.plot_panel.plot_im_values[xarg]
            self.ztv_frame.plot_panel.cursor_position_textctrl.SetValue(
                '{0:.6g},{1:.6g}'.format(evt.xdata, ydata))
            if self.plot_point is None:
                self.plot_point, = self.axes.plot([evt.xdata], [ydata], 'xm')
            else:
                self.plot_point.set_data([[evt.xdata], [ydata]])
            self.figure.canvas.draw()

    def _onSize(self, event):
        self._SetSize()

    def _SetSize(self):
        pixels = tuple(self.GetClientSize())
        self.SetSize(pixels)
        self.canvas.SetSize(pixels)
        self.figure.set_size_inches(
            float(pixels[0]) / self.figure.get_dpi(),
            float(pixels[1]) / self.figure.get_dpi())
Ejemplo n.º 5
0
class AtmosViewer(HasTraits):
    central_wavenumber = CFloat(1000)
    bandwidth = CFloat(10)

    selected_line_wavenumber = Float(-1.)

    figure = Instance(Figure, ())

    all_on = Button()
    all_off = Button()
    selected_molecules = List(editor=CheckListEditor(
        values=molecules.keys(), cols=2, format_str='%s'))

    mplFigureEditor = MPLFigureEditor()

    trait_view = View(VGroup(
        Item('figure', editor=mplFigureEditor, show_label=False),
        HGroup(
            '10',
            VGroup('40',
                   Item(name='central_wavenumber',
                        editor=TextEditor(auto_set=False, enter_set=True)),
                   Item(name='bandwidth',
                        editor=TextEditor(auto_set=False, enter_set=True)),
                   HGroup(Item(name='selected_line_wavenumber'),
                          show_border=True),
                   show_border=True),
            HGroup(VGroup('20', Heading("Molecules"),
                          Item(name='all_on', show_label=False),
                          Item(name='all_off', show_label=False)),
                   Item(name='selected_molecules',
                        style='custom',
                        show_label=False),
                   show_border=True), '10'), '10'),
                      handler=MPLInitHandler,
                      resizable=True,
                      title=title,
                      width=size[0],
                      height=size[1])

    def __init__(self):
        super(AtmosViewer, self).__init__()
        self.colors = {'telluric': 'black', 'orders': 'black'}
        self.molecules = molecules
        self.selected_molecules = []
        orders_filename = resource_filename(__name__, 'orders.txt')
        self.texes_orders = pandas.io.parsers.read_csv(orders_filename,
                                                       sep='\t',
                                                       header=None,
                                                       skiprows=3)
        atmos_filename = resource_filename(__name__, 'atmos.txt.gz')
        self.atmos = pandas.io.parsers.read_csv(gzip.open(atmos_filename, 'r'),
                                                sep='\t',
                                                skiprows=7,
                                                index_col='# wn')
        self.molecule_lookup_points = {
        }  #  keys are e.g. 'O3', with a dict of {'wn':..., 'y':...}
        self.axes = self.figure.add_subplot(111)
        self.axes.plot(self.atmos.index,
                       self.atmos['trans1mm'],
                       color=self.colors['telluric'])
        self.axes.plot(self.atmos.index,
                       self.atmos['trans4mm'],
                       color=self.colors['telluric'])
        for i in self.texes_orders.index:
            self.axes.plot(self.texes_orders.ix[i].values, [0.05, 0.07],
                           color=self.colors['orders'])
        self.axes.set_xlim(self.central_wavenumber - self.bandwidth / 2.,
                           self.central_wavenumber + self.bandwidth / 2.)
        self.axes.set_ylim(0, 1.0)
        self.axes.set_xlabel('Wavenumber (cm-1)')
        self.axes.xaxis.set_major_formatter(FormatStrFormatter('%6.1f'))
        self.onclick_connected = False  # I don't understand why I can't do the connection here.
        self.selected_line = None
        self.selected_line_text = None

    def on_click(self, event):
        if event.xdata is None or event.ydata is None:
            return
        if self.selected_line in self.axes.lines:
            self.axes.lines.pop(self.axes.lines.index(self.selected_line))
        if self.selected_line_text in self.axes.texts:
            self.axes.texts.remove(self.selected_line_text)
        self.selected_line = None
        self.selected_line_text = None
        self.selected_line_wavenumber = -1
        if len(self.molecule_lookup_points) == 0:
            return
        closest = {'name': None, 'wn': -1., 'dist': 9e9}
        for cur_molecule in self.molecule_lookup_points:
            wn = self.molecule_lookup_points[cur_molecule]['wn']
            ys = self.molecule_lookup_points[cur_molecule]['y']
            dist_x2 = (wn - event.xdata)**2
            xlim = self.axes.get_xlim()
            scale = ((xlim[1] - xlim[0]) /  # this is like wavenumbers/inch
                     (self.axes.figure.get_figwidth() *
                      self.axes.get_position().bounds[2]))
            dist_y2 = ((ys - event.ydata) *
                       (self.axes.figure.get_figheight() *
                        self.axes.get_position().bounds[3]) * scale)**2
            dist = np.sqrt(dist_x2 + dist_y2)
            if dist.min() < closest['dist']:
                closest = {
                    'name': cur_molecule,
                    'wn': wn[dist.argmin()],
                    'dist': dist.min()
                }
        self.selected_line_wavenumber = closest['wn']
        self.selected_line = self.axes.plot([closest['wn'], closest['wn']],
                                            [0, 1],
                                            '-.',
                                            color='black')[0]
        self.selected_line_text = self.axes.annotate(
            closest['name'] + ('%11.5f' % closest['wn']),
            (closest['wn'], 1.03),
            ha='center',
            annotation_clip=False)
        self.redraw()

    def on_scroll(self, event):
        self.central_wavenumber += self.bandwidth * event.step

    def _all_on_fired(self):
        self.selected_molecules = self.molecules.keys()

    def _all_off_fired(self):
        self.selected_molecules = []

    def mpl_setup(self):
        self.axes_widget = AxesWidget(self.figure.gca())
        self.axes_widget.connect_event('button_press_event', self.on_click)
        self.axes_widget.connect_event('scroll_event', self.on_scroll)

    @on_trait_change("central_wavenumber, bandwidth")
    def replot_molecular_overplots(self):
        for i, cur_molecule in enumerate(self.selected_molecules):
            if self.molecules[cur_molecule]['hitran'] is None:
                self.molecules[cur_molecule][
                    'hitran'] = pandas.io.parsers.read_csv(gzip.open(
                        self.molecules[cur_molecule]['hitran_filename'], 'r'),
                                                           skiprows=2)
            wn = self.molecules[cur_molecule]['hitran']['wavenumber']
            intensity = self.molecules[cur_molecule]['hitran']['intensity']
            w = ((wn >= self.central_wavenumber - self.bandwidth / 2.) &
                 (wn <= self.central_wavenumber + self.bandwidth / 2.))
            wn = wn[w]
            intensity = intensity[w]
            plot_orders_of_magnitude = 2.
            max_line_intensity = intensity.max()
            min_line_intensity = max_line_intensity / 10**plot_orders_of_magnitude
            wn = wn[intensity >= min_line_intensity]
            intensity = intensity[intensity >= min_line_intensity]
            intensity = (
                (np.log10(intensity) - np.log10(min_line_intensity)) /
                (np.log10(max_line_intensity) - np.log10(min_line_intensity)))
            intensity = intensity * 0.1
            self.molecule_lookup_points[cur_molecule] = {
                'wn': wn,
                'y': intensity + (i * 0.1) + 0.05
            }
            wn = wn.repeat(3)
            intensity = np.column_stack(
                (np.zeros(len(intensity)), intensity, np.zeros(
                    len(intensity)))).flatten() + (i * 0.1) + 0.05
            newplot = self.axes.plot(wn, intensity,
                                     self.molecules[cur_molecule]['color'])
            newtext = self.axes.annotate(
                cur_molecule, (self.central_wavenumber + self.bandwidth * 0.51,
                               i * 0.1 + 0.065),
                ha='left',
                va='center',
                annotation_clip=False,
                color=self.molecules[cur_molecule]['color'])
            if self.molecules[cur_molecule]['plot_lines'] in self.axes.lines:
                self.axes.lines.pop(
                    self.axes.lines.index(
                        self.molecules[cur_molecule]['plot_lines']))
            self.molecules[cur_molecule]['plot_lines'] = None
            if self.molecules[cur_molecule]['plot_text'] in self.axes.texts:
                self.axes.texts.remove(
                    self.molecules[cur_molecule]['plot_text'])
                self.molecules[cur_molecule]['plot_text'] = None
            self.molecules[cur_molecule]['plot_lines'] = newplot[0]
            self.molecules[cur_molecule]['plot_text'] = newtext
        self.redraw()

    def _selected_molecules_changed(self, old, new):
        self.replot_molecular_overplots()
        for cur_molecule in old:
            if cur_molecule not in new:
                if self.molecules[cur_molecule][
                        'plot_lines'] in self.axes.lines:
                    self.axes.lines.pop(
                        self.axes.lines.index(
                            self.molecules[cur_molecule]['plot_lines']))
                if self.molecules[cur_molecule][
                        'plot_text'] in self.axes.texts:
                    self.axes.texts.remove(
                        self.molecules[cur_molecule]['plot_text'])
                self.molecules[cur_molecule]['plot_lines'] = None
                self.molecules[cur_molecule]['plot_text'] = None
                self.molecule_lookup_points.pop(cur_molecule, None)
        self.redraw()

    @on_trait_change("central_wavenumber, bandwidth")
    def redraw(self):
        self.axes.set_xlim(self.central_wavenumber - self.bandwidth / 2.,
                           self.central_wavenumber + self.bandwidth / 2.)
        self.axes.set_ylim(0, 1.0)
        self.figure.canvas.draw()
Ejemplo n.º 6
0
class PhotPlotPanel(wx.Panel):
    def __init__(self, parent, dpi=None, **kwargs):
        wx.Panel.__init__(self, parent, wx.ID_ANY, wx.DefaultPosition, wx.DefaultSize, **kwargs)
        self.ztv_frame = self.GetTopLevelParent()
        self.figure = Figure(dpi=None, figsize=(1.,1.))
        self.axes = self.figure.add_subplot(111)
        self.canvas = FigureCanvasWxAgg(self, -1, self.figure)
        self.Bind(wx.EVT_SIZE, self._onSize)
        self.axes_widget = AxesWidget(self.figure.gca())
        self.axes_widget.connect_event('motion_notify_event', self.on_motion)
        self.axes_widget.connect_event('button_press_event', self.on_button_press)
        self.axes_widget.connect_event('button_release_event', self.on_button_release)
        self.axes_widget.connect_event('figure_leave_event', self.on_cursor_leave)
        self.button_down = False

    def on_button_press(self, event):
        self.aper_names = ['aprad', 'skyradin', 'skyradout']
        self.aper_last_radii = np.array([self.ztv_frame.phot_panel.aprad, 
                                         self.ztv_frame.phot_panel.skyradin,
                                         self.ztv_frame.phot_panel.skyradout])
        self.button_press_xdata = event.xdata
        self.cur_aper_index = np.abs(self.aper_last_radii - event.xdata).argmin()
        self.cur_aper_name = self.aper_names[self.cur_aper_index]
        # but, click must be within +-N pix to be valid
        if np.abs(event.xdata - self.aper_last_radii[self.cur_aper_index]) <= 20:
            self.button_down = True

    def on_motion(self, event):
        if self.button_down:
            if event.xdata is not None:
                if self.cur_aper_name == 'aprad':
                    self.ztv_frame.phot_panel.aprad = (self.aper_last_radii[self.cur_aper_index] +
                                                       (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.aprad_textctrl.SetValue('{0:.2f}'.format(self.ztv_frame.phot_panel.aprad))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.aprad_textctrl, 'ok')
                elif self.cur_aper_name == 'skyradin':
                    self.ztv_frame.phot_panel.skyradin = (self.aper_last_radii[self.cur_aper_index] +
                                                          (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.skyradin_textctrl.SetValue('{0:.2f}'.format( 
                                                                       self.ztv_frame.phot_panel.skyradin))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.skyradin_textctrl, 'ok')
                elif self.cur_aper_name == 'skyradout':
                    self.ztv_frame.phot_panel.skyradout = (self.aper_last_radii[self.cur_aper_index] +
                                                           (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.skyradout_textctrl.SetValue('{0:.2f}'.format( 
                                                                       self.ztv_frame.phot_panel.skyradout))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.skyradout_textctrl, 'ok')
                self.ztv_frame.phot_panel.recalc_phot()

    def on_button_release(self, event):
        if self.button_down:
            if event.xdata is not None:
                if self.cur_aper_name == 'aprad':
                    self.ztv_frame.phot_panel.aprad = (self.aper_last_radii[self.cur_aper_index] +
                                                       (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.aprad_textctrl.SetValue('{0:.2f}'.format(self.ztv_frame.phot_panel.aprad))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.aprad_textctrl, 'ok')
                elif self.cur_aper_name == 'skyradin':
                    self.ztv_frame.phot_panel.skyradin = (self.aper_last_radii[self.cur_aper_index] +
                                                          (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.skyradin_textctrl.SetValue('{0:.2f}'.format( 
                                                                       self.ztv_frame.phot_panel.skyradin))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.skyradin_textctrl, 'ok')
                elif self.cur_aper_name == 'skyradout':
                    self.ztv_frame.phot_panel.skyradout = (self.aper_last_radii[self.cur_aper_index] +
                                                           (event.xdata - self.button_press_xdata))
                    self.ztv_frame.phot_panel.skyradout_textctrl.SetValue('{0:.2f}'.format( 
                                                                       self.ztv_frame.phot_panel.skyradout))
                    set_textctrl_background_color(self.ztv_frame.phot_panel.skyradout_textctrl, 'ok')
                self.ztv_frame.phot_panel.recalc_phot()
        self.button_down = False
    
    def on_cursor_leave(self, event):
        if self.button_down:
            if self.cur_aper_name == 'aprad':
                self.ztv_frame.phot_panel.aprad = self.aper_last_radii[self.cur_aper_index]
                self.ztv_frame.phot_panel.aprad_textctrl.SetValue('{0:.2f}'.format(self.ztv_frame.phot_panel.aprad))
                set_textctrl_background_color(self.ztv_frame.phot_panel.aprad_textctrl, 'ok')
            elif self.cur_aper_name == 'skyradin':
                self.ztv_frame.phot_panel.skyradin = self.aper_last_radii[self.cur_aper_index]
                self.ztv_frame.phot_panel.skyradin_textctrl.SetValue('{0:.2f}'.format( 
                                                                   self.ztv_frame.phot_panel.skyradin))
                set_textctrl_background_color(self.ztv_frame.phot_panel.skyradin_textctrl, 'ok')
            elif self.cur_aper_name == 'skyradout':
                self.ztv_frame.phot_panel.skyradout = self.aper_last_radii[self.cur_aper_index]
                self.ztv_frame.phot_panel.skyradout_textctrl.SetValue('{0:.2f}'.format( 
                                                                   self.ztv_frame.phot_panel.skyradout))
                set_textctrl_background_color(self.ztv_frame.phot_panel.skyradout_textctrl, 'ok')
            self.ztv_frame.phot_panel.recalc_phot()
        self.button_down=False

    def _onSize(self, event):
        self._SetSize()

    def _SetSize(self):
        pixels = tuple(self.GetClientSize())
        self.SetSize(pixels)
        self.canvas.SetSize(pixels)
        self.figure.set_size_inches(float(pixels[0])/self.figure.get_dpi(), float(pixels[1])/self.figure.get_dpi())
Ejemplo n.º 7
0
class multicolorfits_viewer(HasTraits):
    """The main window. Has instructions for creating and destroying the app.
    """

    panel1 = Instance(ControlPanel)
    panel2 = Instance(ControlPanel)
    panel3 = Instance(ControlPanel)
    panel4 = Instance(ControlPanel)

    figure_combined = Instance(Figure, ())
    image = Array()
    image_axes = Instance(Axes)
    image_axesimage = Instance(AxesImage)
    image_xsize = Int(256)
    image_ysize = Int(256)

    gamma = Float(2.2)

    tickcolor = Str(
        '0.9'
    )  #,auto_set=False,enter_set=True) #Apparently need to set to TextEditor explicitly below...
    tickcolor_picker = ColorTrait((230, 230, 230))
    sexdec = Enum('Sexagesimal', 'Decimal')

    plotbutton_combined = Button(u"Plot Combined")
    plotbutton_inverted_combined = Button(u"Plot Inverted Combined")
    clearbutton_combined = Button(u"Clear Combined")
    save_the_image = Button(u"Save Image")
    save_the_fits = Button(u"Save RGB Fits")
    print_params = Button(u"Print Params")

    status_string_left = Str('')
    status_string_right = Str('')

    def _panel1_default(self):
        return ControlPanel()  #figure=self.figure)

    def _panel2_default(self):
        return ControlPanel()  #figure=self.figure)

    def _panel3_default(self):
        return ControlPanel()  #figure=self.figure)

    def _panel4_default(self):
        return ControlPanel()  #figure=self.figure)

    def __init__(self):
        super(multicolorfits_viewer, self).__init__()

        self._init_params(
        )  #Set placeholder things like the WCS, tick color, map units...
        self.image = self._fresh_image()  #Sets a blank image
        self.image_axes = self.figure_combined.add_subplot(111, aspect=1)
        self.image_axesimage = self.image_axes.imshow(self.image,
                                                      cmap='gist_gray',
                                                      origin='lower',
                                                      interpolation='nearest')
        self.image_axes.set_xlabel(self.xlabel)
        self.image_axes.set_ylabel(self.ylabel)
        self.image_axes.tick_params(
            axis='both',
            color=self.tickcolor)  #colors=... also sets label color
        try:
            self.image_axes.coords.frame.set_color(
                self.tickcolor
            )  #Updates the frame color.  .coords won't exist until WCS set
        except:
            [
                self.image_axes.spines[s].set_color(self.tickcolor)
                for s in ['top', 'bottom', 'left', 'right']
            ]

    view = View(Item("gamma",label=u"Gamma",show_label=True),
                Item('_'),

                HSplit(
                  Group(
                    Group(Item('panel1', style="custom",show_label=False),label='Image 1'),
                    Group(Item('panel2', style="custom",show_label=False),label='Image 2'),
                    Group(Item('panel3', style="custom",show_label=False),label='Image 3'),
                    Group(Item('panel4', style="custom",show_label=False),label='Image 4'),
                  orientation='horizontal',layout='tabbed',springy=True),

                VGroup(
                    HGroup(
                      Item('tickcolor',label='Tick Color',show_label=True, \
                          tooltip='Color of ticks: standard name float[0..1], or #hex', \
                          editor=TextEditor(auto_set=False, enter_set=True,)),
                      Item('tickcolor_picker',label='Pick',show_label=True,editor=ColorEditor()),
                      Item('sexdec',label='Coordinate Style',tooltip=u'Display coordinates in sexagesimal or decimal', \
                           show_label=True),
                      ),
                    Item('figure_combined', editor=MPLFigureEditor(),show_label=False, width=900, height=800,resizable=True),
                  HGroup(

                    Item('plotbutton_combined', tooltip=u"Plot the image",show_label=False),
                    Item('plotbutton_inverted_combined', tooltip=u"Plot the inverted image",show_label=False),
                    Item('clearbutton_combined',tooltip=u'Clear the combined figure',show_label=False),
                    Item("save_the_image", tooltip=u"Save current image. Mileage may vary...",show_label=False),
                    Item("save_the_fits", tooltip=u"Save RGB frames as single fits file with header.",show_label=False),
                    Item("print_params", tooltip=u"Print out current settings for use in manual image scripting.",show_label=False),
                  ), #HGroup
                ), #VGroup
                show_labels=False,),
           resizable=True,
           height=0.75, width=0.75,
           statusbar = [StatusItem(name = 'status_string_left', width = 0.5),
                        StatusItem(name = 'status_string_right', width = 0.5)],
           title=u"Fits Multi-Color Combiner",handler=MPLInitHandler ) #View

    def _init_params(self):
        plt.rcParams.update({'font.family': 'serif','xtick.major.size':6,'ytick.major.size':6, \
                             'xtick.major.width':1.,'ytick.major.width':1., \
                             'xtick.direction':'in','ytick.direction':'in'})
        try:
            plt.rcParams.update({
                'xtick.top': True,
                'ytick.right': True
            })  #apparently not in mpl v<2.0...
        except:
            pass  #Make a workaround for mpl<2.0 later...
        self.datamin_initial = 0.
        self.datamax_initial = 1.
        self.datamin = 0.
        self.datamax = 1.  #This will be the displayed value of the scaling min/max
        #self.mapunits='Pixel Value'
        #self.tickcolor='0.5'#'white', 'black', '0.5'
        self.wcs = WCS()
        self.xlabel = 'x'
        self.ylabel = 'y'

    def _fresh_image(self):
        #self.norm=ImageNormalize(self.image,stretch=scaling_fns['linear']() )
        blankdata = np.zeros([100, 100])
        blankdata[-1, -1] = 1
        return blankdata

    def update_radecpars(self):
        self.rapars = self.image_axes.coords[0]
        self.decpars = self.image_axes.coords[1]
        self.rapars.set_ticks(color=self.tickcolor)
        self.decpars.set_ticks(color=self.tickcolor)
        self.rapars.set_ticks(number=6)
        self.decpars.set_ticks(number=6)
        #self.rapars.set_ticklabel(size=8); self.decpars.set_ticklabel(size=8); #size here means the tick length
        ##self.rapars.set_ticks(spacing=10*u.arcmin, color='white', exclude_overlapping=True)
        ##self.decpars.set_ticks(spacing=5*u.arcmin, color='white', exclude_overlapping=True)
        self.rapars.display_minor_ticks(True)
        #self.rapars.set_minor_frequency(10)
        self.decpars.display_minor_ticks(True)
        if self.sexdec == 'Sexagesimal':
            self.rapars.set_major_formatter('hh:mm:ss.ss')
            self.decpars.set_major_formatter('dd:mm:ss.ss')
            #self.rapars.set_separator(('$^\mathrm{H}$', "'", '"'))
            self.rapars.set_separator(('H ', "' ", '" '))
            #format_xcoord=lambda x,y: '{}i$^\mathrm{H}${}{}{}'.format(x[0],x[1],"'",x[2],'"')
            #self.image_axes.format_coord=format_xcoord
        else:
            self.rapars.set_major_formatter('d.dddddd')
            self.decpars.set_major_formatter('d.dddddd')
        ##self.decpars.ticklabels.set_rotation(45) #Rotate ticklabels
        ##self.decpars.ticklabels.set_color(xkcdrust) #Ticklabel Color

    @on_trait_change('tickcolor')
    def update_tickcolor(self):
        try:
            #Catch case when you've predefined a color variable in hex string format, e.g., mynewred='#C11B17'
            #--> Need to do this first, otherwise traits throws a fit up the stack even despite the try/except check
            globals()[
                self.tickcolor]  #This check should catch undefined inputs
            self.image_axes.tick_params(axis='both',
                                        color=globals()[self.tickcolor])
            self.image_axes.coords.frame.set_color(self.tickcolor)
            self.tickcolor_picker = hex_to_rgb(globals()[self.tickcolor])
            self.status_string_right = 'Tick color changed to ' + self.tickcolor
        except:
            try:
                self.tickcolor = to_hex(self.tickcolor)
                try:
                    self.update_radecpars()
                except:
                    self.image_axes.tick_params(axis='both',
                                                color=to_hex(self.tickcolor))
                    self.image_axes.coords.frame.set_color(
                        to_hex(self.tickcolor))
                self.status_string_right = 'Tick color changed to ' + self.tickcolor
            except:
                self.status_string_right = "Color name %s not recognized.  Must be standard mpl.colors string, float[0..1] or #hex string" % (
                    self.tickcolor)
        try:
            self.tickcolor_picker = hex_to_rgb(to_hex(
                self.tickcolor))  #update the picker color...
        except:
            pass
        self.figure_combined.canvas.draw()

    @on_trait_change('tickcolor_picker')
    def update_tickcolorpicker(self):
        #print self.tickcolor_picker.name()
        self.tickcolor = self.tickcolor_picker.name()

    @on_trait_change('sexdec')
    def update_sexdec(self):
        self.update_radecpars()
        self.figure_combined.canvas.draw()
        self.status_string_right = 'Coordinate style changed to ' + self.sexdec

    def _plotbutton_combined_fired(self):
        try:
            self.panel1.data
        except:
            self.status_string_right = "No fits file loaded yet!"
            return
        #self.image=self.panel1.data
        self.wcs = WCS(self.panel1.hdr)
        self.hdr = self.panel1.hdr

        self.combined_RGB = combine_multicolor([
            pan.image_colorRGB
            for pan in [self.panel1, self.panel2, self.panel3, self.panel4]
            if pan.in_use == True
        ],
                                               gamma=self.gamma)

        ###Using this command is preferable, as long as the projection doesn't need to be updated...
        #  The home zoom button will work, but no WCS labels because projection wasn't set during init.
        #self.image_axesimage.set_data(self.data)
        ###Using this set instead properly updates the axes labels to WCS, but the home zoom button won't work
        self.figure_combined.clf()
        self.image_axes = self.figure_combined.add_subplot(111,
                                                           aspect=1,
                                                           projection=self.wcs)
        self.image_axesimage = self.image_axes.imshow(self.combined_RGB,
                                                      origin='lower',
                                                      interpolation='nearest')

        self.update_radecpars()
        self.figure_combined.canvas.draw()
        self.status_string_right = "Plot updated"

    def _plotbutton_inverted_combined_fired(self):
        try:
            self.panel1.data
        except:
            self.status_string_right = "No fits file loaded yet!"
            return
        self.wcs = WCS(self.panel1.hdr)
        self.hdr = self.panel1.hdr
        self.combined_RGB = combine_multicolor(
            [
                pan.image_colorRGB for pan in
                [self.panel1, self.panel2, self.panel3, self.panel4]
                if pan.in_use == True
            ],
            inverse=True,
            gamma=self.gamma,
        )
        self.figure_combined.clf()
        self.image_axes = self.figure_combined.add_subplot(111,
                                                           aspect=1,
                                                           projection=self.wcs)
        self.image_axesimage = self.image_axes.imshow(self.combined_RGB,
                                                      origin='lower',
                                                      interpolation='nearest')
        self.update_radecpars()
        self.figure_combined.canvas.draw()
        self.status_string_right = "Plot updated"

    def _clearbutton_combined_fired(self):
        try:
            del self.combined_RGB  #If clear already pressed once, data will already have been deleted...
        except:
            pass
        self.in_use = False
        self.figure_combined.clf()
        self.image = self._fresh_image()
        self.image_axes = self.figure_combined.add_subplot(111, aspect=1)
        self.image_axesimage = self.image_axes.imshow(self.image,
                                                      cmap='gist_gray',
                                                      origin='lower',
                                                      interpolation='nearest')
        self.xlabel = 'x'
        self.ylabel = 'y'
        self.image_axes.set_xlabel(self.xlabel)
        self.image_axes.set_ylabel(self.ylabel)
        self.image_axes.tick_params(axis='both', color=self.tickcolor)
        try:
            self.image_axes.coords.frame.set_color(self.tickcolor)
        except:
            self.tickcolor_picker = hex_to_rgb(to_hex(self.tickcolor))
        self.figure_combined.canvas.draw()
        self.status_string_right = "Plot cleared"

    def setup_mpl_events(self):
        self.image_axeswidget = AxesWidget(self.image_axes)
        self.image_axeswidget.connect_event('motion_notify_event',
                                            self.image_on_motion)
        self.image_axeswidget.connect_event('figure_leave_event',
                                            self.on_cursor_leave)
        self.image_axeswidget.connect_event('figure_enter_event',
                                            self.on_cursor_enter)
        self.image_axeswidget.connect_event('button_press_event',
                                            self.image_on_click)

    def image_on_motion(self, event):
        if event.xdata is None or event.ydata is None: return
        x = int(np.round(event.xdata))
        y = int(np.round(event.ydata))
        if ((x >= 0) and (x < self.image.shape[1]) and (y >= 0)
                and (y < self.image.shape[0])):
            imval = self.image[y, x]
            self.status_string_left = "x,y={},{}  {:.5g}".format(x, y, imval)
        else:
            self.status_string_left = ""

    def image_on_click(self, event):
        if event.xdata is None or event.ydata is None or event.button is not 1:
            return  #Covers when click outside of main plot
        #print event
        x = int(
            np.round(event.xdata)
        )  #xdata is the actual pixel position.  xy is in 'display space', i.e. pixels in the canvas
        y = int(np.round(event.ydata))
        xwcs, ywcs = self.wcs.wcs_pix2world([[x, y]], 0)[0]
        #print xwcs,ywcs
        if ((x >= 0) and (x < self.image.shape[1]) and (y >= 0)
                and (y < self.image.shape[0])):
            imval = self.image[y, x]
            self.status_string_right = "x,y=[{},{}], RA,DEC=[{}, {}], value = {:.5g}".format(
                x, y, xwcs, ywcs, imval)
            #self.status_string_right = "x,y[{},{}] = {:.3f},{:.3f}  {:.5g}".format(x, y,event.xdata,event.ydata, imval)
        else:
            self.status_string_right = ""
        ## left-click: event.button = 1, middle-click: event.button=2, right-click: event.button=3.
        ## For double-click, event.dblclick = False for first click, True on second
        #print event.button, event.dblclick

    def on_cursor_leave(self, event):
        QApplication.restoreOverrideCursor()
        self.status_string_left = ''

    def on_cursor_enter(self, event):
        QApplication.setOverrideCursor(Qt.CrossCursor)

    def _save_the_image_fired(self):
        dlg = FileDialog(action='save as')
        if dlg.open() == OK: plt.savefig(dlg.path, size=(800, 800), dpi=300)

    def _save_the_fits_fired(self):
        #Generate a generic header with correct WCS and comments about the colors that made it
        #... come back and finish this later...
        dlg = FileDialog(action='save as')
        if dlg.open() == OK:
            pyfits.writeto(
                dlg.path,
                np.swapaxes(np.swapaxes(self.combined_RGB, 0, 2), 2, 1),
                self.hdr)

    def _print_params_fired(self):
        print('\n\nRGB Image plot params:')
        pan_i = 0
        for pan in [self.panel1, self.panel2, self.panel3, self.panel4]:
            pan_i += 1
            if pan.in_use == True:
                print('image%i: ' % (pan_i))
                print('    vmin = %.3e , vmax = %.3e, scale = %s' %
                      (pan.datamin, pan.datamax, pan.image_scale))
                print("    image color = '%s'" % (pan.imagecolor))
        print("gamma = %.1f , tick color = '%s'\n" %
              (self.gamma, self.tickcolor))
Ejemplo n.º 8
0
class ControlPanel(HasTraits):
    """This is the control panel where the various parameters for the images are specified
    """

    gamma = 2.2

    fitsfile = File(filter=[u"*.fits"])
    image_figure = Instance(Figure, ())
    image = Array()
    image_axes = Instance(Axes)
    image_axesimage = Instance(AxesImage)
    image_xsize = Int(256)
    image_ysize = Int(256)

    datamin = Float(0.0, auto_set=False, enter_set=True)  #Say, in mJy
    datamax = Float(
        1.0, auto_set=False, enter_set=True
    )  #auto_set=input set on each keystroke, enter_set=set after Enter
    percent_min = Range(value=0.0, low=0.0, high=100.)
    percent_max = Range(value=100.0, low=0.0,
                        high=100.)  #Percentile of data values for rescaling
    minmaxbutton = Button('Min/Max')
    zscalebutton = Button('Zscale')

    image_scale = Str('linear')
    scale_dropdown = Enum(
        ['linear', 'sqrt', 'squared', 'log', 'power', 'sinh', 'asinh'])

    imagecolor = Str('#FFFFFF')
    imagecolor_picker = ColorTrait((255, 255, 255))

    #plotbeam_button=Button('Add Beam (FWHM)')

    plotbutton_individual = Button(u"Plot Single")
    plotbutton_inverted_individual = Button(u"Plot Inverted Single")
    clearbutton_individual = Button(u"Clear Single")

    status_string_left = Str('')
    status_string_right = Str('')

    def __init__(self):
        self._init_params(
        )  #Set placeholder things like the WCS, tick color, map units...
        self.image = self._fresh_image()  #Sets a blank image
        self.image_axes = self.image_figure.add_subplot(111, aspect=1)
        self.image_axesimage = self.image_axes.imshow(self.image,
                                                      cmap='gist_gray',
                                                      origin='lower',
                                                      interpolation='nearest')
        self.image_axes.axis('off')

    view = View(
      HSplit(
        VGroup(
            Item("fitsfile", label=u"Select 2D FITS file", show_label=True), #,height=100),

            HGroup(
              VGroup(
                     Item('plotbutton_individual', tooltip=u"Plot the single image",show_label=False),
                     Item('plotbutton_inverted_individual', tooltip=u"Plot the single inverted image",show_label=False),
                     Item('clearbutton_individual', tooltip=u"Clear the single image",show_label=False),
                     Item('_'),

                     Item('imagecolor',label='Image Color',show_label=True, \
                      tooltip='Color of ticks: standard name float[0..1], or #hex', \
                      editor=TextEditor(auto_set=False, enter_set=True,)),
                     Item('imagecolor_picker',label='Pick',show_label=True,editor=ColorEditor()),

                     ),
              Item('image_figure', editor=MPLFigureEditor(), show_label=False, width=300, height=300,resizable=True),
            ),
            HGroup(Item('datamin', tooltip=u"Minimum data val for scaling", show_label=True),
                   Item('datamax', tooltip=u"Maximum data val for scaling", show_label=True)),
            Item('percent_min', tooltip=u"Min. percentile for scaling", show_label=True),
            Item('percent_max', tooltip=u"Max. percentile for scaling", show_label=True),
            HGroup(Item('minmaxbutton', tooltip=u"Reset to data min/max", show_label=False),
                   Item('zscalebutton', tooltip=u"Compute scale min/max from zscale algorithm", show_label=False)),
            Item('scale_dropdown',label='Scale',show_label=True),

        ), #End of Left column

      ), #End of HSplit
      resizable=True, #height=0.75, width=0.75, #title=u"Multi-Color Image Combiner", 
      handler=MPLInitHandler,
      statusbar = [StatusItem(name = 'status_string_left', width = 0.5), StatusItem(name = 'status_string_right', width = 0.5)]
    ) #End of View

    def _init_params(self):
        self.in_use = False
        plt.rcParams.update({'font.family': 'serif','xtick.major.size':6,'ytick.major.size':6, \
                             'xtick.major.width':1.,'ytick.major.width':1., \
                             'xtick.direction':'in','ytick.direction':'in'})
        try:
            plt.rcParams.update({
                'xtick.top': True,
                'ytick.right': True
            })  #apparently not in mpl v<2.0...
        except:
            pass  #Make a workaround for mpl<2.0 later...
        self.datamin_initial = 0.
        self.datamax_initial = 1.
        self.datamin = 0.
        self.datamax = 1.  #This will be the displayed value of the scaling min/max

    def _fresh_image(self):
        blankdata = np.zeros([100, 100])
        blankdata[-1, -1] = 1
        return blankdata

    def _fitsfile_changed(self):
        self.data, self.hdr = pyfits.getdata(self.fitsfile, header=True)
        force_hdr_floats(
            self.hdr
        )  #Ensure that WCS cards such as CDELT are floats instead of strings

        naxis = int(self.hdr['NAXIS'])
        if naxis > 2:
            #print('Dropping Extra axes')
            self.hdr = force_hdr_to_2D(self.hdr)
            try:
                self.data = self.data[0, 0, :, :]
            except:
                self.data = self.data[0, :, :]
            self.status_string_right = 'Dropped extra axes'

        self.datamax_initial = np.asscalar(np.nanmax(self.data))
        self.datamin_initial = np.asscalar(np.nanmin(self.data))
        self.datamax = np.asscalar(np.nanmax(self.data))
        self.datamin = np.asscalar(np.nanmin(self.data))

        self.in_use = True

    #@on_trait_change('imagecolor')
    #def update_imagecolor(self):
    def _imagecolor_changed(self):
        try:
            #Catch case when you've predefined a color variable in hex string format, e.g., mynewred='#C11B17'
            #--> Need to do this first, otherwise traits throws a fit up the stack even despite the try/except check
            globals()[
                self.imagecolor]  #This check should catch undefined inputs
            self.imagecolor_picker = hex_to_rgb(globals()[self.imagecolor])
            self.status_string_right = 'Image color changed to ' + self.imagecolor
        except:
            try:
                self.imagecolor = to_hex(self.imagecolor)
                self.status_string_right = 'Image color changed to ' + self.imagecolor
            except:
                self.status_string_right = "Color name %s not recognized.  Must be standard mpl.colors string, float[0..1] or #hex string" % (
                    self.imagecolor)
        try:
            self.imagecolor_picker = hex_to_rgb(to_hex(
                self.imagecolor))  #update the picker color...
        except:
            pass
        ### self.image_greyRGB and self.image_colorRGB may not yet be instantiated if the color is changed before clicking 'plot'
        try:
            self.image_colorRGB = colorize_image(self.image_greyRGB,
                                                 self.imagecolor,
                                                 colorintype='hex',
                                                 gammacorr_color=self.gamma)
        except:
            pass
        try:
            self.image_axesimage.set_data(self.image_colorRGB**(1. /
                                                                self.gamma))
        except:
            pass
        self.in_use = True
        self.image_figure.canvas.draw()

    #@on_trait_change('imagecolor_picker')
    #def update_imagecolorpicker(self):
    def _imagecolor_picker_changed(self):
        #print self.tickcolor_picker.name()
        self.imagecolor = self.imagecolor_picker.name()

    #@on_trait_change('percent_min')
    #def update_scalepercmin(self):
    def _percent_min_changed(self):
        self.datamin = np.nanpercentile(self.data, self.percent_min)
        self.data_scaled = (
            scaling_fns[self.image_scale]() +
            ManualInterval(vmin=self.datamin, vmax=self.datamax))(self.data)
        self.image_greyRGB = ski_color.gray2rgb(
            adjust_gamma(self.data_scaled, self.gamma))
        self.image_colorRGB = colorize_image(self.image_greyRGB,
                                             self.imagecolor,
                                             colorintype='hex',
                                             gammacorr_color=self.gamma)
        self.image_axesimage.set_data(self.image_colorRGB**(1. / self.gamma))
        self.image_figure.canvas.draw()
        self.status_string_right = "Updated scale using percentiles"

    #@on_trait_change('percent_max')
    #def update_scalepercmax(self):
    def _percent_max_changed(self):
        self.datamax = np.nanpercentile(self.data, self.percent_max)
        self.data_scaled = (
            scaling_fns[self.image_scale]() +
            ManualInterval(vmin=self.datamin, vmax=self.datamax))(self.data)
        self.image_greyRGB = ski_color.gray2rgb(
            adjust_gamma(self.data_scaled, self.gamma))
        self.image_colorRGB = colorize_image(self.image_greyRGB,
                                             self.imagecolor,
                                             colorintype='hex',
                                             gammacorr_color=self.gamma)
        self.image_axesimage.set_data(self.image_colorRGB**(1. / self.gamma))
        self.image_figure.canvas.draw()
        self.status_string_right = "Updated scale using percentiles"

    ### Very slow to update datamin and datamax as well as percs... Can comment these if desired and just hit plot after datamin

    #@on_trait_change('datamin')
    #def update_datamin(self): self.percent_min=np.round(percentileofscore(self.data.ravel(),self.datamin,kind='strict'),2)
    def _datamin_changed(self):
        self.percent_min = np.round(
            percentileofscore(self.data.ravel(), self.datamin, kind='strict'),
            2)

    #@on_trait_change('datamax')
    #def update_datamax(self): self.percent_max=np.round(percentileofscore(self.data.ravel(),self.datamax,kind='strict'),2)
    def _datamax_changed(self):
        self.percent_max = np.round(
            percentileofscore(self.data.ravel(), self.datamax, kind='strict'),
            2)

    #@on_trait_change('scale_dropdown')
    #def update_image_scale(self):
    def _scale_dropdown_changed(self):
        self.image_scale = self.scale_dropdown
        #self.norm=ImageNormalize(self.sregion,stretch=scaling_fns[self.image_scale]() )
        self.data_scaled = (
            scaling_fns[self.image_scale]() +
            ManualInterval(vmin=self.datamin, vmax=self.datamax))(self.data)
        #*** Instead, should I just integrate my imscale class here instead of astropy? ...
        self.image_greyRGB = ski_color.gray2rgb(
            adjust_gamma(self.data_scaled, self.gamma))
        self.image_colorRGB = colorize_image(self.image_greyRGB,
                                             self.imagecolor,
                                             colorintype='hex',
                                             gammacorr_color=self.gamma)

        self.image_axesimage.set_data(self.image_colorRGB**(1. / self.gamma))

        self.in_use = True

        self.image_figure.canvas.draw()
        self.status_string_right = 'Image scale function changed to ' + self.image_scale

    def _minmaxbutton_fired(self):
        self.datamin = self.datamin_initial
        self.datamax = self.datamax_initial
        #self.image_axesimage.norm.vmin=self.datamin
        #self.image_axesimage.norm.vmax=self.datamax
        self.percent_min = np.round(
            percentileofscore(self.data.ravel(), self.datamin, kind='strict'),
            2)
        self.percent_max = np.round(
            percentileofscore(self.data.ravel(), self.datamax, kind='strict'),
            2)
        #self.image_figure.canvas.draw()
        self.status_string_right = "Scale reset to min/max"

    def _zscalebutton_fired(self):
        tmpZscale = ZScaleInterval().get_limits(self.data)
        self.datamin = float(tmpZscale[0])
        self.datamax = float(tmpZscale[1])
        self.percent_min = np.round(
            percentileofscore(self.data.ravel(), self.datamin, kind='strict'),
            2)
        self.percent_max = np.round(
            percentileofscore(self.data.ravel(), self.datamax, kind='strict'),
            2)
        #self.image_figure.canvas.draw()
        self.status_string_right = "Min/max determined by zscale"

    def _plotbutton_individual_fired(self):
        try:
            self.data
        except:
            self.status_string_right = "No fits file loaded yet!"
            return
        #self.image=self.data
        ###Using this command is preferable, as long as the projection doesn't need to be updated...
        #  The home zoom button will work, but no WCS labels because projection wasn't set during init.
        #Scale the data to [0,1] range
        self.data_scaled = (
            scaling_fns[self.image_scale]() +
            ManualInterval(vmin=self.datamin, vmax=self.datamax))(self.data)
        #Convert scale[0,1] image to greyscale RGB image
        self.image_greyRGB = ski_color.gray2rgb(
            adjust_gamma(self.data_scaled, self.gamma))
        self.image_colorRGB = colorize_image(self.image_greyRGB,
                                             self.imagecolor,
                                             colorintype='hex',
                                             gammacorr_color=self.gamma)
        self.image_axesimage.set_data(self.image_colorRGB**(1. / self.gamma))
        ###Using this set instead properly updates the axes labels to WCS, but the home zoom button won't work
        #self.image_figure.clf()
        #self.image_axes = self.image_figure.add_subplot(111,aspect=1)#,projection=self.wcs)
        #self.image_axesimage = self.image_axes.imshow(self.image, cmap=self.image_cmap,origin='lower',interpolation='nearest', norm=self.norm)

        self.percent_min = np.round(
            percentileofscore(self.data.ravel(), self.datamin, kind='strict'),
            2)
        self.percent_max = np.round(
            percentileofscore(self.data.ravel(), self.datamax, kind='strict'),
            2)

        self.in_use = True

        #self.update_radecpars()
        self.image_figure.canvas.draw()
        self.status_string_right = "Plot updated"

    def _plotbutton_inverted_individual_fired(self):
        try:
            self.data
        except:
            self.status_string_right = "No fits file loaded yet!"
            return
        self.data_scaled = (
            scaling_fns[self.image_scale]() +
            ManualInterval(vmin=self.datamin, vmax=self.datamax))(self.data)
        self.image_greyRGB = ski_color.gray2rgb(
            adjust_gamma(self.data_scaled, self.gamma))
        self.image_colorRGB = colorize_image(self.image_greyRGB,
                                             hexinv(self.imagecolor),
                                             colorintype='hex',
                                             gammacorr_color=self.gamma)
        #self.image_axesimage.set_data(1.-self.image_colorRGB**(1./self.gamma))
        self.image_axesimage.set_data(
            combine_multicolor([
                self.image_colorRGB,
            ],
                               gamma=self.gamma,
                               inverse=True))
        self.percent_min = np.round(
            percentileofscore(self.data.ravel(), self.datamin, kind='strict'),
            2)
        self.percent_max = np.round(
            percentileofscore(self.data.ravel(), self.datamax, kind='strict'),
            2)
        self.in_use = True
        self.image_figure.canvas.draw()
        self.status_string_right = "Plot updated"

    def _clearbutton_individual_fired(self):
        try:
            del self.data, self.data_scaled, self.image_greyRGB
            self.image_colorRGB  #In case clear already pressed once
        except:
            pass
        self.in_use = False
        self.image_figure.clf()
        self.image = self._fresh_image()
        self.image_axes = self.image_figure.add_subplot(111, aspect=1)
        self.image_axesimage = self.image_axes.imshow(self.image,
                                                      cmap='gist_gray',
                                                      origin='lower',
                                                      interpolation='nearest')
        self.image_axes.axis('off')
        self.image_figure.canvas.draw()
        self.status_string_right = "Plot cleared"

    def setup_mpl_events(self):
        self.image_axeswidget = AxesWidget(self.image_axes)
        self.image_axeswidget.connect_event('motion_notify_event',
                                            self.image_on_motion)
        self.image_axeswidget.connect_event('figure_leave_event',
                                            self.on_cursor_leave)
        self.image_axeswidget.connect_event('figure_enter_event',
                                            self.on_cursor_enter)
        self.image_axeswidget.connect_event('button_press_event',
                                            self.image_on_click)

    def image_on_motion(self, event):
        if event.xdata is None or event.ydata is None: return
        x = int(np.round(event.xdata))
        y = int(np.round(event.ydata))
        if ((x >= 0) and (x < self.image.shape[1]) and (y >= 0)
                and (y < self.image.shape[0])):
            imval = self.image[y, x]
            self.status_string_left = "x,y={},{}  {:.5g}".format(x, y, imval)
        else:
            self.status_string_left = ""

    def image_on_click(self, event):
        if event.xdata is None or event.ydata is None or event.button is not 1:
            return  #Covers when click outside of main plot
        #print event
        x = int(
            np.round(event.xdata)
        )  #xdata is the actual pixel position.  xy is in 'display space', i.e. pixels in the canvas
        y = int(np.round(event.ydata))
        #xwcs,ywcs=self.wcs.wcs_pix2world([[x,y]],0)[0]; #print xwcs,ywcs
        if ((x >= 0) and (x < self.image.shape[1]) and (y >= 0)
                and (y < self.image.shape[0])):
            imval = self.image[y, x]
            #self.status_string_right = "x,y=[{},{}], RA,DEC=[{}, {}], value = {:.5g}".format(x, y,xwcs,ywcs, imval)
            self.status_string_right = "x,y[{},{}] = {:.3f},{:.3f}  {:.5g}".format(
                x, y, event.xdata, event.ydata, imval)
        else:
            self.status_string_right = ""
        ## left-click: event.button = 1, middle-click: event.button=2, right-click: event.button=3.
        ## For double-click, event.dblclick = False for first click, True on second
        #print event.button, event.dblclick

    def on_cursor_leave(self, event):
        QApplication.restoreOverrideCursor()
        self.status_string_left = ''

    def on_cursor_enter(self, event):
        QApplication.setOverrideCursor(Qt.CrossCursor)
Ejemplo n.º 9
0
class AtmosViewer(HasTraits):
    central_wavenumber = CFloat(1000)
    bandwidth = CFloat(10)

    selected_line_wavenumber = Float(-1.)

    figure = Instance(Figure, ())

    all_on = Button()
    all_off = Button()
    selected_molecules = List(editor=CheckListEditor(values=molecules.keys(),
                                                            cols=2, format_str = '%s'))

    mplFigureEditor = MPLFigureEditor()

    trait_view = View(VGroup(Item('figure', editor=mplFigureEditor, show_label=False),
                             HGroup('10',
                                    VGroup('40',
                                           Item(name='central_wavenumber',
                                                editor=TextEditor(auto_set=False, enter_set=True)),
                                           Item(name='bandwidth',
                                                editor=TextEditor(auto_set=False, enter_set=True)),
                                           HGroup(Item(name='selected_line_wavenumber'),
                                                  show_border=True),
                                           show_border=True),
                                    HGroup(
                                        VGroup('20', Heading("Molecules"),
                                               Item(name='all_on', show_label=False),
                                               Item(name='all_off', show_label=False)),
                                        Item(name='selected_molecules', style='custom', show_label=False),
                                        show_border=True), '10'),
                             '10'),
                      handler=MPLInitHandler,
                      resizable=True, title=title, width=size[0], height=size[1])


    def __init__(self):
        super(AtmosViewer, self).__init__()
        self.colors = {'telluric':'black',
                       'orders':'black'}
        self.molecules = molecules
        self.selected_molecules = []
        orders_filename = resource_filename(__name__, 'orders.txt')
        self.texes_orders = pandas.io.parsers.read_csv(orders_filename, sep='\t', header=None, skiprows=3)
        atmos_filename = resource_filename(__name__, 'atmos.txt.gz')
        self.atmos = pandas.io.parsers.read_csv(gzip.open(atmos_filename, 'r'), sep='\t', skiprows=7, index_col='# wn')
        self.molecule_lookup_points = {}  #  keys are e.g. 'O3', with a dict of {'wn':..., 'y':...}
        self.axes = self.figure.add_subplot(111)
        self.axes.plot(self.atmos.index, self.atmos['trans1mm'], color=self.colors['telluric'])
        self.axes.plot(self.atmos.index, self.atmos['trans4mm'], color=self.colors['telluric'])
        for i in self.texes_orders.index:
            self.axes.plot(self.texes_orders.ix[i].values, [0.05, 0.07], color=self.colors['orders'])
        self.axes.set_xlim(self.central_wavenumber - self.bandwidth / 2.,
                           self.central_wavenumber + self.bandwidth / 2.)
        self.axes.set_ylim(0, 1.0)
        self.axes.set_xlabel('Wavenumber (cm-1)')
        self.axes.xaxis.set_major_formatter(FormatStrFormatter('%6.1f'))
        self.onclick_connected = False  # I don't understand why I can't do the connection here.
        self.selected_line = None
        self.selected_line_text = None

    def on_click(self, event):
        if event.xdata is None or event.ydata is None:
            return
        if self.selected_line in self.axes.lines:
            self.axes.lines.pop(self.axes.lines.index(self.selected_line))
        if self.selected_line_text in self.axes.texts:
            self.axes.texts.remove(self.selected_line_text)
        self.selected_line = None
        self.selected_line_text = None
        self.selected_line_wavenumber = -1
        if len(self.molecule_lookup_points) == 0:
            return
        closest = {'name':None, 'wn':-1., 'dist':9e9}
        for cur_molecule in self.molecule_lookup_points:
            wn = self.molecule_lookup_points[cur_molecule]['wn']
            ys = self.molecule_lookup_points[cur_molecule]['y']
            dist_x2 = (wn - event.xdata)**2
            xlim = self.axes.get_xlim()
            scale = ((xlim[1] - xlim[0]) /  # this is like wavenumbers/inch
                     (self.axes.figure.get_figwidth() * self.axes.get_position().bounds[2]))
            dist_y2 = ((ys - event.ydata)*(self.axes.figure.get_figheight() *
                                                  self.axes.get_position().bounds[3]) * scale)**2
            dist = np.sqrt(dist_x2 + dist_y2)
            if dist.min() < closest['dist']:
                closest = {'name':cur_molecule, 'wn':wn[dist.argmin()], 'dist':dist.min()}
        self.selected_line_wavenumber = closest['wn']
        self.selected_line = self.axes.plot([closest['wn'], closest['wn']], [0, 1], '-.', color='black')[0]
        self.selected_line_text = self.axes.annotate(closest['name'] + ('%11.5f' % closest['wn']),
                                                     (closest['wn'], 1.03), ha='center',
                                                     annotation_clip=False)
        self.redraw()

    def on_scroll(self, event):
        self.central_wavenumber += self.bandwidth * event.step

    def _all_on_fired(self):
        self.selected_molecules = self.molecules.keys()

    def _all_off_fired(self):
        self.selected_molecules = []

    def mpl_setup(self):
        self.axes_widget = AxesWidget(self.figure.gca())
        self.axes_widget.connect_event('button_press_event', self.on_click)
        self.axes_widget.connect_event('scroll_event', self.on_scroll)

    @on_trait_change("central_wavenumber, bandwidth")
    def replot_molecular_overplots(self):
        for i, cur_molecule in enumerate(self.selected_molecules):
            if self.molecules[cur_molecule]['hitran'] is None:
                self.molecules[cur_molecule]['hitran'] = pandas.io.parsers.read_csv( gzip.open(
                                        self.molecules[cur_molecule]['hitran_filename'], 'r'), skiprows=2)
            wn = self.molecules[cur_molecule]['hitran']['wavenumber']
            intensity = self.molecules[cur_molecule]['hitran']['intensity']
            w = ( (wn >= self.central_wavenumber - self.bandwidth / 2.) &
                  (wn <= self.central_wavenumber + self.bandwidth / 2.) )
            wn = wn[w]
            intensity = intensity[w]
            plot_orders_of_magnitude = 2.
            max_line_intensity = intensity.max()
            min_line_intensity = max_line_intensity / 10**plot_orders_of_magnitude
            wn = wn[intensity >= min_line_intensity]
            intensity = intensity[intensity >= min_line_intensity]
            intensity = ((np.log10(intensity) - np.log10(min_line_intensity)) /
                         (np.log10(max_line_intensity) - np.log10(min_line_intensity)))
            intensity = intensity * 0.1
            self.molecule_lookup_points[cur_molecule] = {'wn':wn, 'y':intensity + (i * 0.1) + 0.05}
            wn = wn.repeat(3)
            intensity = np.column_stack((np.zeros(len(intensity)),
                                         intensity,
                                         np.zeros(len(intensity)))).flatten() + (i * 0.1) + 0.05
            newplot = self.axes.plot(wn, intensity, self.molecules[cur_molecule]['color'])
            newtext = self.axes.annotate(cur_molecule, (self.central_wavenumber + self.bandwidth * 0.51,
                                                        i * 0.1 + 0.065), ha='left',
                                         va='center', annotation_clip=False, color=self.molecules[cur_molecule]['color'])
            if self.molecules[cur_molecule]['plot_lines'] in self.axes.lines:
                self.axes.lines.pop(self.axes.lines.index(self.molecules[cur_molecule]['plot_lines']))
            self.molecules[cur_molecule]['plot_lines'] = None
            if self.molecules[cur_molecule]['plot_text'] in self.axes.texts:
                self.axes.texts.remove(self.molecules[cur_molecule]['plot_text'])
                self.molecules[cur_molecule]['plot_text'] = None
            self.molecules[cur_molecule]['plot_lines'] = newplot[0]
            self.molecules[cur_molecule]['plot_text'] = newtext
        self.redraw()

    def _selected_molecules_changed(self, old, new):
        self.replot_molecular_overplots()
        for cur_molecule in old:
            if cur_molecule not in new:
                if self.molecules[cur_molecule]['plot_lines'] in self.axes.lines:
                    self.axes.lines.pop(self.axes.lines.index(self.molecules[cur_molecule]['plot_lines']))
                if self.molecules[cur_molecule]['plot_text'] in self.axes.texts:
                    self.axes.texts.remove(self.molecules[cur_molecule]['plot_text'])
                self.molecules[cur_molecule]['plot_lines'] = None
                self.molecules[cur_molecule]['plot_text'] = None
                self.molecule_lookup_points.pop(cur_molecule, None)
        self.redraw()

    @on_trait_change("central_wavenumber, bandwidth")
    def redraw(self):
        self.axes.set_xlim(self.central_wavenumber - self.bandwidth / 2.,
                           self.central_wavenumber + self.bandwidth / 2.)
        self.axes.set_ylim(0, 1.0)
        self.figure.canvas.draw()