Beispiel #1
0
def _do_plot_boilerplate(kwargs, image=False):
    """ Used by various plotting functions.  Checks/handles hold state,
    returns a Plot object for the plotting function to use.
    """

    if "hold" in kwargs:
        hold(kwargs["hold"])
        del kwargs["hold"]

    # Check for an active window; if none, open one.
    if len(session.windows) == 0:
        if image:
            win = session.new_window(is_image=True)
            activate(win)
        else:
            figure()

    cont = session.active_window.get_container()

    if not cont:
        cont = Plot(session.data)
        session.active_window.set_container(cont)

    existing_tools = [type(t) for t in (cont.tools + cont.overlays)]
    if not PanTool in existing_tools:
        cont.tools.append(PanTool(cont))
    if not ZoomTool in existing_tools:
        cont.overlays.append(ZoomTool(cont, tool_mode="box", always_on=True, drag_button="right"))

    if not session.hold:
        cont.delplot(*list(cont.plots.keys()))

    return cont
Beispiel #2
0
def _do_plot_boilerplate(kwargs, image=False):
    """ Used by various plotting functions.  Checks/handles hold state,
    returns a Plot object for the plotting function to use.
    """

    if "hold" in kwargs:
        hold(kwargs["hold"])
        del kwargs["hold"]

    # Check for an active window; if none, open one.
    if len(session.windows) == 0:
        if image:
            win = session.new_window(is_image=True)
            activate(win)
        else:
            figure()

    cont = session.active_window.get_container()

    if not cont:
        cont = Plot(session.data)
        session.active_window.set_container(cont)

    existing_tools = [type(t) for t in (cont.tools + cont.overlays)]
    if not PanTool in existing_tools:
        cont.tools.append(PanTool(cont))
    if not ZoomTool in existing_tools:
        cont.overlays.append(ZoomTool(cont, tool_mode="box", always_on=True, drag_button="right"))

    if not session.hold:
        cont.delplot(*list(cont.plots.keys()))

    return cont
Beispiel #3
0
class DataPlotter(traits.HasTraits):
    plot = traits.Instance(
        Plot
    )  # the attribute 'plot' of class DataPlotter is a trait that has to be an instance of the chaco class Plot.
    plot_data = traits.Instance(ArrayPlotData)
    data_index = traits.List(traits.Int)
    data_selected = traits.List(traits.Int)

    y_offset = traits.Range(0.0, 20.0, value=0)
    x_offset = traits.Range(-10.0, 10.0, value=0)
    y_scale = traits.Range(1e-3, 2.0, value=1.0)
    x_range_up = traits.Float()
    x_range_dn = traits.Float()
    x_range_btn = traits.Button(label="Set range")

    reset_plot_btn = traits.Button(label="Reset plot")
    select_all_btn = traits.Button(label="All")
    select_none_btn = traits.Button(label="None")

    # basic processing
    lb = traits.Float(10.0)
    lb_btn = traits.Button(label="Apodisation")
    lb_plt_btn = traits.Button(label="Plot Apod.")
    zf_btn = traits.Button(label="Zero-fill")
    ft_btn = traits.Button(label="FT")

    # phasing
    ph_auto_btn = traits.Button(label="Auto: all")
    ph_auto_single_btn = traits.Button(label="Auto: selected")
    ph_mp = traits.Bool(True, label="Parallelise")
    ph_man_btn = traits.Button(label="Manual")
    ph_global = traits.Bool(label="apply globally")
    _peak_marker_overlays = {}
    _bl_marker_overlays = {}
    _bl_range_plots = {}

    # baseline correctoin
    bl_cor_btn = traits.Button(label="BL correct")
    bl_sel_btn = traits.Button(label="Select points")

    # peak-picking
    peak_pick_btn = traits.Button(label="Peak-picking")
    peak_pick_clear = traits.Button(label="Clear peaks")
    deconvolute_btn = traits.Button(label="Deconvolute")
    plot_decon_btn = traits.Button(label="Show Lineshapes")
    lorgau = traits.Range(0.0, 1.0, value=0, label="Lorentzian = 0, Gaussian = 1")
    deconvolute_mp = traits.Bool(True, label="Parallelise")

    plot_ints_btn = traits.Button(label="Plot integrals")
    save_fids_btn = traits.Button(label="Save FIDs")
    save_ints_btn = traits.Button(label="Integrals -> csv")

    _peaks_now = []
    _ranges_now = []
    _bl_ranges_now = []
    _bl_ranges = []
    _bl_indices = []

    _flags = {"manphasing": False, "bl_selecting": False, "picking": False, "lineshapes_vis": False}

    # progress bar
    # progress_val = traits.Int()
    loading_animation = traits.File("/home/jeicher/Apps/NMRPy/nmrpy/loading.gif")
    busy_animation = traits.Bool(False)

    # def _metadata_handler(self):
    # return #print self.metadata_source.metadata.get('selections')

    def _bl_handler(self):
        # set_trace()
        self._bl_ranges_now = list(self.bl_tool.ranges_now)
        if self.bl_tool.ranges_now:
            picked_ranges = [self.correct_padding(r) for r in self.bl_tool.ranges_now]
            if self._bl_ranges is not picked_ranges:
                self._bl_ranges = picked_ranges
                self.bl_marker(self._bl_ranges[-1][0], colour="blue")
                self.bl_marker(self._bl_ranges[-1][1], colour="blue")
                l_bl = len(self._bl_ranges) - 1
                bl_range_x = self._bl_ranges[l_bl]
                y = 0.2 * self.plot.range2d.y_range.high
                bl_range_y = [y, y]
                self.plot_data.set_data("bl_range_x_%i" % (l_bl), bl_range_x)
                self.plot_data.set_data("bl_range_y_%i" % (l_bl), bl_range_y)
                self._bl_range_plots["bl_range_x_%i" % (l_bl)] = self.plot.plot(
                    ("bl_range_x_%i" % (l_bl), "bl_range_y_%i" % (l_bl)),
                    type="line",
                    name="bl_range_%i" % (l_bl),
                    line_width=0.5,
                    color="blue",
                )[0]

    def _peak_handler(self):
        # set_trace()
        self._peaks_now = list(self.picking_tool.peaks_now)
        self._ranges_now = list(self.picking_tool.ranges_now)
        picked_peaks = self.index2ppm(np.array(self.picking_tool.peaks_now))
        if picked_peaks != self.fid.peaks:
            self.fid.peaks = picked_peaks
            if self.fid.peaks:
                peak_ppm = self.fid.peaks[-1]
                self.plot_marker(peak_ppm)
        if self.picking_tool.ranges_now:
            picked_ranges = [self.correct_padding(r) for r in self.picking_tool.ranges_now]
            if self.fid.ranges is not picked_ranges:
                self.fid.ranges = picked_ranges
                self.range_marker(self.fid.ranges[-1][0], colour="blue")
                self.range_marker(self.fid.ranges[-1][1], colour="blue")

    def index2ppm(self, index):
        return list(
            np.array(
                self.fid.params["sw_left"] - (index - self.plot.padding_left) / self.plot.width * self.fid.params["sw"],
                dtype="float16",
            )
        )

    def correct_padding(self, values):
        return list(
            np.array(
                np.array(values) + float(self.plot.padding_left) / self.plot.width * self.fid.params["sw"],
                dtype="float16",
            )
        )

    def __init__(self, fid):
        super(DataPlotter, self).__init__()
        self.fid = fid
        data = fid.data
        self.data_index = range(len(data))
        if self.fid._flags["ft"]:
            self.x = np.linspace(
                self.fid.params["sw_left"], fid.params["sw_left"] - fid.params["sw"], len(self.fid.data[0])
            )
            self.plot_data = ArrayPlotData(x=self.x, *np.real(data))
            plot = Plot(self.plot_data, default_origin="bottom right", padding=[5, 0, 0, 35])
        else:
            self.x = np.linspace(0, self.fid.params["at"], len(self.fid.data[0]))
            self.plot_data = ArrayPlotData(x=self.x, *np.real(data))
            plot = Plot(self.plot_data, default_origin="bottom left", padding=[5, 0, 0, 35])
        self.plot = plot
        self.plot_init()

    def plot_init(self, index=[0]):
        if self.fid._flags["ft"]:
            self.plot.x_axis.title = "ppm."
        else:
            self.plot.x_axis.title = "sec."
        self.zoomtool = BetterZoom(self.plot, zoom_to_mouse=False, x_min_zoom_factor=1, zoom_factor=1.5)
        self.pantool = PanTool(self.plot)
        self.phase_dragtool = PhaseDragTool(self.plot)
        self.plot.tools.append(self.zoomtool)
        self.plot.tools.append(self.pantool)
        self.plot.y_axis.visible = False
        for i in index:
            self.plot.plot(("x", "series%i" % (i + 1)), type="line", line_width=0.5, color="black")[0]
        self.plot.request_redraw()
        self.old_y_scale = self.y_scale
        self.index_array = np.arange(len(self.fid.data))
        self.y_offsets = self.index_array * self.y_offset
        self.x_offsets = self.index_array * self.x_offset
        self.data_selected = index
        self.x_range_up = round(self.x[0], 3)
        self.x_range_dn = round(self.x[-1], 3)
        # this is necessary for phasing:
        self.plot._ps = self.fid.ps
        self.plot._data_complex = self.fid.data

    def text_marker(self, text, colour="black"):
        xl, xh, y = self.plot.range2d.x_range.low, self.plot.range2d.x_range.high, self.plot.range2d.y_range.high
        x = xl + 0.1 * (xh - xl)
        y = 0.8 * y

        dl = DataLabel(
            self.plot.plots["plot0"][0],
            data_point=(x, y),
            label_position="top",
            label_text=text,
            show_label_coords=False,
            marker_visible=False,
            border_visible=True,
            arrow_visible=False,
        )
        self.plot.plots["plot0"][0].overlays.append(dl)
        self.plot.request_redraw()

    def plot_marker(self, ppm, colour="red"):
        dl = DataLabel(
            self.plot.plots["plot0"][0],
            data_point=(ppm, 0.0),
            arrow_color=colour,
            arrow_size=10,
            label_position="top",
            label_format="%(x).3f",
            padding_bottom=int(self.plot.height * 0.1),
            marker_visible=False,
            border_visible=False,
            arrow_visible=True,
        )
        self.plot.plots["plot0"][0].overlays.append(dl)
        self._peak_marker_overlays[ppm] = dl
        self.plot.request_redraw()

    def range_marker(self, ppm, colour="blue"):
        dl = DataLabel(
            self.plot.plots["plot0"][0],
            data_point=(ppm, 0.0),
            arrow_color=colour,
            arrow_size=10,
            label_position="top",
            label_format="%(x).3f",
            padding_bottom=int(self.plot.height * 0.25),
            marker_visible=False,
            border_visible=False,
            arrow_visible=True,
        )
        self.plot.plots["plot0"][0].overlays.append(dl)
        self._peak_marker_overlays[ppm] = dl
        self.plot.request_redraw()

    def bl_marker(self, ppm, colour="blue"):
        dl = DataLabel(
            self.plot.plots["plot0"][0],
            data_point=(ppm, 0.0),
            arrow_color=colour,
            arrow_size=10,
            label_position="top",
            label_format="%(x).3f",
            padding_bottom=int(self.plot.height * 0.25),
            marker_visible=False,
            border_visible=False,
            arrow_visible=True,
        )
        self.plot.plots["plot0"][0].overlays.append(dl)
        self._bl_marker_overlays[ppm] = dl
        self.plot.request_redraw()

    def _x_range_btn_fired(self):
        if self.x_range_up < self.x_range_dn:
            xr = self.x_range_up
            self.x_range_up = self.x_range_dn
            self.x_range_dn = xr
        self.set_x_range(up=self.x_range_up, dn=self.x_range_dn)
        self.plot.request_redraw()

    def set_x_range(self, up=x_range_up, dn=x_range_dn):
        if self.fid._flags["ft"]:
            self.plot.index_range.high = up
            self.plot.index_range.low = dn
        else:
            self.plot.index_range.high = dn
            self.plot.index_range.low = up
        pass

    def _y_scale_changed(self):
        self.set_y_scale(scale=self.y_scale)

    def set_y_scale(self, scale=y_scale):
        self.plot.value_range.high /= scale / self.old_y_scale
        self.plot.request_redraw()
        self.old_y_scale = scale

    def reset_plot(self):
        self.x_offset, self.y_offset = 0, 0
        self.y_scale = 1.0
        if self.fid._flags["ft"]:
            self.plot.index_range.low, self.plot.index_range.high = [self.x[-1], self.x[0]]
        else:
            self.plot.index_range.low, self.plot.index_range.high = [self.x[0], self.x[-1]]
        self.plot.value_range.low = self.plot.data.arrays["series%i" % (self.data_selected[0] + 1)].min()
        self.plot.value_range.high = self.plot.data.arrays["series%i" % (self.data_selected[0] + 1)].max()

    def _reset_plot_btn_fired(self):
        print "resetting plot..."
        self.reset_plot()

    def _select_all_btn_fired(self):
        self.data_selected = range(len(self.fid.data))

    def _select_none_btn_fired(self):
        self.data_selected = []

    def set_plot_offset(self, x=None, y=None):
        if x == None and y == None:
            pass

        self.old_x_offsets = self.x_offsets
        self.old_y_offsets = self.y_offsets
        self.x_offsets = self.index_array * x
        self.y_offsets = self.index_array * y
        for i in np.arange(len([pt for pt in self.plot.plots if "plot" in pt])):
            self.plot.plots["plot%i" % i][0].position = [self.x_offsets[i], self.y_offsets[i]]
        self.plot.request_redraw()

    def _y_offset_changed(self):
        self.set_plot_offset(x=self.x_offset, y=self.y_offset)

    def _x_offset_changed(self):
        self.set_plot_offset(x=self.x_offset, y=self.y_offset)

    # for some mysterious reason, selecting new data to plot doesn't retain the plot offsets even if you set them explicitly
    def _data_selected_changed(self):
        if self._flags["manphasing"]:
            self.end_man_phasing()
        if self._flags["picking"]:
            self.end_picking()

        self.plot.delplot(*self.plot.plots)
        self.plot.request_redraw()
        for i in self.data_selected:
            self.plot.plot(
                ("x", "series%i" % (i + 1)),
                type="line",
                line_width=0.5,
                color="black",
                position=[self.x_offsets[i], self.y_offsets[i]],
            )  # FIX: this isn't working
        # self.reset_plot() # this is due to the fact that the plot automatically resets anyway
        if self._flags["lineshapes_vis"]:
            self.clear_lineshapes()
            self.plot_deconv()

    # processing buttons

    # plot the current apodisation function based on lb, and do apodisation
    # =================================================
    def _lb_plt_btn_fired(self):
        if self.fid._flags["ft"]:
            return
        if "lb1" in self.plot.plots:
            self.plot.delplot("lb1")
            self.plot.request_redraw()
            return
        self.plot_lb()

    def plot_lb(self):
        if self.fid._flags["ft"]:
            return
        lb_data = self.fid.data[self.data_selected[0]]
        lb_plt = np.exp(-np.pi * np.arange(len(lb_data)) * (self.lb / self.fid.params["sw_hz"])) * lb_data[0]
        self.plot_data.set_data("lb1", np.real(lb_plt))
        self.plot.plot(("x", "lb1"), type="line", name="lb1", line_width=1, color="blue")[0]
        self.plot.request_redraw()

    def _lb_changed(self):
        if self.fid._flags["ft"]:
            return
        lb_data = self.fid.data[self.data_selected[0]]
        lb_plt = np.exp(-np.pi * np.arange(len(lb_data)) * (self.lb / self.fid.params["sw_hz"])) * lb_data[0]
        self.plot_data.set_data("lb1", np.real(lb_plt))

    def _lb_btn_fired(self):
        if self.fid._flags["ft"]:
            return
        self.fid.emhz(self.lb)
        self.update_plot_data_from_fid()

    def _zf_btn_fired(self):
        if self.fid._flags["ft"]:
            return
        if "lb1" in self.plot.plots:
            self.plot.delplot("lb1")
        self.fid.zf()
        self.update_plot_data_from_fid()

    def _ft_btn_fired(self):
        if "lb1" in self.plot.plots:
            self.plot.delplot("lb1")
        if self.fid._flags["ft"]:
            return
        self.fid.ft()
        self.update_plot_data_from_fid()
        self.plot = Plot(self.plot_data, default_origin="bottom right", padding=[5, 0, 0, 35])
        self.plot_init(index=self.data_selected)
        self.reset_plot()

    def _ph_auto_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        for i in self.fid.data[
            np.iscomplex(self.fid.data) == False
        ]:  # as np.iscomplex returns False for 0+0j, we need to check manually
            if type(i) != np.complex128:
                print "Cannot perform phase correction on non-imaginary data."
                return
        self.fid.phase_auto(mp=self.ph_mp, discard_imaginary=False)
        self.update_plot_data_from_fid()

    def _ph_auto_single_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        for i in self.fid.data[
            np.iscomplex(self.fid.data) == False
        ]:  # as np.iscomplex returns False for 0+0j, we need to check manually
            if type(i) != np.complex128:
                print "Cannot perform phase correction on non-imaginary data."
                return
        for i in self.data_selected:
            self.fid._phase_area_single(i)
        self.update_plot_data_from_fid()

    def _ph_man_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        for i in self.fid.data[
            np.iscomplex(self.fid.data) == False
        ]:  # as np.iscomplex returns False for 0+0j, we need to check manually
            if type(i) != np.complex128:
                print "Cannot perform phase correction on non-imaginary data."
                return
        if not self._flags["manphasing"]:
            self._flags["manphasing"] = True
            self.text_marker("Drag to phase:\n up/down - p0\n left/right - p1")
            self.change_plot_colour(colour="red")
            self.disable_plot_tools()
            self.plot._data_selected = self.data_selected
            self.plot._data_complex = self.fid.data[np.array(self.data_selected)]
            self.plot.tools.append(PhaseDragTool(self.plot))
        elif self._flags["manphasing"]:
            self.end_man_phasing()

    def end_man_phasing(self):
        self._flags["manphasing"] = False
        self.remove_all_overlays()
        self.change_plot_colour(colour="black")
        print "p0: %f, p1: %f" % (self.plot.tools[0].p0, self.plot.tools[0].p1)
        if self.ph_global:
            self.fid.data = self.fid.ps(self.fid.data, p0=self.plot.tools[0].p0, p1=self.plot.tools[0].p1)
            self.update_plot_data_from_fid()
        else:
            for i, j in zip(self.plot._data_selected, self.plot._data_complex):
                self.fid.data[i] = j
        self.disable_plot_tools()
        self.enable_plot_tools()

    def remove_extra_overlays(self):
        self.plot.overlays = [self.plot.overlays[0]]
        self.plot.plots["plot0"][0].overlays = []

    def remove_all_overlays(self):
        self.plot.overlays = []
        self.plot.plots["plot0"][0].overlays = []

    def disable_plot_tools(self):
        self.plot.tools = []

    def enable_plot_tools(self):
        self.plot.tools.append(self.zoomtool)
        self.plot.tools.append(self.pantool)

    def change_plot_colour(self, colour="black"):
        for plot in self.plot.plots:
            self.plot.plots[plot][0].color = colour

    def _bl_sel_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        if self._flags["bl_selecting"]:
            self.end_bl_select()
        else:
            self._flags["bl_selecting"] = True
            self.plot.plot(
                ("x", "series%i" % (self.data_selected[0] + 1)),
                name="bl_plot",
                type="scatter",
                alpha=1.0,
                line_width=0,
                selection_line_width=0,
                marker_size=2,
                selection_marker_size=2,
                selection_color="black",  # change this to make selected points visible
                color="black",
            )[0]

            self.text_marker("Select ranges for baseline correction:\n drag right - select range")
            self.disable_plot_tools()
            self.plot.tools.append(
                BlSelectTool(
                    self.plot.plots["plot0"][0], metadata_name="selections", append_key=KeySpec(None, "control")
                )
            )
            self.plot.overlays.append(
                RangeSelectionOverlay(
                    component=self.plot.plots["bl_plot"][0], metadata_name="selections", axis="index", fill_color="blue"
                )
            )
            self.plot.overlays.append(
                LineInspector(
                    component=self.plot, axis="index_x", inspect_mode="indexed", write_metadata=True, color="blue"
                )
            )

            if self._bl_marker_overlays:
                self.plot.plots["plot0"][0].overlays = self._bl_marker_overlays.values()
            # for i in self._bl_range_plots: #for some reason, after re-adding these plot objects, deleting them doesn't actually delete the visual component in self.plot_components, thus they're being redrawn entirely below
            # self.plot.add(self._bl_range_plots[i])
            for l_bl in range(len(self._bl_ranges)):
                self.plot.plot(
                    ("bl_range_x_%i" % (l_bl), "bl_range_y_%i" % (l_bl)),
                    type="line",
                    name="bl_range_%i" % (l_bl),
                    line_width=0.5,
                    color="blue",
                )[0]

            self.plot.request_redraw()
            self.bl_tool = self.plot.tools[0]
            self.bl_tool.on_trait_change(self._bl_handler, name=["ranges_now"])

    def end_bl_select(self):
        self._flags["bl_selecting"] = False
        self._bl_indices = []
        for i, j in self._bl_ranges:
            self._bl_indices.append((i < self.x) * (self.x < j))
        self.fid.bl_points = np.where(sum(self._bl_indices, 0) == 1)[0]
        # self.plot.delplot('bl_plot')
        for i in [j for j in self.plot.plots if "bl_" in j]:
            self.plot.delplot(i)
        self.plot.request_redraw()
        self.remove_extra_overlays()
        self.disable_plot_tools()
        self.enable_plot_tools()

    def _bl_cor_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        if self._flags["bl_selecting"]:
            self.end_bl_select()
        self.fid.bl_fit()
        self.remove_all_overlays()
        self.update_plot_data_from_fid()

    def _peak_pick_btn_fired(self):
        if not self.fid._flags["ft"]:
            return
        if self._flags["picking"]:
            self.end_picking()
            # print 'self.fid.peaks', self.fid.peaks, '\nself.picking_tool.peaks_now',self.picking_tool.peaks_now
            return
        else:
            self._flags["picking"] = True

        self.reset_plot()
        if self._peak_marker_overlays:
            self.plot.plots["plot0"][0].overlays = self._peak_marker_overlays.values()
        self.text_marker("Peak-picking:\n left click - select peak\n drag right - select range")
        self.disable_plot_tools()

        # set_trace()
        pst = PeakSelectTool(self.plot.plots["plot0"][0], left_button_selects=True)
        pst.peaks_now = self._peaks_now
        pst.ranges_now = self._ranges_now

        self.plot.overlays.append(
            PeakPicker(
                component=self.plot,
                axis="index_x",
                inspect_mode="indexed",
                metadata_name="peaks",
                write_metadata=True,
                color="blue",
            )
        )
        self.plot.tools.append(pst)
        # metadata_name='selections',
        # append_key=KeySpec(None, 'control')))
        self.plot.overlays.append(
            RangeSelectionOverlay(
                component=self.plot.plots["plot0"][0], metadata_name="selections", axis="index", fill_color="blue"
            )
        )

        # Set up the trait handler for range selection
        # self.metadata_source = self.plot.plots['plot0'][0].index#self.plot.tools[0]
        # self.metadata_source.on_trait_change(self._metadata_handler, "metadata_changed")

        # Set up the trait handler for peak/range selections
        self.picking_tool = self.plot.tools[0]
        self.picking_tool.on_trait_change(self._peak_handler, name=["peaks_now", "ranges_now"])
        # print 'self.fid.peaks', self.fid.peaks, '\nself.picking_tool.peaks_now',self.picking_tool.peaks_now

    def end_picking(self):
        # set_trace()
        self._flags["picking"] = False
        if self.fid.peaks and self.fid.ranges:
            self.clear_invalid_peaks_ranges()
        else:
            self.clear_all_peaks_ranges()

        self.remove_all_overlays()
        self.disable_plot_tools()
        self.enable_plot_tools()
        self.plot.request_redraw()

    def clear_invalid_peaks_ranges(self):
        # set_trace()
        peaks_outside_of_ranges = self.peaks_outside_of_ranges()
        ranges_without_peaks = self.ranges_without_peaks()
        # remove uncoupled peak markers and empty range markers
        for i in peaks_outside_of_ranges:
            self._peak_marker_overlays.pop(self.fid.peaks[i])
        for i in ranges_without_peaks:
            for rng in self.fid.ranges[i]:
                self._peak_marker_overlays.pop(rng)
        # remove uncoupled peaks and empty ranges
        self.fid.peaks = [self.fid.peaks[i] for i in range(len(self.fid.peaks)) if i not in peaks_outside_of_ranges]
        self._peaks_now = [self._peaks_now[i] for i in range(len(self._peaks_now)) if i not in peaks_outside_of_ranges]
        self.fid.ranges = [self.fid.ranges[i] for i in range(len(self.fid.ranges)) if i not in ranges_without_peaks]
        self._ranges_now = [self._ranges_now[i] for i in range(len(self._ranges_now)) if i not in ranges_without_peaks]

    def clear_all_peaks_ranges(self):
        self.fid.peaks = []
        self.fid.ranges = []
        self.picking_tool.peaks_now = []
        self.picking_tool.ranges_now = []
        self._peak_marker_overlays = {}
        self.plot.plots["plot0"][0].overlays = []
        self.plot.request_redraw()
        print "Selected peaks and ranges cleared."

    def peaks_ranges_matrix(self):
        return np.array([(self.fid.peaks >= i[0]) * (self.fid.peaks <= i[1]) for i in self.fid.ranges])

    def peaks_outside_of_ranges(self):
        index = self.peaks_ranges_matrix().sum(0)
        return np.arange(len(self.fid.peaks))[np.where(index == 0)]

    def ranges_without_peaks(self):
        index = self.peaks_ranges_matrix().sum(1)
        return np.arange(len(self.fid.ranges))[np.where(index == 0)]

    def _peak_pick_clear_fired(self):
        self.clear_all_peaks_ranges()

    def busy(self):
        if self.busy_animation:
            self.busy_animation = False
        else:
            self.busy_animation = True

    def _deconvolute_btn_fired(self):
        # set_trace()
        # self.busy()
        if self._flags["picking"]:
            self.end_picking()
        if self.fid.peaks == []:
            print "No peaks selected."
            return
        print "Imaginary components discarded."
        self.fid.real()
        self.fid.deconv(gl=self.fid._flags["gl"], mp=self.deconvolute_mp)
        # self.busy()
        self._plot_decon_btn_fired()

    def clear_lineshapes(self):
        for line in [i for i in self.plot.plots if "lineshape" in i]:
            self.plot.delplot(line)
        for line in [i for i in self.plot.plots if "residual" in i]:
            self.plot.delplot(line)
        self.plot.request_redraw()

    def _plot_decon_btn_fired(self):
        if self._flags["lineshapes_vis"]:
            self.clear_lineshapes()
            self._flags["lineshapes_vis"] = False
            return

        self._flags["lineshapes_vis"] = True
        self.plot_deconv()

    def plot_deconv(self):
        index = self.data_selected[0]
        sw_left = self.fid.params["sw_left"]
        data = self.fid.data[index][::-1]
        if len(self.fid.fits) == 0:
            return
        paramarray = self.fid.fits[index]

        def i2ppm(index_value):
            return np.mgrid[sw_left - self.fid.params["sw"] : sw_left : complex(len(data))][index_value]

        def peaknum(paramarray, peak):
            pkr = []
            for i in paramarray:
                for j in i:
                    pkr.append(np.array(j - peak).sum())
            return np.where(np.array(pkr) == 0.0)[0][0]

        x = np.arange(len(data))
        peakplots = []
        for irange in paramarray:
            for ipeak in irange:
                # if txt:
                #    text(i2ppm(int(ipeak[0])), 0.1+pk.max(), str(peaknum(paramarray,ipeak)), color='#336699')
                peakplots.append(f_pk(ipeak, x)[::-1])
        # plot sum of all lines
        self.plot_data.set_data("lineshapes_%i" % (index), sum(peakplots, 0))
        self.plot.plot(
            ("x", "lineshapes_%i" % (index)), type="line", name="lineshapes_%i" % (index), line_width=0.5, color="blue"
        )[0]
        # plot residual
        self.plot_data.set_data("residuals_%i" % (index), data[::-1] - sum(peakplots, 0))
        self.plot.plot(
            ("x", "residuals_%i" % (index)), type="line", name="residuals_%i" % (index), line_width=0.5, color="red"
        )[0]
        # plot all individual lines
        for peak in range(len(peakplots)):
            self.plot_data.set_data("lineshape_%i_%i" % (index, peak), peakplots[peak])
            self.plot.plot(
                ("x", "lineshape_%i_%i" % (index, peak)),
                type="line",
                name="lineshape_%i_%i" % (index, peak),
                line_width=0.5,
                color="green",
            )[0]
        self.plot.request_redraw()

    def _lorgau_changed(self):
        self.fid._flags["gl"] = self.lorgau

    def update_plot_data_from_fid(self, index=None):
        if self.fid._flags["ft"]:
            self.x = np.linspace(
                self.fid.params["sw_left"], self.fid.params["sw_left"] - self.fid.params["sw"], len(self.fid.data[0])
            )
        else:
            self.x = np.linspace(0, self.fid.params["at"], len(self.fid.data[0]))
        self.plot_data.set_data("x", self.x)
        if index == None:
            for i in self.index_array:
                self.plot_data.set_data("series%i" % (i + 1), np.real(self.fid.data[i]))
        else:
            self.plot_data.set_data("series%i" % (index + 1), np.real(self.fid.data[index]))
        self.plot.request_redraw()

    def _plot_ints_btn_fired(self):
        if len(self.fid.integrals) == 0:
            print "self.integrals does not exist"
            return
        self.fid.plot_integrals()

    def _save_fids_btn_fired(self):
        self.fid.savefid_dict()

    def _save_ints_btn_fired(self):
        self.fid.save_integrals_csv()

    def default_traits_view(self):
        # exit_action = Action(name='Exit',
        #                action='exit_action')

        traits_view = View(
            Group(
                Group(
                    Item(
                        "data_index",
                        editor=TabularEditor(
                            show_titles=False,
                            selected="data_selected",
                            editable=False,
                            multi_select=True,
                            adapter=MultiSelectAdapter(),
                        ),
                        width=0.02,
                        show_label=False,
                        has_focus=True,
                    ),
                    Item("plot", editor=ComponentEditor(), show_label=False),
                    padding=0,
                    show_border=False,
                    orientation="horizontal",
                ),
                Group(
                    Group(
                        Group(
                            Item("select_all_btn", show_label=False),
                            Item("select_none_btn", show_label=False),
                            Item("reset_plot_btn", show_label=False),
                            orientation="vertical",
                        ),
                        Group(
                            Item("y_offset"),
                            Item("x_offset"),
                            Item("y_scale", show_label=True),
                            Group(
                                Item("x_range_btn", show_label=False),
                                Item("x_range_up", show_label=False),
                                Item("x_range_dn", show_label=False),
                                orientation="horizontal",
                            ),
                            orientation="vertical",
                        ),
                        orientation="horizontal",
                        show_border=True,
                        label="Plotting",
                    ),
                    Group(
                        Group(
                            Group(
                                Item("lb", show_label=False, format_str="%.2f Hz"),
                                Item("lb_btn", show_label=False),
                                Item("lb_plt_btn", show_label=False),
                                Item("zf_btn", show_label=False),
                                Item("ft_btn", show_label=False),
                                orientation="horizontal",
                                show_border=True,
                                label="Basic",
                            ),
                            Group(
                                Item("bl_cor_btn", show_label=False),
                                Item("bl_sel_btn", show_label=False),
                                orientation="horizontal",
                                show_border=True,
                                label="Baseline correction",
                            ),
                            orientation="horizontal",
                        ),
                        # Group(
                        # Item('zf_btn', show_label=False),
                        # Item('ft_btn', show_label=False),
                        # orientation='horizontal'),
                        Group(
                            Group(
                                Item("ph_auto_btn", show_label=False),
                                Item("ph_auto_single_btn", show_label=False),
                                Item("ph_mp", show_label=True),
                                orientation="horizontal",
                                show_border=True,
                            ),
                            Group(
                                Item("ph_man_btn", show_label=False),
                                Item("ph_global", show_label=True),
                                orientation="horizontal",
                                show_border=True,
                            ),
                            orientation="horizontal",
                            show_border=True,
                            label="Phase correction",
                        ),
                    ),
                    Group(
                        Group(
                            Group(
                                Item("peak_pick_btn", show_label=False),
                                Item("peak_pick_clear", show_label=False),
                                Item("deconvolute_btn", show_label=False),
                                Item(
                                    "lorgau",
                                    show_label=False,
                                    editor=RangeEditor(low_label="Lorentz", high_label="Gauss"),
                                ),
                                Item("deconvolute_mp", show_label=True),
                                orientation="horizontal",
                                show_border=False,
                            ),
                            Group(
                                Item("plot_decon_btn", show_label=False),
                                Item("plot_ints_btn", show_label=False),
                                orientation="horizontal",
                                show_border=False,
                            ),
                            orientation="vertical",
                            show_border=True,
                            label="Peak-picking and deconvolution",
                        ),
                        Group(
                            Item("save_fids_btn", show_label=False),
                            Item("save_ints_btn", show_label=False),
                            show_border=True,
                            label="Save",
                            orientation="horizontal",
                        ),
                        orientation="vertical",
                    ),
                    #                                            Group(
                    #                                        #    Item( 'loading_animation',
                    #                                        #        editor     = AnimatedGIFEditor(playing=str('busy_animation')),#( frame = 'frame_animation' ),
                    #                                        #        show_label = False),
                    #                                        #    #Item('progress_val',
                    #                                        #    #    show_label=False,
                    #                                        #    #    editor=ProgressEditor(
                    #                                        #    #        min=0,
                    #                                        #    #        max=100,
                    #                                        #    #        ),
                    #                                        #    #    )
                    #                                                Item('plot_ints_btn', show_label=False),
                    #                                                show_border=True,
                    #                                                orientation='horizontal'),
                    show_border=True,
                    orientation="horizontal",
                ),
            ),
            width=1.0,
            height=1.0,
            resizable=True,
            handler=TC_Handler(),
            title="NMRPy",
            # menubar=MenuBar(
            #            Menu(
            #                exit_action,
            #                name='File')),
        )
        return traits_view
Beispiel #4
0
class SamplePlotter(sample.SampleModel):
    pd = Instance(ArrayPlotData)
    plot = Instance(Plot)
    window = Instance(Window)

    def __init__(self, parent):
        self.parent=parent
        self.pd = ArrayPlotData()
        self.window = self.create_plot(parent)
        self.widget = self.window.control    

    def load_file(self):
        file_types = [("AtomEye", "cfg"),
                      ("Crystallographic Information File", "cif")]
        file_types = [file_type[0]+" (*."+file_type[1]+")" for file_type in file_types]
        file_type_string = ";;".join(file_types)
        filename = QtGui.QFileDialog.getOpenFileName(self.parent, 
                                                     'Open Model File', '.',
                                                     file_type_string)
        if path.splitext(filename[0])[1].lower()=='.cfg':
            self.loadCfg(filename[0])
        else:
            raise NotImplementedError("Only cfg file import is implemented right now.")
        self._update_coordinates()
        self.top_plot()

    def create_plot(self, parent):
        # Create some line plots of some of the data
        self.plot = Plot(self.pd, padding=[40,10,0,40], border_visible=True)
        self.plot.legend.visible = True

        # Attach some tools to the plot
        self.plot.tools.append(PanTool(self.plot))
        zoom = ZoomTool(component=self.plot, tool_mode="box", always_on=False)
        self.plot.overlays.append(zoom)

        # This Window object bridges the Enable and Qt4 worlds, and handles events
        # and drawing.  We can create whatever hierarchy of nested containers we
        # want, as long as the top-level item gets set as the .component attribute
        # of a Window.
        return Window(parent, -1, component=self.plot)

    def top_plot(self):
        # clear any existing plots


        #self.plot.aspect_ratio=self.nCellsX
        for key in self.transformed_elements.keys():
            if key in self.plot.plots:
                self.plot.delplot(key)
            self.plot.plot(("x"+key,"y"+key), name=key, type="scatter")

    def front_plot(self):
        # clear any existing plots
        for key in self.transformed_elements.keys():
            if key in self.plot.plots:
                self.plot.delplot(key)	    
            self.plot.plot(("x"+key,"z"+key), name=key, type="scatter")	

    @on_trait_change("")
    def _update_coordinates(self):
        for key in self.transformed_elements.keys():
            self.pd.set_data("x" + key, self.transformed_elements[key][:,0])
            self.pd.set_data("y" + key, self.transformed_elements[key][:,1])
            self.pd.set_data("z" + key, self.transformed_elements[key][:,2])
Beispiel #5
0
class MainPlot(HasTraits):

    plot = Instance(Plot)

    traits_view = View(
        Item('plot', editor=ComponentEditor(), show_label=False),
        width=800, height=600, resizable=True,
        title="Plot")

    def __init__(self):
        # list of allready added data
        # self.data[name] = [timeData,yData]
        self.data = {}

        # next color index from map
        self.colNr = 0

        self.plotdata = ArrayPlotData()

        self.plot = Plot(self.plotdata)
        self.plot.legend.visible = True

        self.__existingData = []  ## legenLabels

        # time axis
        time_axis = PlotAxis(self.plot, orientation="bottom", tick_generator=ScalesTickGenerator(scale=CalendarScaleSystem()))
        #self.plot.overlays.append(time_axis)
        self.plot.x_axis = time_axis

        hgrid, vgrid = add_default_grids(self.plot)
        self.plot.x_grid = None
        vgrid.tick_generator = time_axis.tick_generator

        # drag tool only time dir
        self.plot.tools.append(PanTool(self.plot, constrain=False,
                                #    constrain_direction="x"
                                      )
                                )

        # zoom tool only y dir
        self.plot.overlays.append(
        	#ZoomTool(self.plot, drag_button="right", always_on=True, tool_mode="range", axis="value" )
        	ZoomTool(self.plot, tool_mode="box", always_on=False)
        	)

        # init plot
        self.plot.plot(
            (
                self.plotdata.set_data(name = None, new_data = [time.mktime(testTime[i].timetuple()) for i in xrange(len(testTime))], generate_name=True),
                self.plotdata.set_data(name = None, new_data = testData, generate_name=True)
            ),
            name = 'temp')
        self.plot.request_redraw()
        self.plot.delplot('temp')

        #self.showData(testTime,testData,'ga1')

    def addData(self, _time, y , dataKey, overwriteData = False):
        """
        if name already exists the existing is overwritten if overwriteData
        """
        if not dataKey in self.data or overwriteData:
            x = [time.mktime(_time[i].timetuple()) for i in xrange(len(_time))]
            self.data[dataKey] = [
                self.plotdata.set_data(name = None, new_data = x, generate_name=True),
                self.plotdata.set_data(name = None, new_data = y, generate_name=True)
            ]


    def showData(self, legendLabel, dataKey):
        if not dataKey in self.data:
            raise Exception('No entry for that dataKey plz first use addData')

        #if not legendLabel in self.__existingData:
        self.plot.plot((self.data[dataKey][0], self.data[dataKey][1]), name = legendLabel, color=self._getColor())
        #else:
        #    self.plot.plot.showplot(legendLabel)

        zoomrange = self._get_zoomRange()
        self.plot.range2d.set_bounds(zoomrange[0],zoomrange[1])


        self.plot.request_redraw()

    def _get_zoomRange(self):
        values = []
        indices = []
        for renderers in self.plot.plots.values():
            for renderer in renderers:
                indices.append(renderer.index.get_data())
                values.append(renderer.value.get_data())

        indMin = None
        indMax = None

        valMin = None
        valMax = None

        for indice in indices:
            _min = min(indice)
            _max = max(indice)

            if indMin:
                indMin = min(indMin,_min)
            else:
                indMin = _min

            if indMin:
                indMax = max(indMax,_max)
            else:
                indMin = _max

        for value in values:
            _min = min(value)
            _max = max(value)

            if valMin:
                valMin = min(valMin,_min)
            else:
                valMin = _min

            if valMax:
                valMax = max(valMax,_max)
            else:
                valMax = _max

        if indMin and indMax and valMin and valMax:
            return ((indMin,valMin),(indMax,valMax))
        else:
            return None






    def hideData(self, legendLabel):
    	self.plot.delplot(legendLabel)
        #self.plot.hideplot(legendLabel)
        zoomrange = self._get_zoomRange()
        if zoomrange:
            self.plot.range2d.set_bounds(zoomrange[0],zoomrange[1])

        self.plot.request_redraw()

    def _getColor(self):
    	temp = self.colNr

    	self.colNr += 1
    	if self.colNr >= len(COLOR_PALETTE):
    		self.colNr = 0

    	return tuple(COLOR_PALETTE[temp])
Beispiel #6
0
class CameraImage(HasTraits):

    data = Array()
    data_store = Instance(ArrayPlotData)
    plot = Instance(Plot)
    hud_overlay = Instance(PlotLabel)

    # Number of steps of 90 degrees to rotate the image before
    # displaying it - must be between 0 and 3
    rotate = Range(0, 3)

    # Colormap to use for display; None means use the image's natural
    # colors (if RGB data) or grayscale (if monochrome). Setting @cmap
    # to a value coerces the image to monochrome.
    cmap = Enum(None, gray, bone, pink, jet, isoluminant, awesome)

    view = View(Item('plot', show_label=False, editor=ComponentEditor()))

    def __init__(self, **traits):
        super(CameraImage, self).__init__(**traits)
        self._dims = (200, 320)
        self.data_store = ArrayPlotData(image=self.data)
        self._hud = dict()
        self.plot = Plot(self.data_store)
        # Draw the image
        renderers = self.plot.img_plot('image', name='camera_image',
            colormap=fix(gray, (0, 255)))
        self._image = renderers[0]
        self.plot.aspect_ratio = float(self._dims[1]) / self._dims[0]

        self.hud_overlay = PlotLabel(text='', component=self.plot,
            hjustify='left', overlay_position='inside bottom',
            color='white')
        self.plot.overlays.append(self.hud_overlay)

    def _data_default(self):
        return N.zeros(self._dims, dtype=N.uint8)

    def _data_changed(self, value):
        bw = (len(value.shape) == 2)
        if not bw and self.cmap is not None:
            # Selecting a colormap coerces the image to monochrome
            # Use standard NTSC conversion formula
            value = N.array(
                0.2989 * value[..., 0]
                + 0.5870 * value[..., 1]
                + 0.1140 * value[..., 2])
        value = N.rot90(value, self.rotate)
        self.data_store['image'] = self.data = value

        if self._dims != self.data.shape:
            # Redraw the axes if the image is a different size
            self.plot.delplot('camera_image')
            self._dims = self.data.shape
            renderers = self.plot.img_plot('image', name='camera_image',
                colormap=self._get_cmap_function())
            # colormap is ignored if image is RGB or RGBA
            self._image = renderers[0]

        # Make sure the aspect ratio is correct, even after resize
        self.plot.aspect_ratio = float(self._dims[1]) / self._dims[0]

    def _get_cmap_function(self):
        return fix(
            gray if self.cmap is None else self.cmap,
            (0, 65535 if self.data.dtype == N.uint16 else 255))

    def _cmap_changed(self, old_value, value):
        # Must redraw the plot if data was RGB
        if old_value is None or value is None:
            self._data_changed(self.data)

        cmap_func = self._get_cmap_function()
        self._image.color_mapper = cmap_func(self._image.value_range)

    def hud(self, key, text):
        if text is None:
            self._hud.pop(key, None)
        else:
            self._hud[key] = text

        # Do the heads-up display
        text = ''
        for key in sorted(self._hud.keys()):
            text += self._hud[key] + '\n\n'
        self.hud_overlay.text = text
Beispiel #7
0
class CameraImage(HasTraits):

    data = Array()
    data_store = Instance(ArrayPlotData)
    plot = Instance(Plot)
    hud_overlay = Instance(PlotLabel)

    # Number of steps of 90 degrees to rotate the image before
    # displaying it - must be between 0 and 3
    rotate = Range(0, 3)

    # Colormap to use for display; None means use the image's natural
    # colors (if RGB data) or grayscale (if monochrome). Setting @cmap
    # to a value coerces the image to monochrome.
    cmap = Enum(None, gray, bone, pink, jet, isoluminant, awesome)

    view = View(Item('plot', show_label=False, editor=ComponentEditor()))

    def __init__(self, **traits):
        super(CameraImage, self).__init__(**traits)
        self._dims = (200, 320)
        self.data_store = ArrayPlotData(image=self.data)
        self._hud = dict()
        self.plot = Plot(self.data_store)
        # Draw the image
        renderers = self.plot.img_plot('image',
                                       name='camera_image',
                                       colormap=fix(gray, (0, 255)))
        self._image = renderers[0]
        self.plot.aspect_ratio = float(self._dims[1]) / self._dims[0]

        self.hud_overlay = PlotLabel(text='',
                                     component=self.plot,
                                     hjustify='left',
                                     overlay_position='inside bottom',
                                     color='white')
        self.plot.overlays.append(self.hud_overlay)

    def _data_default(self):
        return N.zeros(self._dims, dtype=N.uint8)

    def _data_changed(self, value):
        bw = (len(value.shape) == 2)
        if not bw and self.cmap is not None:
            # Selecting a colormap coerces the image to monochrome
            # Use standard NTSC conversion formula
            value = N.array(0.2989 * value[..., 0] + 0.5870 * value[..., 1] +
                            0.1140 * value[..., 2])
        value = N.rot90(value, self.rotate)
        self.data_store['image'] = self.data = value

        if self._dims != self.data.shape:
            # Redraw the axes if the image is a different size
            self.plot.delplot('camera_image')
            self._dims = self.data.shape
            renderers = self.plot.img_plot('image',
                                           name='camera_image',
                                           colormap=self._get_cmap_function())
            # colormap is ignored if image is RGB or RGBA
            self._image = renderers[0]

        # Make sure the aspect ratio is correct, even after resize
        self.plot.aspect_ratio = float(self._dims[1]) / self._dims[0]

    def _get_cmap_function(self):
        return fix(gray if self.cmap is None else self.cmap,
                   (0, 65535 if self.data.dtype == N.uint16 else 255))

    def _cmap_changed(self, old_value, value):
        # Must redraw the plot if data was RGB
        if old_value is None or value is None:
            self._data_changed(self.data)

        cmap_func = self._get_cmap_function()
        self._image.color_mapper = cmap_func(self._image.value_range)

    def hud(self, key, text):
        if text is None:
            self._hud.pop(key, None)
        else:
            self._hud[key] = text

        # Do the heads-up display
        text = ''
        for key in sorted(self._hud.keys()):
            text += self._hud[key] + '\n\n'
        self.hud_overlay.text = text
Beispiel #8
0
class Plot1D(BasePlot):
    """
    """
    #: Name of the x axis used for labelling the plot.
    x_axis = Str()

    #: Infos caracterising the plotted data.
    #: Should not be manipulated by user code.
    y_infos = List()

    def __init__(self, **kwargs):
        super(Plot1D, self).__init__(**kwargs)
        self.renderer = Plot()
        self.renderer.data = self.data
        exp = self.experiment
        self.data.set_data('x', getattr(exp.model, exp.x_axis).linspace)
        # Add basic tools and ways to activate them in public API
        zoom = BetterSelectingZoom(self.renderer, tool_mode="box",
                                   always_on=False)
        self.renderer.overlays.append(zoom)
        self.renderer.tools.append(PanTool(self.renderer,
                                           restrict_to_data=True))

    @classmethod
    def build_view(cls, plot):
        """
        """
        return Plot1DItem(plot=plot)

    def add_curves(self, curves):
        """
        """
        self.y_infos.extend(curves)
        self._update_graph(added=curves)

    def remove_curves(self, curves):
        """
        """
        for c in curves:
            self.y_infos.remove(c)
        self._update_graph(removed=curves)

    def replace_curve(self, old, new):
        """
        """
        self.y_infos.remove(old)
        self.y_infos.append(new)
        self._update_graph([new], [old])

    # For the time being stage is unused (will try to refine stuff if it is
    # needed)
    def update_data(self, stage):
        """
        """
        exp = self.experiment
        for info in self.y_infos:
            data = info.gather_data(exp)
            self.data.set_data(info.id, data)
            
    def export_data(self, path):
        """
        """
        if not path.endswith('.dat'):
            path += '.dat'
        header = self.experiment.make_header()
        header += '\n' + '\n'.join([i.make_header(self.experiment) 
                                     for i in self.y_infos])

        data = ([self.data.get_data('x')] + 
                [i.gather_data(self.experiment) for i in self.y_infos])
        data = [d for d in data if len(d)!=0] # remove empty arrays
        arr = np.rec.fromarrays(data,
                                names=([self.x_axis] +
                                       [i.m_name + str(i.indexes) 
                                        for i in self.y_infos]))

        with open(path, 'wb') as f:
            header = ['#' + l for l in header.split('\n') if l]
            f.write('\n'.join(header) + '\n')
            f.write('\t'.join(arr.dtype.names) + '\n')
            np.savetxt(f, arr, fmt='%.6e', delimiter='\t')

    def preferences_from_members(self):
        """
        """
        d = super(Plot1D, self).preferences_from_members()
        for i, c in enumerate(self.y_infos):
            d['curve_{}'.format(i)] = c.preferences_from_members()

        return d

    def update_members_from_preferences(self, config):
        """
        """
        super(Plot1D, self).update_members_from_preferences(config)
        infos = []
        i = 0
        while True:
            aux = 'curve_{}'.format(i)
            i += 1
            if aux in config:
                c_config = config[aux]
                curve = [c for c in CURVE_INFOS
                         if c.__name__ == c_config['info_class']][0]()
                curve.update_members_from_preferences(c_config)
                infos.append(curve)
                continue
            break

        self.add_curves(infos)

    def _update_graph(self, added=[], removed=[]):
        """
        """
        exp = self.experiment
        # First we clean the old graphs
        self.renderer.delplot(*[c.id for c in removed])
        for r in removed:
            self.data.del_data(r.id)

        # Then we add new ones (this avoids messing up when replacing a graph)
        for a in added:
            y_data = a.gather_data(exp)
            name = a.id
            self.data.set_data(name, y_data)
            self.renderer.plot(('x', name), name=name, type=a.type,
                               color=a.color)

    def _post_setattr_x_axis(self, old, new):
        """
        """
        self.renderer.x_axis.title = new

    def _default_renderer(self):
        return Plot()
Beispiel #9
0
class SamplePlotter(sample.SampleModel):
    pd = Instance(ArrayPlotData)
    plot = Instance(Plot)
    window = Instance(Window)

    def __init__(self, parent):
        self.parent = parent
        self.pd = ArrayPlotData()
        self.window = self.create_plot(parent)
        self.widget = self.window.control

    def load_file(self):
        file_types = [("AtomEye", "cfg"),
                      ("Crystallographic Information File", "cif")]
        file_types = [
            file_type[0] + " (*." + file_type[1] + ")"
            for file_type in file_types
        ]
        file_type_string = ";;".join(file_types)
        filename = QtGui.QFileDialog.getOpenFileName(self.parent,
                                                     'Open Model File', '.',
                                                     file_type_string)
        if path.splitext(filename[0])[1].lower() == '.cfg':
            self.loadCfg(filename[0])
        else:
            raise NotImplementedError(
                "Only cfg file import is implemented right now.")
        self._update_coordinates()
        self.top_plot()

    def create_plot(self, parent):
        # Create some line plots of some of the data
        self.plot = Plot(self.pd, padding=[40, 10, 0, 40], border_visible=True)
        self.plot.legend.visible = True

        # Attach some tools to the plot
        self.plot.tools.append(PanTool(self.plot))
        zoom = ZoomTool(component=self.plot, tool_mode="box", always_on=False)
        self.plot.overlays.append(zoom)

        # This Window object bridges the Enable and Qt4 worlds, and handles events
        # and drawing.  We can create whatever hierarchy of nested containers we
        # want, as long as the top-level item gets set as the .component attribute
        # of a Window.
        return Window(parent, -1, component=self.plot)

    def top_plot(self):
        # clear any existing plots

        #self.plot.aspect_ratio=self.nCellsX
        for key in self.transformed_elements.keys():
            if key in self.plot.plots:
                self.plot.delplot(key)
            self.plot.plot(("x" + key, "y" + key), name=key, type="scatter")

    def front_plot(self):
        # clear any existing plots
        for key in self.transformed_elements.keys():
            if key in self.plot.plots:
                self.plot.delplot(key)
            self.plot.plot(("x" + key, "z" + key), name=key, type="scatter")

    @on_trait_change("")
    def _update_coordinates(self):
        for key in self.transformed_elements.keys():
            self.pd.set_data("x" + key, self.transformed_elements[key][:, 0])
            self.pd.set_data("y" + key, self.transformed_elements[key][:, 1])
            self.pd.set_data("z" + key, self.transformed_elements[key][:, 2])