Example #1
0
    def __init__(self, depth, data_series, **kw):
        super(MyPlot, self).__init__(**kw)

        plot_data = ArrayPlotData(index=depth)
        plot_data.set_data('data_series', data_series)
        self.plot = Plot(plot_data, orientation='v', origin='top left')
        self.plot.plot(('index', 'data_series'))
Example #2
0
def _create_plot_component():

    # Create some data
    numpts = 5000
    x = sort(random(numpts))
    y = random(numpts)

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)

    # Create the plot
    plot = Plot(pd)
    plot.plot(("index", "value"),
              type="scatter",
              marker="circle",
              index_sort="ascending",
              color="orange",
              marker_size=3,
              bgcolor="white")

    # Tweak some of the plot properties
    plot.title = "Scatter Plot"
    plot.line_width = 0.5
    plot.padding = 50

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot, constrain_key="shift"))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)

    return plot
def _create_plot_component():
    # Create a scalar field to colormap
    xs = linspace(0, 10, 600)
    ys = linspace(0, 5, 600)
    x, y = meshgrid(xs, ys)
    z = exp(-(x**2 + y**2) / 100)

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("imagedata", z)

    # Create the plot
    plot = Plot(pd)
    img_plot = plot.img_plot("imagedata",
                             xbounds=(0, 10),
                             ybounds=(0, 5),
                             colormap=jet)[0]

    # Tweak some of the plot properties
    plot.title = "My First Image Plot"
    plot.padding = 50

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=img_plot, tool_mode="box", always_on=False)
    img_plot.overlays.append(zoom)
    return plot
def _create_plot_component():
    # Create a GridContainer to hold all of our plots
    container = GridContainer(padding=20, fill_padding=True,
                              bgcolor="lightgray", use_backbuffer=True,
                              shape=(3,3), spacing=(12,12))

    # Create the initial series of data
    x = linspace(-5, 15.0, 100)
    pd = ArrayPlotData(index = x)

    # Plot some bessel functions and add the plots to our container
    for i in range(9):
        pd.set_data("y" + str(i), jn(i,x))
        plot = Plot(pd)
        plot.plot(("index", "y" + str(i)),
                  color=tuple(COLOR_PALETTE[i]), line_width=2.0,
                  bgcolor = "white", border_visible=True)

        # Tweak some of the plot properties
        plot.border_width = 1
        plot.padding = 10

        # Set each plot's aspect ratio based on its position in the
        # 3x3 grid of plots.
        n,m = divmod(i, 3)
        plot.aspect_ratio = float(n+1) / (m+1)

        # Attach some tools to the plot
        plot.tools.append(PanTool(plot))
        zoom = ZoomTool(plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        # Add to the grid container
        container.add(plot)
    return container
Example #5
0
def _create_plot_component():

    # Create some RGBA image data
    image = zeros((200,400,4), dtype=uint8)
    image[:,0:40,0] += 255     # Vertical red stripe
    image[0:25,:,1] += 255     # Horizontal green stripe; also yellow square
    image[-80:,-160:,2] += 255 # Blue square
    image[:,:,3] = 255

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("imagedata", image)

    # Create the plot
    plot = Plot(pd, default_origin="top left")
    plot.x_axis.orientation = "top"
    img_plot = plot.img_plot("imagedata")[0]

    # Tweak some of the plot properties
    plot.bgcolor = "white"

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot, constrain_key="shift"))
    plot.overlays.append(ZoomTool(component=plot,
                                    tool_mode="box", always_on=False))

    imgtool = ImageInspectorTool(img_plot)
    img_plot.tools.append(imgtool)
    plot.overlays.append(ImageInspectorOverlay(component=img_plot,
                                               image_inspector=imgtool))
    return plot
Example #6
0
class DataChooser(HasTraits):

    plot = Instance(Plot)
    data_name = Enum("jn0", "jn1", "jn2")
    traits_view = View(Item('data_name', label="Y data"),
                       Item('plot', editor=ComponentEditor(), show_label=False),
                       width=800, height=600, resizable=True,
                       title="Data Chooser")

    def __init__(self):
        x = linspace(-5, 10, 100)
        self.data = {"jn0": jn(0, x),
                     "jn1": jn(1, x),
                     "jn2": jn(2, x)}

        # Create the data and the PlotData object
        self.plotdata = ArrayPlotData(x=x, y=self.data["jn0"])

        # Create a Plot and associate it with the PlotData
        plot = Plot(self.plotdata)
        # Create a line plot in the Plot
        plot.plot(("x", "y"), type="line", color="blue")
        self.plot = plot

    def _data_name_changed(self, old, new):
        self.plotdata.set_data("y", self.data[self.data_name])
Example #7
0
def _create_plot_component():# Create a scalar field to colormap
    xbounds = (-2*pi, 2*pi, 600)
    ybounds = (-1.5*pi, 1.5*pi, 300)
    xs = linspace(*xbounds)
    ys = linspace(*ybounds)
    x, y = meshgrid(xs,ys)
    z = sin(x)*y

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("imagedata", z)

    # Create the plot
    plot = Plot(pd)
    img_plot = plot.img_plot("imagedata",
                             xbounds = xbounds[:2],
                             ybounds = ybounds[:2],
                             colormap=jet)[0]

    # Tweak some of the plot properties
    plot.title = "My First Image Plot"
    plot.padding = 50

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)
    imgtool = ImageInspectorTool(img_plot)
    img_plot.tools.append(imgtool)
    overlay = ImageInspectorOverlay(component=img_plot, image_inspector=imgtool,
                                    bgcolor="white", border_visible=True)

    img_plot.overlays.append(overlay)
    return plot
Example #8
0
def _create_plot_component():

    # Create some x-y data series (with NaNs) to plot
    x = linspace(-5.0, 15.0, 500)
    x[75:125] = nan
    x[200:250] = nan
    x[300:330] = nan
    pd = ArrayPlotData(index = x)
    pd.set_data("value1", jn(0, x))
    pd.set_data("value2", jn(1, x))

    # Create some line and scatter plots of the data
    plot = Plot(pd)
    plot.plot(("index", "value1"), name="j_0(x)", color="red", width=2.0)
    plot.plot(("index", "value2"), type="scatter", marker_size=1,
              name="j_1(x)", color="green")

    # Tweak some of the plot properties
    plot.title = "Plots with NaNs"
    plot.padding = 50
    plot.legend.visible = True

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)

    return plot
Example #9
0
def _create_plot_component():

    # Create some x-y data series to plot
    x = linspace(-2.0, 10.0, 100)
    pd = ArrayPlotData(index = x)
    for i in range(5):
        pd.set_data("y" + str(i), jn(i,x))

    # Create some line plots of some of the data
    plot1 = Plot(pd, padding=50)
    plot1.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="red")
    plot1.plot(("index", "y3"), name="j_3", color="blue")

    # Attach some tools to the plot
    plot1.tools.append(PanTool(plot1))
    zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
    plot1.overlays.append(zoom)

    # Add the scrollbar
    hscrollbar = PlotScrollBar(component=plot1, axis="index", resizable="h",
                               height=15)
    plot1.padding_top = 0
    hscrollbar.force_data_update()

    # Create a container and add our plots
    container = VPlotContainer()
    container.add(plot1)
    container.add(hscrollbar)

    return container
Example #10
0
    def __init__(self, index, data_series, **kw):
        super(MyPlot, self).__init__(**kw)

        plot_data = ArrayPlotData(index=index)
        plot_data.set_data('data_series', data_series)
        self.plot = Plot(plot_data)
        self.plot.plot(('index', 'data_series'))
Example #11
0
def _create_plot_component():

    # Create some x-y data series to plot
    x = linspace(-2.0, 10.0, 100)
    pd = ArrayPlotData(index = x)
    for i in range(5):
        pd.set_data("y" + str(i), jn(i,x))

    # Create some line plots of some of the data
    plot1 = Plot(pd, title="Line Plot", padding=50, border_visible=True)
    plot1.legend.visible = True
    plot1.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="red")
    plot1.plot(("index", "y3"), name="j_3", color="blue")

    # Attach some tools to the plot
    plot1.tools.append(PanTool(plot1))
    zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
    plot1.overlays.append(zoom)

    # Create a second scatter plot of one of the datasets, linking its
    # range to the first plot
    plot2 = Plot(pd, range2d=plot1.range2d, title="Scatter plot", padding=50,
                 border_visible=True)
    plot2.plot(('index', 'y3'), type="scatter", color="blue", marker="circle")

    # Create a container and add our plots
    container = HPlotContainer()
    container.add(plot1)
    container.add(plot2)

    return container
Example #12
0
def _create_plot_component():# Create a scalar field to colormap
    xbounds = (-2*pi, 2*pi, 600)
    ybounds = (-1.5*pi, 1.5*pi, 300)
    xs = linspace(*xbounds)
    ys = linspace(*ybounds)
    x, y = meshgrid(xs,ys)
    z = sin(x)*y

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("imagedata", z)

    # Create the plot
    plot = Plot(pd)
    img_plot = plot.img_plot("imagedata",
                             xbounds=xbounds[:2],
                             ybounds=ybounds[:2],
                             colormap=jet)[0]

    # Tweak some of the plot properties
    plot.title = "Image Plot with Lasso"
    plot.padding = 50

    lasso_selection = LassoSelection(component=img_plot)
    lasso_selection.on_trait_change(lasso_updated, "disjoint_selections")
    lasso_overlay = LassoOverlay(lasso_selection = lasso_selection, component=img_plot)
    img_plot.tools.append(lasso_selection)
    img_plot.overlays.append(lasso_overlay)
    return plot
Example #13
0
def _create_plot_component():

    # Create some data
    numpts = 1000
    x = sort(random(numpts))
    y = random(numpts)
    color = exp(-(x**2 + y**2))

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)
    pd.set_data("color", color)

    # Create the plot
    plot = Plot(pd)
    plot.plot(("index", "value", "color"),
              type="cmap_scatter",
              name="my_plot",
              color_mapper=jet,
              marker = "square",
              fill_alpha = 0.5,
              marker_size = 6,
              outline_color = "black",
              border_visible = True,
              bgcolor = "white")

    # Tweak some of the plot properties
    plot.title = "Colormapped Scatter Plot"
    plot.padding = 50
    plot.x_grid.visible = False
    plot.y_grid.visible = False
    plot.x_axis.font = "modern 16"
    plot.y_axis.font = "modern 16"

    # Right now, some of the tools are a little invasive, and we need the
    # actual ColomappedScatterPlot object to give to them
    cmap_renderer = plot.plots["my_plot"][0]

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot, constrain_key="shift"))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)
    selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35,
                                            selection_type="mask")
    cmap_renderer.overlays.append(selection)

    # Create the colorbar, handing in the appropriate range and colormap
    colorbar = create_colorbar(plot.color_mapper)
    colorbar.plot = cmap_renderer
    colorbar.padding_top = plot.padding_top
    colorbar.padding_bottom = plot.padding_bottom

    # Create a container to position the plot and the colorbar side-by-side
    container = HPlotContainer(use_backbuffer = True)
    container.add(plot)
    container.add(colorbar)
    container.bgcolor = "lightgray"
    return container
Example #14
0
 def _pd_default(self):
     image = ones(shape=(300, 400))
     pd = ArrayPlotData()
     pd.set_data("imagedata", image)
     pd.set_data('h_index', numpy.arange(400))
     pd.set_data('h_value', numpy.ones((400, )))
     pd.set_data('v_index', numpy.arange(300))
     pd.set_data('v_value', numpy.ones((300, )))
     return pd
Example #15
0
    def __init__(self, index, value, *args, **kw):
        super(PlotExample, self).__init__(*args, **kw)

        plot_data = ArrayPlotData(index=index)
        plot_data.set_data("value", value)

        self.plot = Plot(plot_data)
        line = self.plot.plot(("index", "value"))[0]

        line.overlays.append(XRayOverlay(line))
        line.tools.append(BoxSelectTool(line))
def _create_plot_component():

    # Create some data
    numpts = 1000
    x = numpy.arange(0, numpts)
    y = numpy.random.random(numpts)
    marker_size = numpy.random.normal(4.0, 4.0, numpts)

    # Create a plot data object and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)

    # Because this is a non-standard renderer, we can't call plot.plot, which
    # sets up the array data sources, mappers and default index/value ranges.
    # So, its gotta be done manually for now.

    index_ds = ArrayDataSource(x)
    value_ds = ArrayDataSource(y)

    # Create the plot
    plot = Plot(pd)
    plot.index_range.add(index_ds)
    plot.value_range.add(value_ds)

    # Create the index and value mappers using the plot data ranges
    imapper = LinearMapper(range=plot.index_range)
    vmapper = LinearMapper(range=plot.value_range)

    # Create the scatter renderer
    scatter = VariableSizeScatterPlot(
                    index=index_ds,
                    value=value_ds,
                    index_mapper = imapper,
                    value_mapper = vmapper,
                    marker='circle',
                    marker_size=marker_size,
                    color=(1.0,0.0,0.75,0.4))

    # Append the renderer to the list of the plot's plots
    plot.add(scatter)
    plot.plots['var_size_scatter'] = [scatter]

    # Tweak some of the plot properties
    plot.title = "Scatter Plot"
    plot.line_width = 0.5
    plot.padding = 50

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot, constrain_key="shift"))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)

    return plot
 def _plot_default(self):
     outcomes, results, time = self._prepare_data()
     
     # get the x,y data to plot
     pds = []    
     for outcome in outcomes:
         pd  = ArrayPlotData(index = time)
         result = results.get(outcome)
         for j in range(result.shape[0]): 
             pd.set_data("y"+str(j), result[j, :] )
         pds.append(pd)
     
     # Create a container and add our plots
     container = GridContainer(
                               bgcolor="white", use_backbuffer=True,
                               shape=(len(outcomes),1))
 
     #plot data
     tools = []
     for j, outcome in enumerate(outcomes):
         pd1 = pds[j]
 
         # Create some line plots of some of the data
         plot = Plot(pd1, title=outcome, border_visible=True, 
                     border_width = 1)
         plot.legend.visible = False
     
         a = len(pd1.arrays)- 1
         if a > 1000:
             a = 1000
         for i in range(a):
             plotvalue = "y"+str(i)
             color = colors[i%len(colors)]
             plot.plot(("index", plotvalue), name=plotvalue, color=color)
             
         for value in plot.plots.values():
             for entry in value:
                 entry.index.sort_order = 'ascending'
 
         # Attach the selector tools to the plot
         selectorTool1 = LineSelectorTool(component=plot)
         plot.tools.append(selectorTool1)
         tools.append(selectorTool1)
         container.add(plot)
 
     #make sure the selector tools know each other
     
     for tool in tools:
         a = set(tools) - set([tool])
         tool._other_selectors = list(a)
         tool._demo = self
 
     return container
Example #18
0
class LinePlot(QtGui.QWidget):
    def __init__(self, parent, title, x, y, xtitle, ytitle, type="line", color="blue"):
        QtGui.QWidget.__init__(self)
        
        # Create the subclass's window
        self.enable_win = self._create_window(title, x, y, xtitle, ytitle, type, color)
        
        layout = QtGui.QVBoxLayout()
        
        layout.setMargin(0)
        layout.addWidget(self.enable_win.control)

        self.setLayout(layout)

        self.resize(650,650)

        self.show()

    def _create_window(self, title, xtitle, x, ytitle, y, type="line", color="blue"):
        self.plotdata = ArrayPlotData(x=x, y=y)
        plot = ToolbarPlot(self.plotdata)
        plot.plot(('x', 'y'), type=type, color=color)
        plot.title = title
        plot.x_axis.title = xtitle
        plot.y_axis.title = ytitle
        self.plot = plot
        self._hid = 0
        self._colors = ['blue', 'red', 'black', 'green', 'magenta', 'yellow']
        
        # Add some tools
        self.plot.tools.append(PanTool(self.plot, constrain_key="shift"))
        self.plot.overlays.append(ZoomTool(component=self.plot, tool_mode="box", always_on=False))
        
        return Window(self, -1, component=plot)
    
    def update_plot(self, x,y):
        '''
        Update plot
        '''
        self.plotdata.set_data('x', x)
        self.plotdata.set_data('y', y)
        self.plot.data = self.plotdata
        self.plot.request_redraw()
    
    def plot_hold_on(self, x, y, type="line"):
        '''
        Plot if hold on
        '''
        self._hid = self._hid + 1
        self.plotdata.set_data('x' + str(self._hid), x)
        self.plotdata.set_data('y' + str(self._hid), y)
        self.plot.plot(('x' + str(self._hid), 'y' + str(self._hid)), type=type, color=self._colors[self._hid%len(self._colors)])
        self.plot.request_redraw()
Example #19
0
def create_plot():

    # Create some data
    numpts = 200
    x = sort(random(numpts))
    y = random(numpts)
    color = exp(-(x**2 + y**2))

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)
    pd.set_data("color", color)

    # Create the plot
    plot = Plot(pd)
    plot.plot(("index", "value", "color"),
              type="cmap_scatter",
              name="my_plot",
              color_mapper=jet,
              marker = "square",
              fill_alpha = 0.5,
              marker_size = 6,
              outline_color = "black",
              border_visible = True,
              bgcolor = "white")

    # Tweak some of the plot properties
    plot.title = "Colormapped Scatter Plot"
    plot.padding = 50
    plot.x_grid.visible = False
    plot.y_grid.visible = False
    plot.x_axis.font = "modern 16"
    plot.y_axis.font = "modern 16"

    # Set colors
    #plot.title_color = "white"
    #for axis in plot.x_axis, plot.y_axis:
    #    axis.set(title_color="white", tick_label_color="white")

    # Right now, some of the tools are a little invasive, and we need the
    # actual ColomappedScatterPlot object to give to them
    cmap_renderer = plot.plots["my_plot"][0]

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot, constrain_key="shift"))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)
    selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35,
                                            selection_type="mask")
    cmap_renderer.overlays.append(selection)
    plot.tools.append(MoveTool(plot, drag_button="right"))
    return plot
Example #20
0
def _create_plot_component():

    # Create some x-y data series to plot
    x = linspace(-2.0, 10.0, 100)
    pd = ArrayPlotData(index = x)
    for i in range(5):
        pd.set_data("y" + str(i), jn(i,x))

    # Create some line plots of some of the data
    plot = Plot(pd, title="Line Plot", padding=50, border_visible=True)
    plot.legend.visible = True
    plot.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="auto")
    plot.plot(("index", "y3"), name="j_3", color="auto")

    plot.x_grid.line_color = "black"
    plot.y_grid.line_color = "black"
    xmin, xmax = 1.0, 6.0
    ymin, ymax = 0.2, 0.80001
    plot.x_grid.set(data_min = xmin, data_max = xmax,
            transverse_bounds = (ymin, ymax),
            transverse_mapper = plot.y_mapper)

    plot.y_grid.set(data_min = ymin, data_max = ymax,
            transverse_bounds = (xmin, xmax),
            transverse_mapper = plot.x_mapper)

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)

    # A second plot whose vertical grid lines are clipped to the jn(3) function
    def my_bounds_func(ticks):
        """ Returns y_low and y_high for each grid tick in the array **ticks** """
        tmp = array([zeros(len(ticks)),jn(3, ticks)]).T
        return tmp

    func_plot = Plot(pd, padding=50, border_visible=True)
    func_plot.plot(("index", "y3"), color="red")
    func_plot.x_grid.set(transverse_bounds = my_bounds_func,
                    transverse_mapper = func_plot.y_mapper,
                    line_color="black")
    func_plot.tools.append(PanTool(func_plot))

    container = HPlotContainer()
    container.add(plot)
    container.add(func_plot)

    return container
Example #21
0
def _create_plot_component():

    # Create some data
    npts = 100
    x = sort(random(npts))
    y = random(npts)

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)

    # Create the plot
    plot = Plot(pd)
    plot.plot(("index", "value"),
              type="scatter",
              name="my_plot",
              marker="circle",
              index_sort="ascending",
              color="slategray",
              marker_size=6,
              bgcolor="white")

    # Tweak some of the plot properties
    plot.title = "Scatter Plot With Selection"
    plot.line_width = 1
    plot.padding = 50

    # Right now, some of the tools are a little invasive, and we need the
    # actual ScatterPlot object to give to them
    my_plot = plot.plots["my_plot"][0]

    # Attach some tools to the plot
    my_plot.tools.append(ScatterInspector(my_plot, selection_mode="toggle",
                                          persistent_hover=False))
    my_plot.overlays.append(
            ScatterInspectorOverlay(my_plot,
                hover_color = "transparent",
                hover_marker_size = 10,
                hover_outline_color = "purple",
                hover_line_width = 2,
                selection_marker_size = 8,
                selection_color = "lawngreen")
            )

    my_plot.tools.append(PanTool(my_plot))
    my_plot.overlays.append(ZoomTool(my_plot, drag_button="right"))

    return plot
    def _plot_default(self):
        pd = ArrayPlotData()
        plot = Plot(pd, padding = 0)
        self.fid._data = self.panner.buffer

        pd.set_data("imagedata", self.fid)

        img_plot = plot.img_plot("imagedata", colormap=algae,
                                 interpolation='nearest',
                                 xbounds=(0.0, 1.0),
                                 ybounds=(0.0, 1.0))[0]
        self.fid.data_range = plot.range2d
        self.helper.index = img_plot.index
        self.img_plot = img_plot
        return plot
Example #23
0
def main():
    # Create some x-y data series to plot
    x = linspace(-2.0, 10.0, 100)
    pd = ArrayPlotData(index = x)
    for i in range(5):
        pd.set_data("y" + str(i), jn(i,x))

    # Create some line plots of some of the data
    plot = Plot(pd, bgcolor="none", padding=30, border_visible=True,
                 overlay_border=True, use_backbuffer=False)
    plot.legend.visible = True
    plot.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="auto")
    plot.plot(("index", "y3"), name="j_3", color="auto")
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)

    # Create the mlab test mesh and get references to various parts of the
    # VTK pipeline
    f = mlab.figure(size=(600,500))
    m = mlab.test_mesh()
    scene = mlab.gcf().scene
    render_window = scene.render_window
    renderer = scene.renderer
    rwi = scene.interactor

    plot.resizable = ""
    plot.bounds = [200,200]
    plot.padding = 25
    plot.outer_position = [30,30]
    plot.tools.append(MoveTool(component=plot,drag_button="right"))

    container = OverlayPlotContainer(bgcolor = "transparent",
                    fit_window = True)
    container.add(plot)

    # Create the Enable Window
    window = EnableVTKWindow(rwi, renderer,
            component=container,
            #istyle_class = tvtk.InteractorStyleSwitch,
            #istyle_class = tvtk.InteractorStyle,
            istyle_class = tvtk.InteractorStyleTrackballCamera,
            bgcolor = "transparent",
            event_passthrough = True,
            )

    mlab.show()
    return window, render_window
Example #24
0
    def _plot_default(self):
        pd = ArrayPlotData()
        plot = Plot(pd, padding=0)
        self.fid._data = self.panner.buffer

        pd.set_data("imagedata", self.fid)

        img_plot = plot.img_plot("imagedata",
                                 colormap=algae,
                                 interpolation='nearest',
                                 xbounds=(0.0, 1.0),
                                 ybounds=(0.0, 1.0))[0]
        self.fid.data_range = plot.range2d
        self.helper.index = img_plot.index
        self.img_plot = img_plot
        return plot
Example #25
0
def _create_plot_component():

    # Use n_gon to compute center locations for our polygons
    points = n_gon(center=(0, 0), r=4, nsides=8)

    # Choose some colors for our polygons
    colors = {
        3: 0xAABBCC,
        4: "orange",
        5: "yellow",
        6: "lightgreen",
        7: "green",
        8: "blue",
        9: "lavender",
        10: "purple",
    }

    # Create a PlotData object to store the polygon data
    pd = ArrayPlotData()

    # Create a Polygon Plot to draw the regular polygons
    polyplot = Plot(pd)

    # Store path data for each polygon, and plot
    nsides = 3
    for p in points:
        npoints = n_gon(center=p, r=2, nsides=nsides)
        nxarray, nyarray = transpose(npoints)
        pd.set_data("x" + str(nsides), nxarray)
        pd.set_data("y" + str(nsides), nyarray)
        plot = polyplot.plot(
            ("x" + str(nsides), "y" + str(nsides)), type="polygon", face_color=colors[nsides], hittest_type="poly"
        )[0]
        plot.tools.append(DataspaceMoveTool(plot, drag_button="right"))
        nsides = nsides + 1

    # Tweak some of the plot properties
    polyplot.padding = 50
    polyplot.title = "Polygon Plot"

    # Attach some tools to the plot
    polyplot.tools.append(PanTool(polyplot))
    zoom = ZoomTool(polyplot, tool_mode="box", always_on=False)
    polyplot.overlays.append(zoom)

    return polyplot
def create_plot():
    numpoints = 100
    low = -5
    high = 15.0
    x = linspace(low, high, numpoints)
    pd = ArrayPlotData(index=x)
    p = Plot(pd, bgcolor="lightgray", padding=50, border_visible=True)
    for i in range(10):
        pd.set_data("y" + str(i), jn(i,x))
        p.plot(("index", "y" + str(i)), color=tuple(COLOR_PALETTE[i]),
               width = 2.0 * dpi_scale)
    p.x_grid.visible = True
    p.x_grid.line_width *= dpi_scale
    p.y_grid.visible = True
    p.y_grid.line_width *= dpi_scale
    p.legend.visible = True
    return p
Example #27
0
def _create_plot_component():

    # Create some data
    npts = 2000
    x = sort(random(npts))
    y = random(npts)

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("index", x)
    pd.set_data("value", y)

    # Create the plot
    plot = Plot(pd)
    plot.plot(("index", "value"),
              type="scatter",
              name="my_plot",
              marker="circle",
              index_sort="ascending",
              color="red",
              marker_size=4,
              bgcolor="white")

    # Tweak some of the plot properties
    plot.title = "Scatter Plot With Selection"
    plot.line_width = 1
    plot.padding = 50

    # Right now, some of the tools are a little invasive, and we need the
    # actual ScatterPlot object to give to them
    my_plot = plot.plots["my_plot"][0]

    # Attach some tools to the plot
    lasso_selection = LassoSelection(component=my_plot,
                                     selection_datasource=my_plot.index)
    my_plot.active_tool = lasso_selection
    my_plot.tools.append(ScatterInspector(my_plot))
    lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                 component=my_plot)
    my_plot.overlays.append(lasso_overlay)

    # Uncomment this if you would like to see incremental updates:
    #lasso_selection.incremental_select = True

    return plot
Example #28
0
class TracePlot(BasePlot):
    def __init__(self, parent, **kw):

        self._type = kw.pop('type', 'scatter')
        self.nr_of_points = kw.pop('nr_of_points', 0)
        # TODO: more options
        
        BasePlot.__init__(self, parent, **kw)

    def add_point(self, x, y):
        self._x.append(x)
        self._y.append(y)
        self._set_data()

    def set_nr_of_points(self, n):
        self.nr_of_points = n
        self._set_data()

    def reset(self):
        self._x = []
        self._y = []
        self._set_data()

    def _set_data(self):
        if self.nr_of_points > 0:
            while len(self._x) > self.nr_of_points:
                self._x = self._x[1:]
                self._y = self._y[1:]
        self.data.set_data('x', self._x)
        self.data.set_data('y', self._y)
        
    def _create_window(self, **kw):
        self.data = ArrayPlotData()
        self.plot = Plot(self.data)

        self._x = []
        self._y = []
        self.data.set_data('x', self._x)
        self.data.set_data('y', self._y)

        self.plot.plot(('x', 'y'),
                       type = self._type,
                       name = 'trace')

        return Window(self, -1, component=self.plot)
Example #29
0
    def create_plot(self):
        plot_data = ArrayPlotData(rev = self.get_rev())
        self.plot = Plot(plot_data,title='Temperature')
        
        plot_data.set_data('T', self.get_sci('T'))
        self.plot.plot(('rev','T'))

        color = {'Q':'red','U':'green'}
        self.plot_P = Plot(plot_data,title='Polarization %s' % color)

        for qu in ['Q','U']:
            plot_data.set_data(qu, self.get_sci(qu))
            self.plot_P.plot(('rev',qu),color=color[qu])

        for p in [self.plot,self.plot_P]:
            zoom = ZoomTool(p, tool_mode="box", always_on=False)
            p.overlays.append(zoom)
            p.tools.append(PanTool(p))
Example #30
0
    def __init__(self, **kw):
        super(WorldMapPlot, self).__init__(**kw)

        self._download_map_image()
        image = ImageData.fromfile(self.image_path)

        # For now, the locations are hardcoded, though this can be changed
        # eassily to take command line args, read from a file, or by other
        # means
        austin_loc = (30.16, -97.44)

        locations_x = numpy.array([austin_loc[1]])
        locations_y = numpy.array([austin_loc[0]])

        # transform each of the locations to the image data space, including
        # moving the origin from bottom left to top left
        locations_x = (locations_x + 180) * image.data.shape[1]/360
        locations_y = (locations_y*-1 + 90) * image.data.shape[0]/180

        # Create the plott data, adding the image and the locations
        plot_data = ArrayPlotData()
        plot_data.set_data("imagedata", image._data)
        plot_data.set_data("locations_x", locations_x)
        plot_data.set_data("locations_y", locations_y)

        # Create the plot with the origin as top left, which matches
        # how the image data is aligned
        self.plot = Plot(plot_data, default_origin="top left")
        self.plot.img_plot('imagedata')

        # Plot the locations as a scatter plot to be overlayed on top
        # of the map
        loc_plot = self.plot.plot(('locations_x',  'locations_y'),
                                    type='scatter', size=3, color='yellow',
                                    marker='dot')[0]

        loc_plot.x_mapper.range.high = image.data.shape[1]
        loc_plot.x_mapper.range.low = 0
        loc_plot.y_mapper.range.high = image.data.shape[0]
        loc_plot.y_mapper.range.low = -0

        # set up any tools, in this case just the zoom tool
        zoom = ZoomTool(component=self.plot, tool_mode="box", always_on=False)
        self.plot.overlays.append(zoom)
Example #31
0
class sinDataView(HasTraits):
	x = Array
	y1 = Array
	y2 = Array
	freqY1 = Range(low=1.0,high=10.0,value=1.0)
	freqY2 = Range(low=1.0,high=10.0,value=2.0)
	myPlot = Instance(Plot)

	traits_view = View(
	              Item('myPlot',
	                   editor=ComponentEditor(),
	                   show_label=False),
	              Item(name='freqY1'),
	              Item(name='freqY2'),
	              buttons = [OKButton],
	              resizable=True,
	              title='Window Title',
	              width=500,height=600)

	def __init__(self):

		# data ranges
		self.x = arange(-10,10,0.01)
		self.y1 = sin(self.freqY1*self.x)
		self.y2 = sin(self.freqY2*self.x)

		self.plotdata = ArrayPlotData(x=self.x,y1=self.y1,y2=self.y2)
		self.myPlot = Plot(self.plotdata)
		self.myPlot.plot(("x","y1"),type="line", color="blue",name = 'Y1')
		self.myPlot.plot(("x","y2"),type="line", color="red",name = 'Y2')
		self.myPlot.legend.visible = True
		self.myPlot.title = "ArrayPlotData Example"

	def _freqY1_changed(self):
		self.y1 = sin(self.freqY1*self.x)
		# set_data is necessary to update ArrayPlotData object
		# which then updates the plots
		self.plotdata.set_data("y1",self.y1)

	def _freqY2_changed(self):
		self.y2 = sin(self.freqY2*self.x)
		self.plotdata.set_data("y2",self.y2)
Example #32
0
def _create_plot_component():
    # Create some x-y data series to plot
    x = linspace(-2.0, 10.0, 100)
    pd = ArrayPlotData(index = x)
    for i in range(5):
        pd.set_data("y" + str(i), jn(i,x))

    # Create some line plots of some of the data
    plot1 = Plot(pd)
    plot1.plot(("index", "y0", "y1", "y2"), name="j_n, n<3", color="red")
    plot1.plot(("index", "y3"), name="j_3", color="blue")

    # Tweak some of the plot properties
    plot1.title = "Inset Plot"
    plot1.padding = 50

    # Attach some tools to the plot
    plot1.tools.append(PanTool(plot1))
    zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
    plot1.overlays.append(zoom)

    # Create a second scatter plot of one of the datasets, linking its
    # range to the first plot
    plot2 = Plot(pd, range2d=plot1.range2d, padding=50)
    plot2.plot(('index', 'y3'), type="scatter", color="blue", marker="circle")
    plot2.set(resizable = "",
              bounds = [250, 250],
              position = [550,150],
              bgcolor = "white",
              border_visible = True,
              unified_draw = True
              )
    plot2.tools.append(PanTool(plot2))
    plot2.tools.append(MoveTool(plot2, drag_button="right"))
    zoom = ZoomTool(component=plot2, tool_mode="box", always_on=False)
    plot2.overlays.append(zoom)

    # Create a container and add our plots
    container = OverlayPlotContainer()
    container.add(plot1)
    container.add(plot2)
    return container
def create_plot(num_plots=8, type="line"):
    """ Create a single plot object, with multiple renderers. """
    # This is a bit of a hack to work around that line widths don't scale
    # with the GraphicsContext's CTM.
    dpi_scale = DPI / 72.0
    numpoints = 100
    low = -5
    high = 15.0
    x = linspace(low, high, numpoints)
    pd = ArrayPlotData(index=x)
    p = Plot(pd, bgcolor="white", padding=50, border_visible=True)
    for i in range(1, num_plots + 2):
        pd.set_data("y" + str(i), jn(i, x))
        p.plot(("index", "y" + str(i)), color=tuple(COLOR_PALETTE[i]), width=2.0 * dpi_scale, type=type)
    p.x_grid.visible = True
    p.x_grid.line_width *= dpi_scale
    p.y_grid.visible = True
    p.y_grid.line_width *= dpi_scale
    p.legend.visible = True
    return p
Example #34
0
    def activate_template(self):
        """ Converts all contained 'TDerived' objects to real objects using the
            template traits of the object. This method must be overridden in
            subclasses.
            
            Returns
            -------
            None
        """
        # If our data sources are still unbound, then just exit; someone must
        # have marked them as optional:
        if ((self.index.context_data is Undefined)
                or (self.value.context_data is Undefined)):
            return

        # Create a plot data object and give it this data:
        pd = ArrayPlotData()
        pd.set_data('index', self.index.context_data)
        pd.set_data('value', self.value.context_data)

        # Create the plot:
        self.plot = plot = Plot(pd)
        plot.plot(('index', 'value'),
                  type='scatter',
                  index_sort='ascending',
                  marker=self.marker,
                  color=self.color,
                  outline_color=self.outline_color,
                  marker_size=self.marker_size,
                  line_width=self.line_width,
                  bgcolor='white')
        plot.set(padding_left=50,
                 padding_right=0,
                 padding_top=0,
                 padding_bottom=20)

        # Attach some tools to the plot:
        plot.tools.append(PanTool(plot, constrain_key='shift'))
        zoom = SimpleZoom(component=plot, tool_mode='box', always_on=False)
        plot.overlays.append(zoom)
def _create_plot_component():

    # Create a scalar field to contour
    xs = linspace(-2 * pi, 2 * pi, 600)
    ys = linspace(-1.5 * pi, 1.5 * pi, 300)
    x, y = meshgrid(xs, ys)
    z = tanh(x * y / 6) * cosh(exp(-y**2) * x / 3)
    z = x * y

    # Create a plot data obect and give it this data
    pd = ArrayPlotData()
    pd.set_data("imagedata", z)

    # Create a contour polygon plot of the data
    plot = Plot(pd, default_origin="top left")
    plot.contour_plot("imagedata",
                      type="poly",
                      poly_cmap=jet,
                      xbounds=(xs[0], xs[-1]),
                      ybounds=(ys[0], ys[-1]))

    # Create a contour line plot for the data, too
    plot.contour_plot("imagedata",
                      type="line",
                      xbounds=(xs[0], xs[-1]),
                      ybounds=(ys[0], ys[-1]))

    # Tweak some of the plot properties
    plot.title = "My First Contour Plot"
    plot.padding = 50
    plot.bg_color = "white"
    plot.fill_padding = True

    # Attach some tools to the plot
    plot.tools.append(PanTool(plot))
    zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
    plot.overlays.append(zoom)
    return plot
Example #36
0
 def render_scatplot(self):
     peakdata = ArrayPlotData()
     peakdata.set_data("index", self.peaks[self.img_idx][:, 0])
     peakdata.set_data("value", self.peaks[self.img_idx][:, 1])
     peakdata.set_data("color", self.peaks[self.img_idx][:, 2])
     scatplot = Plot(peakdata,
                     aspect_ratio=self.img_plot.aspect_ratio,
                     default_origin="top left")
     scatplot.plot(
         ("index", "value", "color"),
         type="cmap_scatter",
         name="my_plot",
         color_mapper=jet(DataRange1D(low=0.0, high=1.0)),
         marker="circle",
         fill_alpha=0.5,
         marker_size=6,
     )
     scatplot.x_grid.visible = False
     scatplot.y_grid.visible = False
     scatplot.range2d = self.img_plot.range2d
     self.scatplot = scatplot
     self.peakdata = peakdata
     return scatplot
Example #37
0
 def _plotdata_default(self):
     arr = self.data
     plotdata = ArrayPlotData(
         zz=numpy.array([[0]]),
         z_x=[],
         z_y=[],
         y_x=[],
         y_z=[],
         x_y=[],
         x_z=[],
     )
     plotdata.set_data('xy', arr[0])
     plotdata.set_data('xz', arr[:, 0])
     plotdata.set_data('zy', arr[:, :, 0].T)
     return plotdata
Example #38
0
    def get_plot(self):
        #pd = ArrayPlotData()
        index_label = 'index'
        index = None
        colors = ['purple','blue','green','gold', 'orange', 'red', 'black']
        groups = defaultdict(lambda:[])

        pd = ArrayPlotData()

        index_values = None
        if 'relative time' in self.table:
            index_key = 'relative time'
            index_values = self.table[index_key]
        else:
            index_key = 'index'
        index_label = index_key

        for key, values in self.table.items():
            if index_values is None:
                index_values = range(len(values))
            if key==index_label:
                continue
            if key.startswith('stage '):
                label = key[6:].strip()
                group = groups['stages']
            elif key=='contact position':
                label = key
                group = groups['stages']
            elif key.startswith('fiber '):

                if key.endswith('deformation'):
                    label = key[5:-11].strip()
                    group = groups['fiber deformation']
                elif key.endswith('position'):
                    label = key[5:-8].strip()
                    group = groups['fiber position']
                else:
                    label = key[5:].strip ()
                    group = groups['fiber']
            elif key.startswith('sarcomere '):
                label = key[10:].strip()
                if label=='orientation': # this is artificial information
                    continue
                group = groups['sarcomere']
            else:
                label = key
                group = groups[key]

            group.append((index_label, label, index_key, key))
            pd.set_data(key, values)

        pd.set_data (index_key, index_values)

        if 'force' in self.table and 'stage right current' in self.table:
            group = groups['position-force']
            group.append(('position','force','stage right current','force'))



        n = len (groups)
        if n in [0,1,2,3,5,7]:
            shape = (n, 1)
        elif n in [4,6,8,10]:
            shape = (n//2,2)
        elif n in [9]:
            shape = (n//3,3)
        else:
            raise NotImplementedError (`n`)

        container = GridContainer(padding=10, #fill_padding=True,
                                  #bgcolor="lightgray", use_backbuffer=True,
                                  shape=shape, spacing=(0,0))

        for i, (group_label, group_info) in enumerate(groups.items ()):
            plot = Plot (pd)
            for j, (index_label, label, index_key, key) in enumerate(group_info):
                color = colors[j % len (colors)]
                plot.plot((index_key, key), name=label, color=color, x_label=index_label)
            plot.legend.visible = True
            plot.title = group_label
            plot.tools.append(PanTool(plot))
            zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
            plot.overlays.append(zoom)
            container.add (plot)

        return container
class MatrixViewer(HasTraits):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    edge_para = Any
    data_name = Enum("a", "b")

    fro = Int
    to = Int
    data = None
    val = Float

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(800, 600)),
                                  show_label=False),
                             HGroup(
                                 Item('fro',
                                      label="From",
                                      style='readonly',
                                      springy=True),
                                 Item('to',
                                      label="To",
                                      style='readonly',
                                      springy=True),
                                 Item('val',
                                      label="Value",
                                      style='readonly',
                                      springy=True),
                             ),
                             orientation="vertical"),
                       Item('data_name', label="Image data"),
                       handler=CustomHandler(),
                       resizable=True,
                       title="Matrix Viewer")

    def __init__(self, data, **traits):
        """ Data is a nxn numpy array """
        super(HasTraits, self).__init__(**traits)

        self.data_name = data.keys()[0]
        self.data = data
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

    def _data_name_changed(self, old, new):
        self.pd.set_data("imagedata", self.data[self.data_name])
        self.my_plot.set_value_selection((0, 2))

    def _update_fields(self):
        from numpy import trunc

        # map mouse location to array index
        frotmp = int(trunc(self.custtool.yval))
        totmp = int(trunc(self.custtool.xval))

        # check if within range
        sh = self.data[self.data_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            self.fro = frotmp
            self.to = totmp
            self.val = self.data[self.data_name][self.fro, self.to]

    def _create_plot_component(self):

        # Create a plot data object and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.data[self.data_name])

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot(
            "imagedata",
            name="my_plot",
            #xbounds=(0,10),
            #ybounds=(0,10),
            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = "Matrix"
        self.tplot.padding = 50

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        return container
Example #40
0
class TriangleWave(HasTraits):
    # 指定三角波的最窄和最宽范围,由于Range类型不能将常数和Traits属性名混用
    # 所以定义这两个值不变的Trait属性
    low = Float(0.02)
    hi = Float(1.0)

    # 三角波形的宽度
    wave_width = Range("low", "hi", 0.5)

    # 三角波的顶点C的x轴坐标
    length_c = Range("low", "wave_width", 0.5)

    # 三角波的定点的y轴坐标
    height_c = Float(1.0)

    # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择
    fftsize = Enum([(2**x) for x in range(6, 12)])

    # FFT频谱图的x轴上限值
    fft_graph_up_limit = Range(0, 400, 20)

    # 用于显示FFT的结果
    peak_list = Str

    # 采用多少个频率合成三角波
    N = Range(1, 40, 4)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    plot_wave = Instance(Component)

    # 绘制FFT频谱图的容器
    plot_fft = Instance(Component)

    # 包括两个绘图的容器
    container = Instance(Component)

    # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化
    view = View(HSplit(
        VSplit(
            VGroup(Item("wave_width", editor=scrubber, label="波形宽度"),
                   Item("length_c", editor=scrubber, label="最高点x坐标"),
                   Item("height_c", editor=scrubber, label="最高点y坐标"),
                   Item("fft_graph_up_limit", editor=scrubber, label="频谱图范围"),
                   Item("fftsize", label="FFT点数"), Item("N", label="合成波频率数")),
            Item("peak_list",
                 style="custom",
                 show_label=False,
                 width=100,
                 height=250)),
        VGroup(Item("container",
                    editor=ComponentEditor(size=(600, 300)),
                    show_label=False),
               orientation="vertical")),
                resizable=True,
                width=800,
                height=600,
                title="三角波FFT演示")

    # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以
    # 减少重复代码
    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        return p

    def __init__(self):
        # 首先需要调用父类的初始化函数
        super(TriangleWave, self).__init__()

        # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用
        self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[])

        # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列
        self.container = VPlotContainer()

        # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2)
        self.plot_wave = self._create_plot(("x", "y"), "Triangle Wave")
        self.plot_wave.plot(("x2", "y2"), color="red")

        # 创建频谱图,使用数据集中的f和p
        self.plot_fft = self._create_plot(("f", "p"), "FFT", type="scatter")

        # 将两个绘图容器添加到垂直容器中
        self.container.add(self.plot_wave)
        self.container.add(self.plot_fft)

        # 设置
        self.plot_wave.x_axis.title = "Samples"
        self.plot_fft.x_axis.title = "Frequency pins"
        self.plot_fft.y_axis.title = "(dB)"

        # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值
        self.fftsize = 1024

    # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性
    def _fft_graph_up_limit_changed(self):
        self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit

    def _N_changed(self):
        self.plot_sin_combine()

    # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定
    @on_trait_change("wave_width, length_c, height_c, fftsize")
    def update_plot(self):
        # 计算三角波
        global y_data
        x_data = np.arange(0, 1.0, 1.0 / self.fftsize)
        func = self.triangle_func()
        # 将func函数的返回值强制转换成float64
        y_data = np.cast["float64"](func(x_data))

        # 计算频谱
        fft_parameters = np.fft.fft(y_data) / len(y_data)

        # 计算各个频率的振幅
        fft_data = np.clip(
            20 * np.log10(np.abs(fft_parameters))[:self.fftsize / 2 + 1], -120,
            120)

        # 将计算的结果写进数据集
        self.plot_data.set_data("x", np.arange(0, self.fftsize))  # x坐标为取样点
        self.plot_data.set_data("y", y_data)
        self.plot_data.set_data("f", np.arange(0, len(fft_data)))  # x坐标为频率编号
        self.plot_data.set_data("p", fft_data)

        # 合成波的x坐标为取样点,显示2个周期
        self.plot_data.set_data("x2", np.arange(0, 2 * self.fftsize))

        # 更新频谱图x轴上限
        self._fft_graph_up_limit_changed()

        # 将振幅大于-80dB的频率输出
        peak_index = (fft_data > -80)
        peak_value = fft_data[peak_index][:20]
        result = []
        for f, v in zip(np.flatnonzero(peak_index), peak_value):
            result.append("%s : %s" % (f, v))
        self.peak_list = "\n".join(result)

        # 保存现在的fft计算结果,并计算正弦合成波
        self.fft_parameters = fft_parameters
        self.plot_sin_combine()

    # 计算正弦合成波,计算2个周期
    def plot_sin_combine(self):
        index, data = fft_combine(self.fft_parameters, self.N, 2)
        self.plot_data.set_data("y2", data)

    # 返回一个ufunc计算指定参数的三角波
    def triangle_func(self):
        c = self.wave_width
        c0 = self.length_c
        hc = self.height_c

        def trifunc(x):
            x = x - int(x)  # 三角波的周期为1,因此只取x坐标的小数部分进行计算
            if x >= c: r = 0.0
            elif x < c0: r = x / c0 * hc
            else: r = (c - x) / (c - c0) * hc
            return r

        # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数
        # 计算得到的是一个Object数组,需要进行类型转换
        return np.frompyfunc(trifunc, 1, 1)
Example #41
0
class EqualizerDesigner(HasTraits):
    '''均衡器设计器的主界面'''

    equalizers = Instance(Equalizers)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    container = Instance(Component)

    plot_gain = Instance(Component)
    plot_phase = Instance(Component)
    save_button = Button("Save")
    load_button = Button("Load")
    export_button = Button("Export")

    view = View(VGroup(
        HGroup(Item("load_button"),
               Item("save_button"),
               Item("export_button"),
               show_labels=False),
        HSplit(
            VGroup(
                Item("equalizers", style="custom", show_label=False),
                show_border=True,
            ),
            Item("container",
                 editor=ComponentEditor(size=(800, 300)),
                 show_label=False),
        )),
                resizable=True,
                width=800,
                height=500,
                title="Equalizer Designer")

    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        p.index_scale = "log"
        return p

    def __init__(self):
        super(EqualizerDesigner, self).__init__()
        self.plot_data = ArrayPlotData(f=FREQS, gain=[], phase=[])
        self.plot_gain = self._create_plot(("f", "gain"), "Gain(dB)")
        self.plot_phase = self._create_plot(("f", "phase"), "Phase(degree)")
        self.container = VPlotContainer()
        self.container.add(self.plot_phase)
        self.container.add(self.plot_gain)
        self.plot_gain.padding_bottom = 20
        self.plot_phase.padding_top = 20

    def _equalizers_default(self):
        return Equalizers()

    @on_trait_change("equalizers.h")
    def redraw(self):
        gain = 20 * np.log10(np.abs(self.equalizers.h))
        phase = np.angle(self.equalizers.h, deg=1)
        self.plot_data.set_data("gain", gain)
        self.plot_data.set_data("phase", phase)

    def _save_button_fired(self):
        dialog = FileDialog(action="save as", wildcard='EQ files (*.eq)|*.eq')
        result = dialog.open()
        if result == OK:
            f = file(dialog.path, "wb")
            pickle.dump(self.equalizers, f)
            f.close()

    def _load_button_fired(self):
        dialog = FileDialog(action="open", wildcard='EQ files (*.eq)|*.eq')
        result = dialog.open()
        if result == OK:
            f = file(dialog.path, "rb")
            self.equalizers = pickle.load(f)
            f.close()

    def _export_button_fired(self):
        dialog = FileDialog(action="save as", wildcard='c files (*.c)|*.c')
        result = dialog.open()
        if result == OK:
            self.equalizers.export(dialog.path)
Example #42
0
class Plot3D(HasTraits):

    plot = Instance(Component)
    name = 'Scan Plot'
    id = 'radpy.plugins.BeamAnalysis.ChacoPlot'
    current_dose = Float()

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(400, 300)),
                                  show_label=False),
                             Item('current_dose'),
                             id='radpy.plugins.BeamAnalysis.ChacoPlotItems'),
                       resizable=True,
                       title="Scan Plot",
                       width=400,
                       height=300,
                       id='radpy.plugins.BeamAnalysis.ChacoPlotView')
    plot_type = String

    # These are the indices into the cube that each of the image plot views
    # will show; the default values are non-zero just to make it a little
    # interesting.
    slice_x = 0
    slice_y = 0
    slice_z = 0

    num_levels = Int(15)
    colormap = Any
    colorcube = Any

    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _cmap = Trait(jet, Callable)

    def _update_indices(self, token, axis, index_x):
        for i in self.center.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

        for i in self.right.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

        for i in self.bottom.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

    def _update_positions(self, token, value):

        for i in self.center.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

        for i in self.right.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

        for i in self.bottom.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

    def _index_callback(self, tool, axis, index, value):
        plane = tool.token
        if plane == "xy":
            if axis == "index" or axis == "index_x":
                self.slice_x = index
                self._update_indices("xz", "index_x", index)
                self._update_positions("yz", value)
            else:
                self.slice_y = index
                self._update_indices("yz", "index_y", index)
                self._update_positions("xz", value)
        elif plane == "yz":
            if axis == "index" or axis == "index_x":
                self.slice_z = index
                self._update_indices("xz", "index_y", index)
                self._update_positions("xy", value)
            else:
                self.slice_y = index
                self._update_indices("xy", "index_y", index)
                self._update_positions("xz", value)
        elif plane == "xz":
            if axis == "index" or axis == "index_x":
                self.slice_x = index
                self._update_indices("xy", "index_x", index)
                self._update_positions("yz", value)
            else:
                self.slice_z = index
                self._update_indices("yz", "index_x", index)
                self._update_positions("xy", value)
        else:
            warnings.warn("Unrecognized plane for _index_callback: %s" % plane)

        self._update_images()
        self.center.invalidate_and_redraw()
        self.right.invalidate_and_redraw()
        self.bottom.invalidate_and_redraw()
        return

    def _plot_type_default(self):
        return '3D_dose'

    def _plot_default(self):
        return self._create_plot_component()

    def _add_plot_tools(self, imgplot, token):
        """ Add LineInspectors, ImageIndexTool, and ZoomTool to the image plots. """

        imgplot.overlays.append(
            ZoomTool(component=imgplot,
                     tool_mode="box",
                     enable_wheel=False,
                     always_on=False))
        imgplot.overlays.append(
            MyLineInspector(imgplot,
                            axis="index_y",
                            color="grey",
                            inspect_mode="indexed",
                            callback=self._index_callback,
                            token=token))

        imgplot.overlays.append(
            MyLineInspector(imgplot,
                            axis="index_x",
                            color="grey",
                            inspect_mode="indexed",
                            callback=self._index_callback,
                            token=token))
        imgplot.overlays.append(
            MyTextBoxOverlay(imgplot,
                             token=token,
                             align='lr',
                             bgcolor='white',
                             font='Arial 12'))

    def _create_plot_component(self):
        container = GridPlotContainer(padding=30,
                                      fill_padding=True,
                                      bgcolor="white",
                                      use_backbuffer=True,
                                      shape=(2, 2),
                                      spacing=(30, 30))
        self.container = container

        #
        return container
#

    def add_plot(self, label, beam):

        #        container = GridPlotContainer(padding=20, fill_padding=True,
        #                                      bgcolor="white", use_backbuffer=True,
        #                                      shape=(2,2), spacing=(12,12))
        #        self.container = container
        self.plotdata = ArrayPlotData()
        self.model = beam.Data
        self.model.z_axis = self.model.z_axis[::-1]
        cmap = jet
        self._update_model(cmap)
        self.plotdata.set_data("xy", self.model.dose)
        self._update_images()

        # Center Plot
        centerplot = Plot(self.plotdata,
                          resizable='hv',
                          height=150,
                          width=150,
                          padding=0)
        centerplot.default_origin = 'top left'
        imgplot = centerplot.img_plot("xy",
                                      xbounds=(self.model.x_axis[0],
                                               self.model.x_axis[-1]),
                                      ybounds=(self.model.y_axis[0],
                                               self.model.y_axis[-1]),
                                      colormap=cmap)[0]

        imgplot.origin = 'top left'
        self._add_plot_tools(imgplot, "xy")
        left_axis = PlotAxis(centerplot, orientation='left', title='y')
        bottom_axis = PlotAxis(centerplot,
                               orientation='bottom',
                               title='x',
                               title_spacing=30)
        centerplot.underlays.append(left_axis)
        centerplot.underlays.append(bottom_axis)
        self.center = imgplot

        # Right Plot
        rightplot = Plot(self.plotdata,
                         height=150,
                         width=150,
                         resizable="hv",
                         padding=0)
        rightplot.default_origin = 'top left'
        rightplot.value_range = centerplot.value_range
        imgplot = rightplot.img_plot("yz",
                                     xbounds=(self.model.z_axis[0],
                                              self.model.z_axis[-1]),
                                     ybounds=(self.model.y_axis[0],
                                              self.model.y_axis[-1]),
                                     colormap=cmap)[0]
        imgplot.origin = 'top left'
        self._add_plot_tools(imgplot, "yz")
        left_axis = PlotAxis(rightplot, orientation='left', title='y')
        bottom_axis = PlotAxis(rightplot,
                               orientation='bottom',
                               title='z',
                               title_spacing=30)
        rightplot.underlays.append(left_axis)
        rightplot.underlays.append(bottom_axis)
        self.right = imgplot

        # Bottom Plot
        bottomplot = Plot(self.plotdata,
                          height=150,
                          width=150,
                          resizable="hv",
                          padding=0)
        bottomplot.index_range = centerplot.index_range
        imgplot = bottomplot.img_plot("xz",
                                      xbounds=(self.model.x_axis[0],
                                               self.model.x_axis[-1]),
                                      ybounds=(self.model.z_axis[0],
                                               self.model.z_axis[-1]),
                                      colormap=cmap)[0]
        self._add_plot_tools(imgplot, "xz")
        left_axis = PlotAxis(bottomplot, orientation='left', title='z')
        bottom_axis = PlotAxis(bottomplot,
                               orientation='bottom',
                               title='x',
                               title_spacing=30)
        bottomplot.underlays.append(left_axis)
        bottomplot.underlays.append(bottom_axis)
        self.bottom = imgplot

        # Create Container and add all Plots
        #        container = GridPlotContainer(padding=20, fill_padding=True,
        #                                      bgcolor="white", use_backbuffer=True,
        #                                      shape=(2,2), spacing=(12,12))
        self.container.add(centerplot)
        self.container.add(rightplot)
        self.container.add(bottomplot)

        #return Window(self, -1, component=container)

        #        return container

        return label

    def _update_images(self):
        """ Updates the image data in self.plotdata to correspond to the 
        slices given.
        """
        cube = self.colorcube
        pd = self.plotdata
        # These are transposed because img_plot() expects its data to be in
        # row-major order

        pd.set_data("xy", numpy.transpose(cube[:, :, self.slice_z], (1, 0, 2)))
        pd.set_data("xz", numpy.transpose(cube[:, self.slice_y, :], (1, 0, 2)))
        pd.set_data("yz", cube[self.slice_x, :, :])
        self.current_dose = self.model.dose[self.slice_x][self.slice_y][
            self.slice_z]
        return

    def _update_model(self, cmap):
        range = DataRange1D(low=numpy.amin(self.model.dose),
                            high=numpy.amax(self.model.dose))
        self.colormap = cmap(range)
        self.colorcube = (self.colormap.map_screen(self.model.dose) *
                          255).astype(numpy.uint8)
Example #43
0
class LFapplication(HasTraits):

    traits_view = View(Item('LF_img',
                            editor=ComponentEditor(),
                            show_label=False),
                       Item('X_angle', label='Angle in the X axis'),
                       Item('Y_angle', label='Angle in the Y axis'),
                       resizable=True,
                       title="LF Image")

    def __init__(self, img_path):
        super(LFapplication, self).__init__()

        #
        # Load image data
        #
        base_path = os.path.splitext(img_path)[0]
        lenslet_path = base_path + '-lenslet.txt'
        optics_path = base_path + '-optics.txt'

        with open(lenslet_path, 'r') as f:
            tmp = eval(f.readline())
            x_offset, y_offset, right_dx, right_dy, down_dx, down_dy = \
                     np.array(tmp, dtype=np.float32)

        with open(optics_path, 'r') as f:
            for line in f:
                name, val = line.strip().split()
                try:
                    setattr(self, name, np.float32(val))
                except:
                    pass

        max_angle = math.atan(self.pitch / 2 / self.flen)

        #
        # Prepare image
        #
        im_pil = Image.open(img_path)
        if im_pil.mode == 'RGB':
            self.NCHANNELS = 3
            w, h = im_pil.size
            im = np.zeros((h, w, 4), dtype=np.float32)
            im[:, :, :3] = np.array(im_pil).astype(np.float32)
            self.LF_dim = (ceil(h / down_dy), ceil(w / right_dx), 3)
        else:
            self.NCHANNELS = 1
            im = np.array(im_pil.getdata()).reshape(im_pil.size[::-1]).astype(
                np.float32)
            h, w = im.shape
            self.LF_dim = (ceil(h / down_dy), ceil(w / right_dx))

        x_start = x_offset - int(x_offset / right_dx) * right_dx
        y_start = y_offset - int(y_offset / down_dy) * down_dy
        x_ratio = self.flen * right_dx / self.pitch
        y_ratio = self.flen * down_dy / self.pitch

        #
        # Generate the cuda kernel
        #
        mod_LFview = pycuda.compiler.SourceModule(
            _kernel_tpl.render(newiw=self.LF_dim[1],
                               newih=self.LF_dim[0],
                               oldiw=w,
                               oldih=h,
                               x_start=x_start,
                               y_start=y_start,
                               x_ratio=x_ratio,
                               y_ratio=y_ratio,
                               x_step=right_dx,
                               y_step=down_dy,
                               NCHANNELS=self.NCHANNELS))

        self.LFview_func = mod_LFview.get_function("LFview_kernel")
        self.texref = mod_LFview.get_texref("tex")

        #
        # Now generate the cuda texture
        #
        if self.NCHANNELS == 3:
            cuda.bind_array_to_texref(
                cuda.make_multichannel_2d_array(im, order="C"), self.texref)
        else:
            cuda.matrix_to_texref(im, self.texref, order="C")

        #
        # We could set the next if we wanted to address the image
        # in normalized coordinates ( 0 <= coordinate < 1.)
        # texref.set_flags(cuda.TRSF_NORMALIZED_COORDINATES)
        #
        self.texref.set_filter_mode(cuda.filter_mode.LINEAR)

        #
        # Prepare the traits
        #
        self.add_trait('X_angle', Range(-max_angle, max_angle, 0.0))
        self.add_trait('Y_angle', Range(-max_angle, max_angle, 0.0))

        self.plotdata = ArrayPlotData(LF_img=self.sampleLF())
        self.LF_img = Plot(self.plotdata)
        if self.NCHANNELS == 3:
            self.LF_img.img_plot("LF_img")
        else:
            self.LF_img.img_plot("LF_img", colormap=gray)

    def sampleLF(self):
        #
        # Get the output image
        #
        output = np.zeros(self.LF_dim, dtype=np.uint8)

        #
        # Calculate the gridsize. This is entirely given by the size of our image.
        #
        blocks = (16, 16, 1)
        gridx = ceil(self.LF_dim[1] / blocks[1])
        gridy = ceil(self.LF_dim[0] / blocks[0])
        grid = (gridx, gridy)

        #
        # Call the kernel
        #
        self.LFview_func(np.float32(self.X_angle),
                         np.float32(self.Y_angle),
                         cuda.Out(output),
                         texrefs=[self.texref],
                         block=blocks,
                         grid=grid)

        return output

    @on_trait_change('X_angle, Y_angle')
    def updateImge(self):
        self.plotdata.set_data('LF_img', self.sampleLF())
class GenerateProjectorCalibration(HasTraits):
    #width = traits.Int
    #height = traits.Int
    display_id = traits.String
    plot = Instance(Component)
    linedraw = Instance(LineSegmentTool)
    viewport_id = traits.String('viewport_0')
    display_mode = traits.Trait('white on black', 'black on white')
    client = traits.Any
    blit_compressed_image_proxy = traits.Any

    set_display_server_mode_proxy = traits.Any

    traits_view = View(
                    Group(
        Item('display_mode'),
        Item('viewport_id'),
                        Item('plot', editor=ComponentEditor(),
                             show_label=False),
                        orientation = "vertical"),
                    resizable=True,
                    )

    def __init__(self,*args,**kwargs):
        display_coords_filename = kwargs.pop('display_coords_filename')
        super( GenerateProjectorCalibration, self ).__init__(*args,**kwargs)

        fd = open(display_coords_filename,mode='r')
        data = pickle.load(fd)
        fd.close()

        self.param_name = 'virtual_display_config_json_string'
        self.fqdn = '/virtual_displays/'+self.display_id + '/' + self.viewport_id
        self.fqpn = self.fqdn + '/' + self.param_name
        self.client = dynamic_reconfigure.client.Client(self.fqdn)

        self._update_image()
        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads( virtual_display_json_str )

        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads( virtual_display_json_str )

            all_points_ok = True
            # error check
            for (x,y) in this_virtual_display['viewport']:
                if (x >= self.width) or (y >= self.height):
                    all_points_ok = False
                    break
            if all_points_ok:
                self.linedraw.points = this_virtual_display['viewport']
            # else:
            #     self.linedraw.points = []
            self._update_image()

    def _update_image(self):
        self._image = np.zeros( (self.height, self.width, 3), dtype=np.uint8)
        # draw polygon
        if len(self.linedraw.points)>=3:
            pts = [ (posint(y,self.height-1),posint(x,self.width-1)) for (x,y) in self.linedraw.points]
            mahotas.polygon.fill_polygon(pts, self._image[:,:,0])
            self._image[:,:,0] *= 255
            self._image[:,:,1] = self._image[:,:,0]
            self._image[:,:,2] = self._image[:,:,0]

        # draw red horizontal stripes
        for i in range(0,self.height,100):
            self._image[i:i+10,:,0] = 255

        # draw blue vertical stripes
        for i in range(0,self.width,100):
            self._image[:,i:i+10,2] = 255

        if hasattr(self,'_pd'):
            self._pd.set_data("imagedata", self._image)
        self.send_array()
        if len(self.linedraw.points) >= 3:
            self.update_ROS_params()

    def _plot_default(self):
        self._pd = ArrayPlotData()
        self._pd.set_data("imagedata", self._image)

        plot = Plot(self._pd, default_origin="top left")
        plot.x_axis.orientation = "top"
        img_plot = plot.img_plot("imagedata")[0]

        plot.bgcolor = "white"

        # Tweak some of the plot properties
        plot.title = "Click to add points, press Enter to clear selection"
        plot.padding = 50
        plot.line_width = 1

        # Attach some tools to the plot
        pan = PanTool(plot, drag_button="right", constrain_key="shift")
        plot.tools.append(pan)
        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        return plot

    def _linedraw_default(self):
        linedraw = LineSegmentTool(self.plot,color=(0.5,0.5,0.9,1.0))
        self.plot.overlays.append(linedraw)
        linedraw.on_trait_change( self.points_changed, 'points[]')
        return linedraw

    def points_changed(self):
        self._update_image()

    @traits.on_trait_change('display_mode')
    def send_array(self):
        # create an array
        if self.display_mode.endswith(' on black'):
            bgcolor = (0,0,0,1)
        elif self.display_mode.endswith(' on white'):
            bgcolor = (1,1,1,1)

        if self.display_mode.startswith('black '):
            color = (0,0,0,1)
        elif self.display_mode.startswith('white '):
            color = (1,1,1,1)

        fname = tempfile.mktemp('.png')
        try:
            scipy.misc.imsave(fname, self._image )
            image = freemovr_engine.msg.FreemoVRCompressedImage()
            image.format = 'png'
            image.data = open(fname).read()
            self.blit_compressed_image_proxy(image)
        finally:
            os.unlink(fname)

    def get_viewport_verts(self):
        # convert to integers
        pts = [ (posint(x,self.width-1),posint(y,self.height-1)) for (x,y) in self.linedraw.points]
        # convert to list of lists for maximal json compatibility
        return [ list(x) for x in pts ]
Example #45
0
class ConnectionMatrixViewer(HasTraits):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    fro = Any
    to = Any
    data = None
    val = Float
    nodelabels = Any

    traits_view = View(
        Group(Item('plot',
                   editor=ComponentEditor(size=(800, 600)),
                   show_label=False),
              HGroup(
                  Item('fro', label="From", style='readonly', springy=True),
                  Item('to', label="To", style='readonly', springy=True),
                  Item('val', label="Value", style='readonly', springy=True),
              ),
              orientation="vertical"),
        Item('data_name', label="Edge key"),
        # handler=CustomHandler(),
        resizable=True,
        title="Connection Matrix Viewer")

    def __init__(self, nodelabels, matdict, **traits):
        """ Starts a matrix inspector
        
        Parameters
        ----------
        nodelables : list
            List of strings of labels for the rows of the matrix
        matdict : dictionary
            Keys are the edge type and values are NxN Numpy arrays """
        super(HasTraits, self).__init__(**traits)

        self.add_trait('data_name', Enum(matdict.keys()))

        self.data_name = matdict.keys()[0]
        self.data = matdict
        self.nodelables = nodelabels
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

    def _data_name_changed(self, old, new):
        self.pd.set_data("imagedata", self.data[self.data_name])
        #self.my_plot.set_value_selection((0, 2))
        self.tplot.title = "Connection Matrix for %s" % self.data_name

    def _update_fields(self):

        # map mouse location to array index
        frotmp = int(round(self.custtool.yval) - 1)
        totmp = int(round(self.custtool.xval) - 1)

        # check if within range
        sh = self.data[self.data_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            row = " (index: %i" % (frotmp + 1) + ")"
            col = " (index: %i" % (totmp + 1) + ")"
            self.fro = " " + str(self.nodelables[frotmp]) + row
            self.to = " " + str(self.nodelables[totmp]) + col
            self.val = self.data[self.data_name][frotmp, totmp]

    def _create_plot_component(self):

        # Create a plot data object and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.data[self.data_name])

        # find dimensions
        xdim = self.data[self.data_name].shape[1]
        ydim = self.data[self.data_name].shape[0]

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot("imagedata",
                            name="my_plot",
                            xbounds=(0.5, xdim + 0.5),
                            ybounds=(0.5, ydim + 0.5),
                            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = "Connection Matrix for %s" % self.data_name
        self.tplot.padding = 80

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        return container
Example #46
0
 def _pd_default(self):
     image = zeros(shape = (300,400))        
     pd = ArrayPlotData()
     pd.set_data("imagedata", toRGB(image))   
     return pd
Example #47
0
class ImagePlot(HasTraits):
    #plot = Instance(Plot)
    # traits_view = View(
    # Group(Item('container',
    # editor=ComponentEditor(),
    # show_label=False)),
    # width=500, height=500,
    # buttons=NoButtons,
    # resizable=True, title="QTLab Analysis Plot")
    # Item('plot', editor=ComponentEditor(), show_label=False),
    # width=500, height=500, resizable=True, title="QTLab Analysis Plot")
    # def __init__(self, title, xtitle, x, ytitle, y, z):
    # super(ImagePlot, self).__init__()
    # self.create_plot(title, xtitle, x, ytitle, y, z)

    def _create_window(self, title, xtitle, x, ytitle, y, z):
        '''
        - Left-drag pans the plot.
    	- Mousewheel up and down zooms the plot in and out.
        - Pressing "z" brings up the Zoom Box, and you can click-drag a rectangular
        region to zoom.  If you use a sequence of zoom boxes, pressing alt-left-arrow
        and alt-right-arrow moves you forwards and backwards through the "zoom
        history".
        '''
        self._plotname = title
        # Create window
        self.data = ArrayPlotData()
        self.plot = Plot(self.data)
        self.update_plot(x, y, z)
        self.plot.title = title
        self.plot.x_axis.title = xtitle
        self.plot.y_axis.title = ytitle

        cmap_renderer = self.plot.plots[self._plotname][0]

        # Create colorbar
        self._create_colorbar()
        self._colorbar.plot = cmap_renderer
        self._colorbar.padding_top = self.plot.padding_top
        self._colorbar.padding_bottom = self.plot.padding_bottom

        # Add some tools
        self.plot.tools.append(PanTool(self.plot, constrain_key="shift"))
        self.plot.overlays.append(
            ZoomTool(component=self.plot, tool_mode="box", always_on=False))
        # selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, selection_type="mask")
        # cmap_renderer.overlays.append(selection)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.plot)
        container.add(self._colorbar)
        self.container = container

        # Return a window containing our plot container
        return Window(self, -1, component=container)

    def update_plot(self, x, y, z):
        self.data.set_data('x', x)
        self.data.set_data('y', y)
        self.data.set_data('z', z)

        if self.plot.plots.has_key(self._plotname):
            self.plot.delplot(self._plotname)

        # determine correct bounds
        xstep = (x.max() - x.min()) / (len(x) - 1)
        ystep = (y.max() - y.min()) / (len(y) - 1)
        x0, x1 = x.min() - xstep / 2, x.max() + xstep / 2
        y0, y1 = y.min() - ystep / 2, y.max() + ystep / 2

        self.plot.img_plot('z',
                           name=self._plotname,
                           xbounds=(x0, x1),
                           ybounds=(y0, y1),
                           colormap=jet)

        # self.plot.plot(("x", "y", "z"),
        # type = "img_plot",
        # name = self._plotname,
        # color_mapper = jet,
        # marker = "square",
        # fill_alpha = 0.5,
        # marker_size = 6,
        # outline_color = "black",
        # border_visible = True,
        # bgcolor = "white")

    def _create_colorbar(self):
        cmap = self.plot.color_mapper
        self._colorbar = ColorBar(index_mapper=LinearMapper(range=cmap.range),
                                  color_mapper=cmap,
                                  orientation='v',
                                  resizable='v',
                                  width=30,
                                  padding=30)
Example #48
0
 def _pd_default(self):
     image = numpy.array([])
     pd = ArrayPlotData()
     pd.set_data("imagedata", image)
     return pd
Example #49
0
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero = Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x = Int(1023)
    max_pos_y = Int(1023)
    top = Range(low='zero', high='max_pos_x', value=20, cols=4)
    left = Range(low='zero', high='max_pos_y', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar = Instance(Component)
    numpeaks_total = Int(0)
    numpeaks_img = Int(0)
    OK_custom = OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh = Trait(None, None, List, Tuple, Array)
    thresh_upper = Float(1.0)
    thresh_lower = Float(0.0)
    numfiles = Int(1)
    img_idx = Int(0)
    tmp_img_idx = Int(0)

    csr = Instance(BaseCursorTool)

    traits_view = View(HFlow(
        VGroup(Item("img_container",
                    editor=ComponentEditor(),
                    show_label=False),
               Group(
                   Spring(),
                   Item("ShowCC",
                        editor=BooleanEditor(),
                        label="Show cross correlation image")),
               label="Original image",
               show_border=True,
               trait_modified="tab_selected"),
        VGroup(
            Group(HGroup(
                Item("left", label="Left coordinate", style="custom"),
                Item("top", label="Top coordinate", style="custom"),
            ),
                  Item("tmp_size", label="Template size", style="custom"),
                  Item("tmp_plot",
                       editor=ComponentEditor(height=256, width=256),
                       show_label=False,
                       resizable=True),
                  label="Template",
                  show_border=True),
            Group(Item("peak_width", label="Peak width", style="custom"),
                  Group(
                      Spring(),
                      Item("findpeaks",
                           editor=ButtonEditor(label="Find Peaks"),
                           show_label=False),
                      Spring(),
                  ),
                  HGroup(
                      Item("thresh_lower",
                           label="Threshold Lower Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                      Item("thresh_upper",
                           label="Threshold Upper Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                  ),
                  HGroup(
                      Item("numpeaks_img",
                           label="Number of Cells selected (this image)",
                           style='readonly'),
                      Spring(),
                      Item("numpeaks_total", label="Total", style='readonly'),
                      Spring(),
                  ),
                  label="Peak parameters",
                  show_border=True),
        )),
                       buttons=[
                           Action(name='OK',
                                  enabled_when='numpeaks_total > 0'),
                           CancelButton
                       ],
                       title="Template Picker",
                       handler=OK_custom,
                       kind='livemodal',
                       key_bindings=key_bindings,
                       width=960,
                       height=600)

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
            return None
        self.OK_custom = OK_custom_handler()
        self.sig = signal_instance
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            self.sig.data = np.atleast_3d(self.sig.data)
            self.titles = [self.sig.mapped_parameters.name]
        else:
            self.numfiles = len(
                self.sig.mapped_parameters.original_files.keys())
            self.titles = self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data = ArrayPlotData(
            imagedata=self.sig.data[self.top:self.top + self.tmp_size,
                                    self.left:self.left + self.tmp_size,
                                    self.img_idx])
        tmp_plot = Plot(tmp_plot_data, default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio = 1.0
        self.tmp_plot = tmp_plot
        self.tmp_plotdata = tmp_plot_data
        self.img_plotdata = ArrayPlotData(
            imagedata=self.sig.data[:, :, self.img_idx])
        self.img_container = self._image_plot_container()

        self.crop_sig = None

    def render_image(self):
        plot = Plot(self.img_plotdata, default_origin="top left")
        img = plot.img_plot("imagedata", colormap=gray)[0]
        plot.title = "%s of %s: " % (self.img_idx + 1,
                                     self.numfiles) + self.titles[self.img_idx]
        plot.aspect_ratio = float(self.sig.data.shape[1]) / float(
            self.sig.data.shape[0])

        #if not self.ShowCC:
        csr = CursorTool(img,
                         drag_button='left',
                         color='white',
                         line_width=2.0)
        self.csr = csr
        csr.current_position = self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot, drag_button="right"))
        zoom = ZoomTool(plot,
                        tool_mode="box",
                        always_on=False,
                        aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot = plot
        return plot

    def render_scatplot(self):
        peakdata = ArrayPlotData()
        peakdata.set_data("index", self.peaks[self.img_idx][:, 0])
        peakdata.set_data("value", self.peaks[self.img_idx][:, 1])
        peakdata.set_data("color", self.peaks[self.img_idx][:, 2])
        scatplot = Plot(peakdata,
                        aspect_ratio=self.img_plot.aspect_ratio,
                        default_origin="top left")
        scatplot.plot(
            ("index", "value", "color"),
            type="cmap_scatter",
            name="my_plot",
            color_mapper=jet(DataRange1D(low=0.0, high=1.0)),
            marker="circle",
            fill_alpha=0.5,
            marker_size=6,
        )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d = self.img_plot.range2d
        self.scatplot = scatplot
        self.peakdata = peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container = OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer=False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img > 0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot = self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer,
                                                fade_alpha=0.35,
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections'] = self.thresh
            cmap_renderer.color_data.metadata_changed = {
                'selections': self.thresh
            }
        # Create the colorbar, handing in the appropriate range and colormap
        colormap = scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(range=DataRange1D(low=0.0, high=1.0)),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection = RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr = colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection = colorbar_selection
        self.cmap_renderer = cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar = colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.img_plot.title = "%s of %s: " % (
            self.img_idx + 1, self.numfiles) + self.titles[self.img_idx]
        self.redraw_plots()

    @on_trait_change('tmp_size')
    def update_max_pos(self):
        max_pos_x = self.sig.data.shape[0] - self.tmp_size - 1
        if self.left > max_pos_x: self.left = max_pos_x
        self.max_pos_x = max_pos_x
        max_pos_y = self.sig.data.shape[1] - self.tmp_size - 1
        if self.top > max_pos_y: self.top = max_pos_y
        self.max_pos_y = max_pos_y
        return

    def increase_img_idx(self, info):
        if self.img_idx == (self.numfiles - 1):
            self.img_idx = 0
        else:
            self.img_idx += 1

    def decrease_img_idx(self, info):
        if self.img_idx == 0:
            self.img_idx = self.numfiles - 1
        else:
            self.img_idx -= 1

    @on_trait_change('left, top')
    def update_csr_position(self):
        self.csr.current_position = self.left, self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        self.left, self.top = self.csr.current_position

    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data(
            "imagedata",
            self.sig.data[self.top:self.top + self.tmp_size,
                          self.left:self.left + self.tmp_size, self.img_idx])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size),
                                  np.arange(self.tmp_size))
        self.tmp_img_idx = self.img_idx
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :,
                                                               self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]),
                                      np.arange(self.CC.shape[0]))
        if self.numpeaks_total > 0:
            self.peaks = [np.array([[0, 0, -1]])]

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
            self.cmap_renderer.color_data.metadata['selections'] = thresh
            self.thresh_lower = thresh[0]
            self.thresh_upper = thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed = {
                'selections': thresh
            }
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh = [self.thresh_lower, self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections'] = self.thresh
        self.cmap_renderer.color_data.metadata_changed = {
            'selections': self.thresh
        }
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
        except:
            thresh = []
        if thresh == [] or thresh == () or thresh == None:
            thresh = (0, 1)
        self.numpeaks_total = int(
            np.sum([
                np.sum(
                    np.ma.masked_inside(self.peaks[i][:, 2], thresh[0],
                                        thresh[1]).mask)
                for i in xrange(len(self.peaks))
            ]))
        try:
            self.numpeaks_img = int(
                np.sum(
                    np.ma.masked_inside(self.peaks[self.img_idx][:, 2],
                                        thresh[0], thresh[1]).mask))
        except:
            self.numpeaks_img = 0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks = []
        for idx in xrange(self.numfiles):
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :, idx])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks = pc.two_dim_findpeaks(self.CC * 255,
                                       peak_width=self.peak_width,
                                       medfilt_radius=None)
            pks[:, 2] = pks[:, 2] / 255.
            peaks.append(pks)
        self.peaks = peaks

    def mask_peaks(self, idx):
        thresh = self.cbar_selection.selection
        if thresh == []:
            thresh = (0, 1)
        mpeaks = np.ma.asarray(self.peaks[idx])
        mpeaks[:, 2] = np.ma.masked_outside(mpeaks[:, 2], thresh[0], thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot = self.img_plot
        self.container.remove(oldplot)
        newplot = self.render_image()
        self.container.add(newplot)
        self.img_plot = newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat = self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img > 0:
            newscat = self.render_scatplot()
            self.container.add(newscat)
            self.scatplot = newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar = colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from eelslab.signals.aggregate import AggregateCells
        if self.numfiles == 1:
            self.crop_sig = self.crop_cells()
            return
        else:
            crop_agg = []
            for idx in xrange(self.numfiles):
                crop_agg.append(self.crop_cells(idx))
            self.crop_sig = AggregateCells(*crop_agg)
            return

    def crop_cells(self, idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks = np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz = self.tmp_size
        data = np.zeros((tmp_sz, tmp_sz, peaks.shape[0]))
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            parent = self.sig
        else:
            parent = self.sig.mapped_parameters.original_files[
                self.titles[idx]]
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[:, :, i] = self.sig.data[peaks[i, 1]:peaks[i, 1] + tmp_sz,
                                          peaks[i,
                                                0]:peaks[i, 0] + tmp_sz, idx]
            crop_sig = Image({
                'data': data,
                'mapped_parameters': {
                    'name': 'Cropped cells from %s' % self.titles[idx],
                    'record_by': 'image',
                    'locations': peaks,
                    'parent': parent,
                }
            })
        return crop_sig
        # attach a class member that has the locations from which the images were cropped
        print "Complete.  "
Example #50
0
class ConstructionEditor(pyface.SplitApplicationWindow):
    def __init__(self, **kwargs):
        self._handle_by_joint = {}
        super(ConstructionEditor, self).__init__(**kwargs)

#############################################################
# visualization related

    scene = Instance(Scene)  # vtk scene
    show_simulation = Bool(True)
    show_scalar = Enum('stress', 'strain')
    scalar_range_max = Float(50.)
    auto_scalar_range = Bool(True)
    displacement_amplify_factor = Float(1000.)
    amplify_radius = Float(1.5)
    show_cross_sectional_areas_labels = Bool(True)

    _on_init = Event(
    )  # internal event that is triggered when scene has been setup

    _bg_plane_actor = None
    _axis = None
    _bg_plane_picker = None
    _bg_plane_width = None
    _bg_plane_height = None

    _reset_zoom_needed = Bool(False)

    def reset_zoom(self):
        if self.scene:
            self.scene.reset_zoom()

    # schedule a render when model changes
    @on_trait_change(
        'construction.[joints,elements].+, construction.[joint_force_magnitude,width,height,weight_factor], selected_object, displacement_amplify_factor'
    )
    def render_if_not_interacting(self):
        if self.scene and not self.scene._interacting:
            self.render_interacting()

    def render_interacting(self):
        if self.scene:
            old_interacting = self.scene._interacting
            self.scene._interacting = True
            self.scene.render()
            self.scene._interacting = old_interacting

    def _setup_scene(self):
        # scalar bar for strain
        lut_strain = tvtk.LookupTable(hue_range=(0.66, 0.0))
        lut_strain.build()
        self._scalar_bar_strain = tvtk.ScalarBarActor(
            lookup_table=lut_strain,
            orientation='horizontal',
            text_position='succeed_scalar_bar',
            maximum_number_of_colors=256,
            number_of_labels=9,
            position=(0.1, 0.01),
            position2=(0.8, 0.08),
            title='element strain (%)',
            visibility=False)
        self.scene.add_actor(self._scalar_bar_strain)
        # scalar bar for stress
        # lookup table from green to yellow, and last 2 values dark red
        lut_stress = tvtk.LookupTable(hue_range=(0.33, 0.1),
                                      number_of_table_values=256)
        lut_stress.build()
        lut_stress.set_table_value(254, (1.0, 0.4, 0.0, 1.0))
        lut_stress.set_table_value(255, (1.0, 0.0, 0.0, 1.0))
        self._scalar_bar_stress = tvtk.ScalarBarActor(
            lookup_table=lut_stress,
            orientation='horizontal',
            text_position='succeed_scalar_bar',
            maximum_number_of_colors=256,
            number_of_labels=9,
            position=(0.1, 0.01),
            position2=(0.8, 0.08),
            title='element stress',
            visibility=False)
        self.scene.add_actor(self._scalar_bar_stress)
        # setup elements visualization
        self._elements_polydata = tvtk.PolyData()
        self._tubes = tvtk.TubeFilter(
            input=self._elements_polydata,
            number_of_sides=6,
            vary_radius='vary_radius_by_absolute_scalar',
            radius_factor=1.5)
        mapper = tvtk.PolyDataMapper(input=self._tubes.output,
                                     lookup_table=lut_strain,
                                     interpolate_scalars_before_mapping=True,
                                     scalar_mode='use_cell_data')
        self._elements_actor = tvtk.Actor(mapper=mapper)
        self.scene.add_actor(self._elements_actor)
        # show elements in deformed state as wireframe
        self._deformed_elements_polydata = tvtk.PolyData()
        self._deformed_elements_actor = tvtk.Actor(mapper=tvtk.PolyDataMapper(
            input=self._deformed_elements_polydata))
        self._deformed_elements_actor.property.set(opacity=0.2,
                                                   representation='wireframe')
        self.scene.add_actor(self._deformed_elements_actor)
        # highlight one element via a ribbon outline
        self._hl_element_ribbons = tvtk.RibbonFilter(input=tvtk.PolyData(),
                                                     use_default_normal=True,
                                                     width=1.0)
        self._hl_element_actor = tvtk.Actor(
            mapper=tvtk.PolyDataMapper(input=self._hl_element_ribbons.output),
            visibility=False)
        self._hl_element_actor.property.set(ambient=1,
                                            ambient_color=(1, 1, 1),
                                            diffuse=0)
        self.scene.add_actor(self._hl_element_actor)
        # cross sectional radius labels
        self._elements_label_polydata = tvtk.PolyData()
        self._label_cellcenters = tvtk.CellCenters(
            input=self._elements_label_polydata)
        self._label_visps = tvtk.SelectVisiblePoints(
            renderer=self.scene.renderer,
            input=self._label_cellcenters.output,
            tolerance=10000)
        self._label_actor = tvtk.Actor2D(mapper=tvtk.Dynamic2DLabelMapper(
            input=self._label_visps.output, label_mode='label_scalars'))
        self._label_actor.mapper.label_text_property.set(
            bold=True, italic=False, justification='centered', font_size=14)
        #self.scene.add_actor(self._label_actor)
        # force glyphs (use arrows for that)
        self._force_glyphs = tvtk.Glyph3D(scale_mode='scale_by_vector',
                                          vector_mode='use_vector',
                                          color_mode='color_by_vector',
                                          scale_factor=1.)
        self._force_glyphs.set_source(
            0,
            tvtk.ArrowSource(shaft_radius=0.04,
                             tip_resolution=8,
                             tip_radius=0.2).output)
        self._force_polydata = tvtk.PolyData()
        self._force_glyphs.set_input(0, self._force_polydata)
        self._force_actor = tvtk.Actor(mapper=tvtk.PolyDataMapper(
            input=self._force_glyphs.output, scalar_range=(0, 10)))
        self._force_actor.mapper.lookup_table.hue_range = (0.33, 0.0)
        self.scene.add_actor(self._force_actor)
        # current status display
        self._text_actor = tvtk.TextActor(position=(0.5, 0.95))
        self._text_actor.position_coordinate.coordinate_system = 'normalized_display'
        self._text_actor.text_property.set(font_size=14,
                                           justification='center')
        self.scene.add_actor(self._text_actor)
        # a nice gradient background
        self.scene.renderer.set(background2=(0.28, 0.28, 0.28),
                                background=(0.01, 0.01, 0.02),
                                gradient_background=True)
        # setup events
        self.interaction_mode = 'select'
        self.scene.interactor.add_observer('MouseMoveEvent', self._mouse_move)
        self.scene.interactor.add_observer('LeftButtonPressEvent',
                                           self._mouse_press)
        self.scene.renderer.add_observer('StartEvent', self._before_render)
        self._on_init = True
        self._reset_zoom_needed = True

    def bw_mode(self):
        self.scene.renderer.set(background=(1, 1, 1),
                                gradient_background=False)
        self._deformed_elements_actor.visibility = False
        self._elements_actor.mapper.scalar_visibility = False
        self._elements_actor.property.set(ambient=1,
                                          ambient_color=(.3, .3, .3),
                                          diffuse_color=(0, 0, 0))
        self._axis.visibility = False
        #self._tubes.set(vary_radius='vary_radius_off', radius=0.1)
        self._bg_plane_actor.visibility = False
        self._text_actor.property.color = (0, 0, 0)
        self._label_actor.mapper.label_text_property.set(color=(0, 0, 0))

    def pic_mode(self):
        self.scene.renderer.set(background=(1, 1, 1),
                                gradient_background=False)
        self._deformed_elements_actor.visibility = False
        self._axis.visibility = False
        self._bg_plane_actor.visibility = False
        self._text_actor.property.color = (0, 0, 0)
        self._label_actor.visibility = False
        self._scalar_bar_stress.lookup_table.set(hue_range=(0, 0),
                                                 value_range=(0, 1))
        self._scalar_bar_stress.lookup_table.force_build()
        self.scene.reset_zoom()

    def _before_render(self, *args):
        self._redraw_background_plane()
        self._redraw_joint_handles()
        self._redraw_joint_forces()
        self._redraw_elements()
        self._redraw_element_labels()
        self._redraw_joints()
        self._redraw_caption()
        self._setup_elements_picker()
        self._show_scalar_bar_when_simulating()
        if self._reset_zoom_needed:
            self.reset_zoom()
            self._reset_zoom_needed = False

    def _redraw_background_plane(self):
        if self.construction and self.scene:
            if self._bg_plane_width == self.construction.width and self._bg_plane_height == self.construction.height:
                return
            if self._bg_plane_actor:
                self.scene.remove_actor(self._bg_plane_actor)
            if self._axis:
                self.scene.remove_actor(self._axis)
            w, h = self.construction.width, self.construction.height
            plane = tvtk.PlaneSource(x_resolution=int(w), y_resolution=int(h))
            scalation = tvtk.Transform()
            scalation.scale((w, h, 1))
            scale_plane = tvtk.TransformPolyDataFilter(transform=scalation,
                                                       input=plane.output)
            self._bg_plane_actor = tvtk.Actor(mapper=tvtk.PolyDataMapper(
                input=scale_plane.output))
            self._bg_plane_actor.property.set(representation='wireframe',
                                              line_stipple_pattern=0xF0F0,
                                              opacity=0.15)
            self.scene.add_actor(self._bg_plane_actor)
            self._axis = tvtk.CubeAxesActor2D(
                camera=self.scene.camera,
                z_axis_visibility=False,
                corner_offset=0,
                bounds=[-w / 2, w / 2, -h / 2, h / 2, 0, 0])
            self.scene.add_actor(self._axis)
            self._bg_plane_picker = tvtk.CellPicker(tolerance=0,
                                                    pick_from_list=True)
            self._bg_plane_picker.pick_list.append(self._bg_plane_actor)
            self._bg_plane_width, self._bg_plane_height = self.construction.width, self.construction.height

    def _show_scalar_bar_when_simulating(self):
        if self.construction:
            essential = self.show_simulation and len(
                self.construction.elements) > 0
            self._scalar_bar_strain.visibility = essential and self.show_scalar == 'strain'
            self._scalar_bar_stress.visibility = essential and self.show_scalar == 'stress'

    def _representation_model_for_joint_handle(self, joint):
        if joint.movable_x and joint.movable_y:
            model = data.joint_model_movable
        elif not joint.movable_x and not joint.movable_y:
            model = data.joint_model_unmovable
        elif joint.movable_x and not joint.movable_y:
            model = data.joint_model_movable_x
        elif not joint.movable_x and joint.movable_y:
            model = data.joint_model_movable_y
        return model

    def _redraw_joint_handles(self):
        if self.construction:
            #print "rebuild_joint_handles"
            for joint in self.construction.joints:
                if not joint in self._handle_by_joint:
                    self._handle_by_joint[joint] = self._make_handle_for_joint(
                        joint)

    def _redraw_joint_forces(self):
        #print "redraw_joint_forces"
        if self.construction:
            points = []
            vectors = []
            for joint in self.construction.joints:
                if joint.force_x != 0 or joint.force_y != 0:
                    points.append((joint.x, joint.y, 0))
                    vectors.append((joint.force_x, joint.force_y, 0))
            if len(points) > 0:
                self._force_polydata.points = points
                self._force_polydata.point_data.vectors = vectors
            else:
                self._force_polydata.points = None
                self._force_polydata.point_data.vectors = None

    def _redraw_elements(self):
        if self.construction:
            jp = self.construction.joint_positions[
                self.construction.element_index_table].reshape((-1, 2))
            self._elements_polydata.points = pts2dto3d(jp)
            self._elements_polydata.lines = N.r_[:len(
                self.construction.elements) * 2].reshape(
                    (-1, 2))  # => [[0,1],[2,3],[4,5],...]
            radii = N.array(
                [e.material.radius for e in self.construction.elements])
            self._elements_polydata.point_data.scalars = N.column_stack(
                (radii, radii)).ravel() * self.amplify_radius
            # set scalars of elements when enabled by the user
            if self.show_simulation and self.construction.simulated and len(
                    self.construction.elements) > 0:
                # show element stress
                if self.show_scalar == 'stress':
                    self._elements_actor.mapper.lookup_table = self._scalar_bar_stress.lookup_table
                    breakage = (N.abs(
                        self.construction.simulation_result.element_stresses) /
                                self.construction.max_element_stress_array
                                ) * 100.  # percentual breakage
                    self._elements_polydata.cell_data.scalars = breakage
                    self._scalar_bar_stress.lookup_table.table_range = \
                            self._elements_actor.mapper.scalar_range = (0, 100)
                # show element strain
                elif self.show_scalar == 'strain':
                    self._elements_actor.mapper.lookup_table = self._scalar_bar_strain.lookup_table
                    strains = self.construction.simulation_result.element_strains
                    self._elements_polydata.cell_data.scalars = strains
                    # set strain scalar range
                    if self.auto_scalar_range:
                        max_strain = N.abs(strains).max()
                        scalar_range = (-max_strain, max_strain)
                    else:
                        scalar_range = (-self.scalar_range_max,
                                        self.scalar_range_max)
                    self._scalar_bar_strain.lookup_table.table_range = \
                        self._elements_actor.mapper.scalar_range = scalar_range
            else:
                self._elements_polydata.cell_data.scalars = None
            # deformed wireframe model:
            if self.show_simulation and self.construction.simulated and len(self.construction.joints) > 0 \
                    and self.construction.simulation_result.okay:
                # amplify displacements so they become visible
                jp = self.construction.joint_positions + \
                        self.construction.simulation_result.joint_displacements * self.displacement_amplify_factor
                self._deformed_elements_polydata.points = pts2dto3d(jp)
                self._deformed_elements_polydata.lines = self.construction.element_index_table
                self._deformed_elements_actor.visibility = True
            else:
                self._deformed_elements_actor.visibility = False
            # redraw selected element highlight
            if isinstance(self.selected_object, model.Element):
                self._hl_element_ribbons.input = tvtk.PolyData(\
                        points=[list(self.selected_object.joint1.position)+[0], list(self.selected_object.joint2.position)+[0]], lines=[[0,1]])
                self._hl_element_ribbons.width = max(
                    self.selected_object.material.radius, 0.05) * 2.0

    def _redraw_element_labels(self):
        if self.construction:
            eit = self.construction.element_index_table
            jp = self.construction.joint_positions
            self._elements_label_polydata.points = pts2dto3d(
                (jp[eit[:, 0]] + jp[eit[:, 1]]) / 2)
            self._elements_label_polydata.point_data.scalars = N.round(
                N.array(
                    [e.material.radius
                     for e in self.construction.elements]) * 100, 3)
            self._label_visps.input = self._elements_label_polydata
            self._label_actor.mapper.input = self._label_visps.output

    def _redraw_caption(self):
        if self.construction:
            if len(self.construction.joints) > 0 and len(
                    self.construction.elements) > 0:
                # calculate fitness for shown construction
                #self.construction_fitness = fitness_for_construction(self.ga, self.construction)
                if self.construction.simulation_result.okay:
                    stability = self.construction.simulation_result.stability
                    if stability < 0:
                        stability_txt = 'STATICALLY UNSTABLE (%.0f%%)' % (
                            stability * 100)
                        self._text_actor.text_property.color = (1.0, 0.0, 0.0)
                    else:
                        self._text_actor.text_property.color = (1.0, 1.0, 1.0)
                        stability_txt = 'Stability %.0f%%' % (stability * 100)
                    self._text_actor.input = 'Weight: %.2fkg, %s' % (
                        self.construction.simulation_result.construction_mass,
                        stability_txt)
                else:
                    self._text_actor.input = 'Simulation Error (%s)' % self.construction.simulation_result.status
                self._text_actor.visibility = True
            else:
                self._text_actor.visibility = False

    def _redraw_joints(self):
        if self.scene and self.construction:
            #print "redraw joints"
            # a joint might have been removed - the handle for that should be removed
            joint_handle_pairs_to_remove = []
            for joint, handle in self._handle_by_joint.iteritems():
                if not joint in self.construction.joints:
                    joint_handle_pairs_to_remove.append((joint, handle))
            for joint, handle in joint_handle_pairs_to_remove:
                del self._handle_by_joint[joint]
                handle.off()
            for joint in self.construction.joints:
                handle = self._handle_by_joint[joint]
                handle.representation.handle = self._representation_model_for_joint_handle(
                    joint)
                handle.representation.world_position = (joint.x, joint.y, 0)

#############################################################
# gui related

    gui = Instance(pyface.GUI)
    property_view_contents = Any()
    title = String('Plane Truss Design via Genetic Algorithms')
    ratio = Float(1.0)
    direction = Str('vertical')
    construction_fitness = Float(0.0)

    def _create_splitter(self, parent):
        s = super(ConstructionEditor, self)._create_splitter(parent)
        self._setup_scene()
        return s

    def _create_lhs(self, parent):
        self.scene = Scene(parent)
        self.scene.interactor.interactor_style = tvtk.InteractorStyleImage()
        #self.scene.camera.parallel_projection = True
        self.scene.z_plus_view()
        return self.scene.control

    def _create_rhs(self, parent):
        self.properties_panel = wx.Panel(parent)
        parent.SetMinimumPaneSize(0)
        parent.SetSashSize(0)
        parent.SetSize((0, 100))
        self._splitter_control = parent
        return self.properties_panel

    def _property_view_contents_changed(self, old, new):
        if old:
            old.dispose()
        if new:
            w, h = new.control.GetSize()
            ps = self._splitter_control.GetMinimumPaneSize()
            self._splitter_control.SetMinimumPaneSize(max(ps, w))
            new.control.SetSize((max(ps, w), h))
        else:
            self._splitter_control.SetMinimumPaneSize(0)

    def close(self):
        if self.scene is not None:
            self.scene.close()
        super(ConstructionEditor, self).close()

    @on_trait_change('show_evolution_progress, ga.on_step')
    def _update_evolve_gui(self):
        if self.show_evolution_progress and self.ga.current_population:
            self.selectable_individuals = [
                ui.IndividualSelect(individual, i + 1) for i, individual in
                enumerate(self.ga.current_population.individuals)
            ]
            self.selected_individual = self.selectable_individuals[0]

#############################################################
# plots

    plot_num_generations = Enum(-1, 10, 50, 500, 5000)

    fitness_plot = Instance(Plot)
    fitness_plot_data = Instance(ArrayPlotData)
    #feasible_plot = Instance(Plot)
    #feasible_plot_data = Instance(ArrayPlotData)
    #stability_plot = Instance(Plot)
    #stability_plot_data = Instance(ArrayPlotData)

    @on_trait_change('_on_init')
    def _init_plots(self):
        # fitness
        self.fitness_plot_data = ArrayPlotData(generation=N.zeros(1),
                                               best=N.zeros(1),
                                               average=N.zeros(1))
        self.fitness_plot = Plot(self.fitness_plot_data)
        self.fitness_plot.legend.visible = True
        self.fitness_plot.set(padding_top=5,
                              padding_right=5,
                              padding_bottom=20,
                              padding_left=40)
        self.fitness_plot.plot(('generation', 'best'),
                               color='green',
                               line_width=2,
                               name='best')
        self.fitness_plot.plot(('generation', 'average'),
                               color='black',
                               name='avg')
        ## stability
        #self.stability_plot_data = ArrayPlotData(generation=N.zeros(1), min=N.zeros(1), max=N.zeros(1), average=N.zeros(1))
        #self.stability_plot = Plot(self.stability_plot_data, height=100)
        #self.stability_plot.legend.visible = True
        #self.stability_plot.set(padding_top = 15, padding_right = 5, padding_bottom = 20, padding_left = 40, title='Stability')
        #self.stability_plot.plot(('generation', 'min'), color='red', line_width=2, name='min')
        #self.stability_plot.plot(('generation', 'average'), color='black', name='avg')
        #self.stability_plot.plot(('generation', 'max'), color='green', line_width=2, name='max')
        ## feasible
        #self.feasible_plot_data = ArrayPlotData(generation=N.zeros(1), num=N.zeros(1))
        #self.feasible_plot = Plot(self.feasible_plot_data)
        #self.feasible_plot.set(padding_top = 15, padding_right = 5, padding_bottom = 20, padding_left = 40, title = 'Unfeasible Individuals')
        #self.feasible_plot.plot(('generation', 'num'), color='red', line_width=2)

    @on_trait_change('ga.on_init')
    def _reset_plots(self):
        self._fitness_best_history = []
        self._fitness_avg_history = []
        #self._stability_max_history = []
        #self._stability_min_history = []
        #self._stability_avg_history = []
        #self._num_feasible_history = []

    @on_trait_change('show_evolution_progress, ga.on_step, ga.on_init')
    def _update_plots(self):
        mm = self.plot_num_generations  # just an alias so that slicing expressions do not become too big
        gens = N.r_[:self.ga.num_steps][-mm:]
        # fitness plot
        if self.ga.current_population:
            self._fitness_best_history.append(
                self.ga.current_population.best.raw_fitness)
            self._fitness_avg_history.append(
                N.array([
                    i.raw_fitness
                    for i in self.ga.current_population.individuals
                    if i.feasible and not i.got_penalized
                ]).mean())
        self.fitness_plot_data.set_data('generation', gens)
        self.fitness_plot_data.set_data('best',
                                        self._fitness_best_history[-mm:])
        self.fitness_plot_data.set_data('average',
                                        self._fitness_avg_history[-mm:])
        ## stability plot
        #if self.ga.current_population:
        #    stabilities = N.array([i.stability for i in self.ga.current_population.individuals], N.float) * 100
        #    self._stability_min_history.append(stabilities.min())
        #    self._stability_max_history.append(stabilities.max())
        #    self._stability_avg_history.append(stabilities.mean())
        #self.stability_plot_data.set_data('generation', gens)
        #self.stability_plot_data.set_data('min', self._stability_min_history[-mm:])
        #self.stability_plot_data.set_data('max', self._stability_max_history[-mm:])
        #self.stability_plot_data.set_data('average', self._stability_avg_history[-mm:])
        ## feasible plot
        #if self.ga.current_population:
        #    self._num_feasible_history.append(len([i for i in self.ga.current_population.individuals if not i.feasible or i.got_penalized]))
        #self.feasible_plot_data.set_data('generation', gens)
        #self.feasible_plot_data.set_data('num', self._num_feasible_history[-mm:])

#############################################################
# ga related

    ga = Instance(GA.GeneticAlgorithm)
    fitness_function = Callable()

    start_evolve = Button()
    pause_evolve = Button()
    reset_evolve = Button()
    evolving = Bool(False)

    selected_individual = Instance(ui.IndividualSelect)
    selectable_individuals = List(ui.IndividualSelect)

    show_evolution_progress = Bool(True)
    ga_step = Event

    #def _internal_on_ga_step(self):
    #    self.ga_step = True

    #def _ga_changed(self, old, new):
    #    if old != new and new:
    #        new.on_trait_event(self._internal_on_ga_step, 'on_step', dispatch='fast_ui')

    def _start_evolve_fired(self):
        if not self.ga.inited:
            self.ga.world.construction = self.construction
            self.construction = self.construction.clone_traits(copy='deep')
            self.ga.init_evolution()
            self.ga_step = True
        self.interaction_mode = 'evolve'
        self.gui.invoke_later(self._evolve_it)

    def _fitness_function_changed(self):
        self.ga.fitness_function = self.fitness_function
        self.pause_evolve = True

    @on_trait_change('selected_individual, show_evolution_progress')
    def _show_current_individual(self):
        if self.show_evolution_progress:
            self.construction = self.ga.world.construction_from_individual(
                self.selected_individual.individual)

    def _evolve_it(self):
        #for i in xrange(5):
        self.ga.evolution_step()
        self.gui.process_events()
        if self.interaction_mode == 'evolve':
            self.gui.invoke_later(self._evolve_it)

    def _pause_evolve_fired(self):
        self.interaction_mode = 'select'

    def _reset_evolve_fired(self):
        if self.ga.inited and self.ga.world:
            self.construction = self.ga.world.construction
            self.ga.reset_evolution()

#############################################################
# model

    construction = Instance(model.Construction)

    open_construction = Button()
    save_construction = Button()
    new_construction = Button()

    def _open_construction_fired(self):
        file_dialog = pyface.FileDialog(
            action='open', wildcard='Constructions (*.con)|*.con|')
        if file_dialog.open() == pyface.OK:
            self.construction, self.ga = pickle.load(open(file_dialog.path))

    def _save_construction_fired(self):
        file_dialog = pyface.FileDialog(
            action='save as', wildcard='Constructions (*.con)|*.con|')
        if file_dialog.open() == pyface.OK:
            pickle.dump((self.construction, self.ga),
                        open(file_dialog.path, 'w'))

    def _new_construction_fired(self):
        self.construction = model.Construction(
            available_element_materials=data.steels,
            element_deleted_material=data.air)

    edit_available_materials = Button()

    def _edit_available_materials_fired(self):
        self.construction.edit_traits(view=ui.edit_available_elements_view())

#############################################################
# editing

    enter_add_joint_mode = Button()
    enter_add_element_mode = Button()
    remove_selected_object = Button()
    build_full_connections = Button()
    snap_joints = Button()
    new_element_material = Instance(model.ElementMaterial)

    selected_object = Either(Instance(model.Joint), Instance(model.Element))
    interaction_mode = Enum('select', 'add_joint', 'add_element', 'evolve')

    _last_hovered_element = None

    def _enter_add_joint_mode_fired(self):
        self.selected_object = None
        self.interaction_mode = 'add_joint'

    def _enter_add_element_mode_fired(self):
        self.interaction_mode = 'add_element'

    def _remove_selected_object_fired(self):
        if isinstance(self.selected_object, model.Joint):
            self.construction.joints.remove(self.selected_object)
            self.selected_object = None
        elif isinstance(self.selected_object, model.Element):
            self.construction.elements.remove(self.selected_object)
            self.selected_object = None

    def _build_full_connections_fired(self):
        self.construction.build_full_connections(self.new_element_material)

    def _snap_joints_fired(self):
        for joint in self.construction.joints:
            joint.x = N.round(joint.x)
            joint.y = N.round(joint.y)

    @on_trait_change('selected_object, _on_init')
    def _edit_selected_object_properties(self):
        # no object selected, show general ui
        if self.selected_object == None:
            self.property_view_contents = \
                    self.edit_traits(kind='subpanel', parent=self.properties_panel,
                            view=ui.general_edit_view([self.construction.element_deleted_material] + self.construction.available_element_materials))
        # joint selected
        elif isinstance(self.selected_object, model.Joint):
            w, h = self.construction.width, self.construction.height
            self.property_view_contents = \
                    self.edit_traits(view=ui.joint_edit_view(w, h), kind='subpanel', parent=self.properties_panel,
                        context={'joint': self.selected_object, 'object': self})
        # element selected
        elif isinstance(self.selected_object, model.Element):
            self.property_view_contents = \
                    self.edit_traits( view = ui.element_edit_view([self.construction.element_deleted_material] + self.construction.available_element_materials),
                        kind = 'subpanel', parent = self.properties_panel,
                        context = {'element': self.selected_object, 'object': self} )

    def _pick_element(self, x, y):
        self._element_picker.pick((float(x), float(y), 0.0),
                                  self.scene.renderer)
        if self._element_picker.cell_id > -1:
            return self.construction.elements[self._element_picker.cell_id]
        else:
            return None

    def _interaction_mode_changed(self):
        # need to disable the handle widgets?
        for h in self._handle_by_joint.values():
            h.process_events = self.interaction_mode in [
                'select', 'add_element'
            ]
        # cursor for adding
        if self.interaction_mode == 'add_joint':
            self.scene.render_window.current_cursor = 10

    def _mouse_press(self, *args):
        x, y = self.scene.interactor.last_event_position
        if self.interaction_mode == 'select':
            element = self._pick_element(x, y)
            if element:
                self.selected_object = element
            else:
                self.selected_object = None
        elif self.interaction_mode == 'add_joint':
            # did we hit the bounds plane?
            self._bg_plane_picker.pick((x, y, 0), self.scene.renderer)
            if self._bg_plane_picker.cell_id > -1:
                wx, wy, wz = self._bg_plane_picker.pick_position
                self.construction.joints.append(model.Joint(x=wx, y=wy))
                self.interaction_mode = 'select'

    def _mouse_move(self, *args):
        if self.interaction_mode == 'select':
            x, y = self.scene.interactor.event_position
            element = self._pick_element(x, y)
            if element:
                self.scene.render_window.current_cursor = 9
                self._last_hovered_element = element
            elif self._last_hovered_element:
                self.scene.render_window.current_cursor = 0
                self._last_hovered_element = None

    def _get_handle_by_joint(self, joint):
        try:
            return self._handle_by_joint[joint]
        except KeyError:
            return None

    def _selected_object_changed(self, old, new):
        if old and isinstance(old, model.Joint):
            handle = self._get_handle_by_joint(old)
            if handle:
                handle.representation.property.set(ambient=0, diffuse=1)
        self._hl_element_actor.visibility = False
        if new:
            if isinstance(new, model.Joint):
                self._get_handle_by_joint(new).representation.property.set(
                    ambient=1, diffuse=0.3)
                if self.interaction_mode == 'add_element' and old != new and isinstance(
                        old, model.Joint):
                    elem = model.Element(joint1=old,
                                         joint2=new,
                                         material=self.new_element_material)
                    if self.construction.element_addable(elem):
                        self.construction.elements.append(elem)
                    else:
                        pyface.MessageDialog(
                            message=
                            'Element can not be added because it already exists',
                            severity='warning',
                            title='Adding an Element').open()
                    self.interaction_mode = 'select'
            if isinstance(new, model.Element):
                self._hl_element_actor.visibility = True

    def _construct_handle_widget(self):
        handle = tvtk.HandleWidget(allow_handle_resize=False)
        return handle

    def _make_handle_for_joint(self, joint):
        handle = self._construct_handle_widget()
        representation = tvtk.PolygonalHandleRepresentation3D(
            handle=self._representation_model_for_joint_handle(joint))
        representation.property.set(ambient=0,
                                    diffuse=1,
                                    diffuse_color=(0.9, 0.9, 0.9),
                                    ambient_color=(1, 0, 0))
        handle.set_representation(representation)
        handle.interactor = self.scene.interactor
        handle.representation.world_position = (joint.position[0],
                                                joint.position[1], 0)
        handle.enabled = True

        # bind movements so that widget movement moves the joint
        def widget_move(*args):
            x, y = handle.representation.world_position[:2]
            joint.x = x
            joint.y = y

        handle.add_observer('InteractionEvent', widget_move)

        # bind selection
        def widget_select(*args):
            self.selected_object = joint

        handle.add_observer('StartInteractionEvent', widget_select)
        self._interaction_mode_changed()
        handle.on()
        return handle

    @on_trait_change(
        'construction.elements, construction:joints:position, _on_init')
    def _setup_elements_picker(self):
        if self.scene and self.construction:
            #print "setup elements picker"
            pd = tvtk.PolyData(points=pts2dto3d(
                self.construction.joint_positions),
                               lines=self.construction.element_index_table)
            self._pick_elements_actor = tvtk.Actor(mapper=tvtk.PolyDataMapper(
                input=pd))
            self._element_picker = tvtk.CellPicker(pick_from_list=True,
                                                   tolerance=0.005)
            self._element_picker.pick_list.append(self._pick_elements_actor)
Example #51
0
class FiberView( HasTraits ):

    timer         = Instance( Timer )
#    model         =  FiberModel(options)

    plot_data     = Instance( ArrayPlotData )
    plot          = Instance( Plot )
    start_stop    = Button()
    exit          = Button()

    # Default TraitsUI view
    traits_view = View(
        Item('plot', editor=ComponentEditor(), show_label=False),
        # Items
        HGroup( spring,
                Item( "start_stop", show_label = False ),
                Item( "exit", show_label = False ), spring),
        HGroup( spring ),

        # GUI window
        resizable = True,
        width     = 1000, 
        height    = 700,
        kind      = 'live' )

    def __init__(self, options, **kwtraits):
        super( FiberView, self).__init__(**kwtraits )
        self.model = FiberModel(options)

        # debugging
        self.debug = options.debug
        self.model.debug = options.debug

        # timing parameters
        self.max_packets = options.max_packets
        self.hertz = options.hertz

        # extend options to model
        self.model.max_packets = options.max_packets
        self.model.preallocate_arrays()
        self.model.num_analog_channels = options.num_analog_channels

        # generate traits plot
        self.plot_data = ArrayPlotData( x = self.model._tdata, y = self.model._ydata )
        self.plot = Plot( self.plot_data )
        renderer  = self.plot.plot(("x", "y"), type="line", name='old', color="green")[0]
#        self.plot.delplot('old')

        # recording flags
        self.model.recording = False
        self.model.trialEnded = True


        print 'Viewer initialized.'

        # should we wait for a ttl input to start?
        if options.ttl_start:
            self.model.ttl_start = True
            self.ttl_received = False

            # initialize FIO0 for TTL input
            self.FIO0_DIR_REGISTER = 6100 
            self.FIO0_STATE_REGISTER = 6000 
            self.model.labjack.writeRegister(self.FIO0_DIR_REGISTER, 0) # Set FIO0 low

        # initialize output array
        self.out_arr = None

        # keep track of number of runs
        self.run_number = 0

        self.timer = Timer(self.model.dt, self.time_update) # update every 1 ms


    def run( self ):
        self._plot_update()
        self.model._get_current_data()


    def time_update( self ):
        """ Callback that gets called on a timer to get the next data point over the labjack """

#        print "time_update"
#        print "self.model.ttl_start", self.model.ttl_start
#        print "self.ttl_received", self.ttl_received
#        print "self.model.recording", self.model.recording

        if self.model.ttl_start and not self.ttl_received:
            ttl = self.check_for_ttl()
            if ttl:
                self.ttl_received = True
#                self.model.recording = True
#                self.model.trialEnded = False
                self._start_stop_fired()
               
        if self.model.recording and not self.model.ttl_start:
            self.run()

        elif self.model.ttl_start:
            if self.model.recording and self.ttl_received:
                self.run()
#            elif self.model.recording:
#                self._start_stop_fired()
        else:
            if self.debug:
                pass #print "--timer tic--"
            pass

    def check_for_ttl(self):
        start_ttl = self.model.labjack.readRegister(self.FIO0_STATE_REGISTER)
#        print "start_ttl: ", start_ttl
        return start_ttl

    def clean_time_series(self, time_series, blip_thresh = 10.0):
        """ Removes blips, NAs, etc. """
        blip_idxs = time_series > blip_thresh
        time_series[blip_idxs] = np.median(time_series)
        return time_series

    def _plot_update( self ):
        num_display_points = 100*25 # For 1000 Hz
    
        if self.model.master_index > num_display_points:
            disp_begin = self.model.master_index - num_display_points
            disp_end = self.model.master_index
#            print 'disp_begin', disp_begin
#            print 'disp_end', disp_end
            ydata = self.clean_time_series( np.array(self.model._ydata[disp_begin:disp_end]) )
            xdata = np.array(self.model._tdata[disp_begin:disp_end])

            self.plot_data.set_data("y", ydata)
            self.plot_data.set_data("x", xdata)
            
        else:        
            self.plot_data.set_data( "y", self.clean_time_series( self.model._ydata[0:self.model.master_index] )) 
            self.plot_data.set_data( "x", self.model._tdata[0:self.model.master_index] )

        self.plot = Plot( self.plot_data )

#        self.plot.delplot('old')
        the_plot = self.plot.plot(("x", "y"), type="line", name = 'old', color="green")[0]
        self.plot.request_redraw()

    # Note: These should be moved to a proper handler (ModelViewController)

    def _start_stop_fired( self ):
        self.model.start_time = time.time()
        self.model.recording = not self.model.recording
        self.model.trialEnded = not self.model.trialEnded
        if self.model.trialEnded:
            self.save()
        
        #Quickly turn LED off to signal button push
        self.model.start_LED()


    def _exit_fired(self):
        print "Closing connection to LabJack..."
        self.ttl_start = False
        self.recording = False
        self.model.sdr.running = False
        self.model.recording = False

#        self.model.labjack.streamStop()
        self.model.labjack.close()
        print "connection closed. Safe to close..."
        #raise SystemExit
        #sys.exit()

    def save( self ):
            print "Saving!"        
            # Finally, construct and save full data array
            print "Saving acquired data..."
            for i in xrange( len( self.model.data ) ):
                new_array = 1
                block = self.model.data[i]
                for k in self.model.result.keys():
                    print k
                    if k != 0:
                        if new_array == 1:
                            array = np.array(block[k])
                            array.shape = (len(array),1)
                            new_array = 0
                        else:
                            new_data = np.array(block[k])
                            new_data.shape = (len(block[k]),1)
                            # if new_data is short by one, fill in with last entry
                            if new_data.shape[0] < array.shape[0]:
                                new_data = np.append( new_data, new_data[-1] )
                                new_data.shape = (new_data.shape[0], 1)
                                print "Appended point to new_data, shape now:",new_data.shape
                            if new_data.shape[0] > array.shape[0]:
                                new_data = new_data[:-1]
                                print "Removed point from new_data, shape now:",new_data.shape
                            print "array shape, new_data shape:", array.shape, new_data.shape
                            array = np.hstack((array, new_data ))
                if i == 0:
                    self.out_arr = array
                    print "Array shape:", self.out_arr.shape
                else:
                    self.out_arr = np.vstack( (self.out_arr, array) )
                    print "Array shape:", self.out_arr.shape
            date = time.localtime()
            outfile = self.model.savepath + self.model.filename
            outfile += str(date[0]) + '_' + str(date[1]) + '_' + str(date[2]) + '_'
            outfile += str(date[3]) + '-' + str(date[4]) + '_run_number_' + str(self.run_number) 
            np.savez(outfile, data=self.out_arr, time_stamps=self.model._tdata)

#             time_outfile = outfile + '_t'
#             np.savez(time_outfile, self.model._tdata)
            print "Saved ", outfile

            # Plot the data collected this last run
            self.plot_last_data()

            # clean up
            self.reset_variables()

    def reset_variables(self):
        self.out_arr = None
        self.ttl_received = False
        self.run_number += 1
        self.model.recording = False
        self.trialEnded = True
        self.plot.delplot('old')

    def plot_last_data(self):
        import pylab as pl
        if self.out_arr.shape[1] == 4:
            pl.figure()
            pl.subplot(411)
            pl.plot(self.out_arr[:,0])
            pl.subplot(412)
            pl.plot(self.out_arr[:,1])
            pl.subplot(413)
            pl.plot(self.out_arr[:,2])
            pl.subplot(414)
            pl.plot(self.out_arr[:,3])
            pl.show()
class ViewportDefiner(HasTraits):
    width = traits.Int
    height = traits.Int
    display_name = traits.String
    plot = Instance(Component)
    linedraw = Instance(LineSegmentTool)
    viewport_id = traits.String
    display_mode = traits.Trait('white on black', 'black on white')
    display_server = traits.Any
    display_info = traits.Any
    show_grid = traits.Bool

    traits_view = View(
        Group(Item('display_mode'),
              Item('display_name'),
              Item('viewport_id'),
              Item('plot', editor=ComponentEditor(), show_label=False),
              orientation="vertical"),
        resizable=True,
    )

    def __init__(self, *args, **kwargs):
        super(ViewportDefiner, self).__init__(*args, **kwargs)

        #find our index in the viewport list
        viewport_ids = []
        self.viewport_idx = -1
        for i, obj in enumerate(self.display_info['virtualDisplays']):
            viewport_ids.append(obj['id'])
            if obj['id'] == self.viewport_id:
                self.viewport_idx = i

        if self.viewport_idx == -1:
            raise Exception("Could not find viewport (available ids: %s)" %
                            ",".join(viewport_ids))

        self._update_image()

        self.fqdn = self.display_name + '/display/virtualDisplays'
        self.this_virtual_display = self.display_info['virtualDisplays'][
            self.viewport_idx]

        all_points_ok = True
        # error check
        for (x, y) in self.this_virtual_display['viewport']:
            if (x >= self.width) or (y >= self.height):
                all_points_ok = False
                break
        if all_points_ok:
            self.linedraw.points = self.this_virtual_display['viewport']
        else:
            self.linedraw.points = []
            rospy.logwarn('invalid points')
        self._update_image()

    def _update_image(self):
        self._image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
        fill_polygon.fill_polygon(self.linedraw.points, self._image)

        if self.show_grid:
            # draw red horizontal stripes
            for i in range(0, self.height, 100):
                self._image[i:i + 10, :, 0] = 255

            # draw blue vertical stripes
            for i in range(0, self.width, 100):
                self._image[:, i:i + 10, 2] = 255

        if hasattr(self, '_pd'):
            self._pd.set_data("imagedata", self._image)
        self.send_array()
        if len(self.linedraw.points) >= 3:
            self.update_ROS_params()

    def _plot_default(self):
        self._pd = ArrayPlotData()
        self._pd.set_data("imagedata", self._image)

        plot = Plot(self._pd, default_origin="top left")
        plot.x_axis.orientation = "top"
        img_plot = plot.img_plot("imagedata")[0]

        plot.bgcolor = "white"

        # Tweak some of the plot properties
        plot.title = "Click to add points, press Enter to clear selection"
        plot.padding = 50
        plot.line_width = 1

        # Attach some tools to the plot
        pan = PanTool(plot, drag_button="right", constrain_key="shift")
        plot.tools.append(pan)
        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        return plot

    def _linedraw_default(self):
        linedraw = LineSegmentTool(self.plot, color=(0.5, 0.5, 0.9, 1.0))
        self.plot.overlays.append(linedraw)
        linedraw.on_trait_change(self.points_changed, 'points[]')
        return linedraw

    def points_changed(self):
        self._update_image()

    @traits.on_trait_change('display_mode')
    def send_array(self):
        # create an array
        if self.display_mode.endswith(' on black'):
            bgcolor = (0, 0, 0, 1)
        elif self.display_mode.endswith(' on white'):
            bgcolor = (1, 1, 1, 1)

        if self.display_mode.startswith('black '):
            color = (0, 0, 0, 1)
        elif self.display_mode.startswith('white '):
            color = (1, 1, 1, 1)

        self.display_server.show_pixels(self._image)

    def get_viewport_verts(self):
        # convert to integers
        pts = [(fill_polygon.posint(x, self.width - 1),
                fill_polygon.posint(y, self.height - 1))
               for (x, y) in self.linedraw.points]
        # convert to list of lists for maximal json compatibility
        return [list(x) for x in pts]

    def update_ROS_params(self):
        viewport_verts = self.get_viewport_verts()
        self.this_virtual_display['viewport'] = viewport_verts
        self.display_info['virtualDisplays'][
            self.viewport_idx] = self.this_virtual_display
        rospy.set_param(self.fqdn, self.display_info['virtualDisplays'])
Example #53
0
 def _plotdata_default(self):
     data = self.get_data_slice(0, 0)
     plotdata = ArrayPlotData()
     plotdata.set_data('xy', data)
     return plotdata
Example #54
0
class ImagePlot(QtGui.QWidget):
    def __init__(self, parent, title, x, y, z, xtitle, ytitle, ztitle):
        QtGui.QWidget.__init__(self)
        
        # Create the subclass's window
        self.enable_win = self._create_window(title, x, y, z, xtitle, ytitle, ztitle)
        
        layout = QtGui.QVBoxLayout()
        
        layout.setMargin(0)
        layout.addWidget(self.enable_win.control)

        self.setLayout(layout)

        self.resize(650,650)

        self.show()

    def _create_window(self, title, x, y, z, xtitle, ytitle, ztitle):
        '''
        - Left-drag pans the plot.
    	- Mousewheel up and down zooms the plot in and out.
        - Pressing "z" brings up the Zoom Box, and you can click-drag a rectangular
        region to zoom.  If you use a sequence of zoom boxes, pressing alt-left-arrow
        and alt-right-arrow moves you forwards and backwards through the "zoom
        history".
        '''
        # Create window
        self._plotname = title
        self.data = ArrayPlotData()
        self.plot = ToolbarPlot(self.data, hiding=False, auto_hide=False)
        self.update_plot(x, y, z)
        self.plot.title = title
        self.plot.x_axis.title = xtitle
        self.plot.y_axis.title = ytitle
        
        cmap_renderer = self.plot.plots[self._plotname][0]
        
        # Create colorbar
        self._create_colorbar()
        self._colorbar.plot = cmap_renderer
        self._colorbar.padding_top = self.plot.padding_top
        self._colorbar.padding_bottom = self.plot.padding_bottom
        
        # Add some tools
        self.plot.tools.append(PanTool(self.plot, constrain_key="shift"))
        self.plot.overlays.append(ZoomTool(component=self.plot, tool_mode="box", always_on=False))
        
        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer = True)
        container.add(self.plot)
        container.add(self._colorbar)
        self.container = container

        # Return a window containing our plot container
        return Window(self, -1, component=self.container)

    def update_plot(self, x, y, z):
        self.data.set_data('x', x)
        self.data.set_data('y', y)
        self.data.set_data('z', z)
        
        if self.plot.plots.has_key(self._plotname):
            self.plot.delplot(self._plotname)

        # determine correct bounds
        xstep = (x.max() - x.min())/(len(x)-1)
        ystep = (y.max() - y.min())/(len(y)-1)
        x0, x1 = x.min() - xstep/2, x.max() + xstep/2
        y0, y1 = y.min() - ystep/2, y.max() + ystep/2
        
        self.plot.img_plot('z',
                   name = self._plotname,
                   xbounds = (x0, x1),
                   ybounds = (y0, y1),
                   colormap = jet)

    def _create_colorbar(self):
        cmap = self.plot.color_mapper
        self._colorbar = ColorBar(index_mapper=LinearMapper(range=cmap.range),
                                  color_mapper=cmap,
                                  orientation='v',
                                  resizable='v',
                                  width=30,
                                  padding=30)
Example #55
0
class TimerController(HasTraits):
    def __init__(self):
        self.arch = BehOrg.GraspArchitecture()
        self._time_steps = 0
        self.create_plot_component()

    def get_container(self):
        return self._container

    def create_plot_component(self):
        color_range_max_value = 10

        # gripper right cos field
        x_axis = numpy.array(
            range(self.arch._gripper_right_cos_field.
                  get_output_dimension_sizes()[0]))
        self._gripper_right_cos_field_plotdata = ArrayPlotData(
            x=x_axis, y=self.arch._gripper_right_cos_field.get_activation())
        self._gripper_right_cos_field_plot = Plot(
            self._gripper_right_cos_field_plotdata)
        self._gripper_right_cos_field_plot.title = 'gripper right cos'
        self._gripper_right_cos_field_plot.plot(("x", "y"),
                                                name='gripper_right_cos',
                                                type="line",
                                                color="blue")
        range_self = self._gripper_right_cos_field_plot.plots[
            'gripper_right_cos'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # gripper left cos field
        x_axis = numpy.array(
            range(
                self.arch._gripper_left_cos_field.get_output_dimension_sizes()
                [0]))
        self._gripper_left_cos_field_plotdata = ArrayPlotData(
            x=x_axis, y=self.arch._gripper_left_cos_field.get_activation())
        self._gripper_left_cos_field_plot = Plot(
            self._gripper_left_cos_field_plotdata)
        self._gripper_left_cos_field_plot.title = 'gripper left cos'
        self._gripper_left_cos_field_plot.plot(("x", "y"),
                                               name='gripper_left_cos',
                                               type="line",
                                               color="blue")
        range_self = self._gripper_left_cos_field_plot.plots[
            'gripper_left_cos'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # find red color intention field
        x_axis = numpy.array(
            range(self.arch._find_color.get_intention_field().
                  get_output_dimension_sizes()[0]))
        self._find_color_intention_field_plotdata = ArrayPlotData(
            x=x_axis,
            y=self.arch._find_color.get_intention_field().get_activation())
        self._find_color_intention_field_plot = Plot(
            self._find_color_intention_field_plotdata)
        self._find_color_intention_field_plot.title = 'find color int'
        self._find_color_intention_field_plot.plot(("x", "y"),
                                                   name='find_color_int',
                                                   type="line",
                                                   color="blue")
        range_self = self._find_color_intention_field_plot.plots[
            'find_color_int'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # find green color intention field
        x_axis = numpy.array(
            range(self.arch._find_color_ee.get_intention_field().
                  get_output_dimension_sizes()[0]))
        self._find_color_ee_intention_field_plotdata = ArrayPlotData(
            x=x_axis,
            y=self.arch._find_color_ee.get_intention_field().get_activation())
        self._find_color_ee_intention_field_plot = Plot(
            self._find_color_ee_intention_field_plotdata)
        self._find_color_ee_intention_field_plot.title = 'find color ee int'
        self._find_color_ee_intention_field_plot.plot(("x", "y"),
                                                      name='find_color_ee_int',
                                                      type="line",
                                                      color="blue")
        range_self = self._find_color_ee_intention_field_plot.plots[
            'find_color_ee_int'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # camera
        self._camera_field_plotdata = ArrayPlotData()
        self._camera_field_plotdata.set_data(
            'imagedata',
            self.arch._camera_field.get_activation().max(2).transpose())
        self._camera_field_plot = Plot(self._camera_field_plotdata)
        self._camera_field_plot.title = 'camera'
        self._camera_field_plot.img_plot(
            'imagedata',
            name='camera_field',
            xbounds=(0, self.arch._camera_field_sizes[0] - 1),
            ybounds=(0, self.arch._camera_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._camera_field_plot.plots['camera_field'][
            0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # color space red
        self._color_space_field_plotdata = ArrayPlotData()
        self._color_space_field_plotdata.set_data(
            'imagedata',
            self.arch._color_space_field.get_activation().max(1).transpose())
        self._color_space_field_plot = Plot(self._color_space_field_plotdata)
        self._color_space_field_plot.title = 'color space'
        self._color_space_field_plot.img_plot(
            'imagedata',
            name='color_space_field',
            xbounds=(0, self.arch._color_space_field_sizes[0] - 1),
            ybounds=(0, self.arch._color_space_field_sizes[2] - 1),
            colormap=jet,
        )
        range_self = self._color_space_field_plot.plots['color_space_field'][
            0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # color space green
        self._color_space_ee_field_plotdata = ArrayPlotData()
        self._color_space_ee_field_plotdata.set_data(
            'imagedata',
            self.arch._color_space_ee_field.get_activation().max(
                2).transpose())
        self._color_space_ee_field_plot = Plot(
            self._color_space_ee_field_plotdata)
        self._color_space_ee_field_plot.title = 'color space ee'
        self._color_space_ee_field_plot.img_plot(
            'imagedata',
            name='color_space_ee_field',
            xbounds=(0, self.arch._color_space_ee_field_sizes[0] - 1),
            ybounds=(0, self.arch._color_space_ee_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._color_space_ee_field_plot.plots[
            'color_space_ee_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # spatial target
        self._spatial_target_field_plotdata = ArrayPlotData()
        self._spatial_target_field_plotdata.set_data(
            'imagedata',
            self.arch._spatial_target_field.get_activation().transpose())
        self._spatial_target_field_plot = Plot(
            self._spatial_target_field_plotdata)
        self._spatial_target_field_plot.title = 'spatial target'
        self._spatial_target_field_plot.img_plot(
            'imagedata',
            name='spatial_target_field',
            xbounds=(0, self.arch._spatial_target_field_sizes[0] - 1),
            ybounds=(0, self.arch._spatial_target_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._spatial_target_field_plot.plots[
            'spatial_target_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # move head intention
        self._move_head_intention_field_plotdata = ArrayPlotData()
        self._move_head_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._move_head.get_intention_field().get_activation().
            transpose())
        self._move_head_intention_field_plot = Plot(
            self._move_head_intention_field_plotdata)
        self._move_head_intention_field_plot.title = 'move head int'
        self._move_head_intention_field_plot.img_plot(
            'imagedata',
            name='move_head_intention_field',
            xbounds=(0, self.arch._move_head_field_sizes[0] - 1),
            ybounds=(0, self.arch._move_head_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._move_head_intention_field_plot.plots[
            'move_head_intention_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # move head cos
        self._move_head_cos_field_plotdata = ArrayPlotData()
        self._move_head_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._move_head.get_cos_field().get_activation().transpose())
        self._move_head_cos_field_plot = Plot(
            self._move_head_cos_field_plotdata)
        self._move_head_cos_field_plot.title = 'move head cos'
        self._move_head_cos_field_plot.img_plot(
            'imagedata',
            name='move_head_cos_field',
            xbounds=(0, self.arch._move_head_field_sizes[0] - 1),
            ybounds=(0, self.arch._move_head_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._move_head_cos_field_plot.plots[
            'move_head_cos_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # move right arm intention
        self._move_right_arm_intention_field_plotdata = ArrayPlotData()
        self._move_right_arm_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._move_right_arm_intention_field.get_activation().
            transpose())
        self._move_right_arm_intention_field_plot = Plot(
            self._move_right_arm_intention_field_plotdata)
        self._move_right_arm_intention_field_plot.title = 'move right arm int'
        self._move_right_arm_intention_field_plot.img_plot(
            'imagedata',
            name='move_right_arm_intention_field',
            xbounds=(0, self.arch._move_arm_field_sizes[0] - 1),
            ybounds=(0, self.arch._move_arm_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._move_right_arm_intention_field_plot.plots[
            'move_right_arm_intention_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # move right arm cos
        self._move_right_arm_cos_field_plotdata = ArrayPlotData()
        self._move_right_arm_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._move_arm_cos_field.get_activation().transpose())
        self._move_right_arm_cos_field_plot = Plot(
            self._move_right_arm_cos_field_plotdata)
        self._move_right_arm_cos_field_plot.title = 'move right arm cos'
        self._move_right_arm_cos_field_plot.img_plot(
            'imagedata',
            name='move_right_arm_cos_field',
            xbounds=(0, self.arch._move_arm_field_sizes[0] - 1),
            ybounds=(0, self.arch._move_arm_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._move_right_arm_cos_field_plot.plots[
            'move_right_arm_cos_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # visual servoing right intention
        self._visual_servoing_right_intention_field_plotdata = ArrayPlotData()
        self._visual_servoing_right_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._visual_servoing_right.get_intention_field().
            get_activation().transpose())
        self._visual_servoing_right_intention_field_plot = Plot(
            self._visual_servoing_right_intention_field_plotdata)
        self._visual_servoing_right_intention_field_plot.title = 'visual servoing right int'
        self._visual_servoing_right_intention_field_plot.img_plot(
            'imagedata',
            name='visual_servoing_right_intention_field',
            xbounds=(0, self.arch._visual_servoing_field_sizes[0] - 1),
            ybounds=(0, self.arch._visual_servoing_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._visual_servoing_right_intention_field_plot.plots[
            'visual_servoing_right_intention_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        # visual servoing right cos
        self._visual_servoing_right_cos_field_plotdata = ArrayPlotData()
        self._visual_servoing_right_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._visual_servoing_right.get_cos_field().get_activation().
            transpose())
        self._visual_servoing_right_cos_field_plot = Plot(
            self._visual_servoing_right_cos_field_plotdata)
        self._visual_servoing_right_cos_field_plot.title = 'visual servoing right cos'
        self._visual_servoing_right_cos_field_plot.img_plot(
            'imagedata',
            name='visual_servoing_right_cos_field',
            xbounds=(0, self.arch._visual_servoing_field_sizes[0] - 1),
            ybounds=(0, self.arch._visual_servoing_field_sizes[1] - 1),
            colormap=jet,
        )
        range_self = self._visual_servoing_right_cos_field_plot.plots[
            'visual_servoing_right_cos_field'][0].value_mapper.range
        range_self.high = color_range_max_value
        range_self.low = -color_range_max_value

        self._container = VPlotContainer()
        self._hcontainer_top = HPlotContainer()
        self._hcontainer_bottom = HPlotContainer()
        self._hcontainer_bottom.add(self._camera_field_plot)
        self._hcontainer_bottom.add(self._color_space_field_plot)
        self._hcontainer_bottom.add(self._spatial_target_field_plot)
        self._hcontainer_bottom.add(self._move_head_intention_field_plot)
        self._hcontainer_bottom.add(self._move_right_arm_intention_field_plot)
        #        self._hcontainer_bottom.add(self._find_color_intention_field_plot)
        #        self._hcontainer_bottom.add(self._gripper_right_intention_field_plot)

        self._hcontainer_top.add(self._color_space_ee_field_plot)
        self._hcontainer_top.add(
            self._visual_servoing_right_intention_field_plot)
        self._hcontainer_top.add(self._visual_servoing_right_cos_field_plot)
        self._hcontainer_top.add(self._move_head_cos_field_plot)
        self._hcontainer_top.add(self._move_right_arm_cos_field_plot)
        #        self._hcontainer_top.add(self._gripper_right_cos_field_plot)

        self._container.add(self._hcontainer_bottom)
        self._container.add(self._hcontainer_top)

    def onTimer(self, *args):
        self.arch.step()

        self._camera_field_plotdata.set_data(
            'imagedata',
            self.arch._camera_field.get_activation().max(2).transpose())
        self._color_space_field_plotdata.set_data(
            'imagedata',
            self.arch._color_space_field.get_activation().max(1).transpose())
        self._color_space_ee_field_plotdata.set_data(
            'imagedata',
            self.arch._color_space_ee_field.get_activation().max(
                2).transpose())
        self._spatial_target_field_plotdata.set_data(
            'imagedata',
            self.arch._spatial_target_field.get_activation().transpose())
        self._move_head_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._move_head.get_intention_field().get_activation().
            transpose())
        self._move_head_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._move_head.get_cos_field().get_activation().transpose())
        self._visual_servoing_right_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._visual_servoing_right.get_intention_field().
            get_activation().transpose())
        self._visual_servoing_right_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._visual_servoing_right.get_cos_field().get_activation().
            transpose())
        self._move_right_arm_intention_field_plotdata.set_data(
            'imagedata',
            self.arch._move_right_arm_intention_field.get_activation().
            transpose())
        self._move_right_arm_cos_field_plotdata.set_data(
            'imagedata',
            self.arch._move_arm_cos_field.get_activation().transpose())
        #        self._gripper_right_intention_field_plotdata.set_data('imagedata', self.arch._gripper_right_intention_field.get_activation().transpose())
        #        self._gripper_right_cos_field_plotdata.set_data('imagedata', self.arch._gripper_right_cos_field.get_activation().transpose())
        #        self._find_color_intention_field_plotdata.set_data('y', self.arch._find_color.get_intention_field().get_activation())
        #        self._find_color_ee_intention_field_plotdata.set_data('y', self.arch._find_color_ee.get_intention_field().get_activation())

        self._camera_field_plot.request_redraw()
        self._color_space_field_plot.request_redraw()
        self._color_space_ee_field_plot.request_redraw()
        self._spatial_target_field_plot.request_redraw()
        self._move_head_intention_field_plot.request_redraw()
        self._move_head_cos_field_plot.request_redraw()
        self._visual_servoing_right_intention_field_plot.request_redraw()
        self._visual_servoing_right_cos_field_plot.request_redraw()
        self._move_right_arm_intention_field_plot.request_redraw()
        self._move_right_arm_cos_field_plot.request_redraw()
        #        self._gripper_right_cos_field_plot.request_redraw()
        #        self._gripper_right_intention_field_plot.request_redraw()

        return
class CMatrixViewer(MatrixViewer):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    edge_parameter = Instance(EdgeParameters)
    network_reference = Any
    matrix_data_ref = Any
    labels = Any
    fro = Any
    to = Any
    val = Float

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(800, 600)),
                                  show_label=False),
                             HGroup(
                                 Item('fro',
                                      label="From",
                                      style='readonly',
                                      springy=True),
                                 Item('to',
                                      label="To",
                                      style='readonly',
                                      springy=True),
                                 Item('val',
                                      label="Value",
                                      style='readonly',
                                      springy=True),
                             ),
                             orientation="vertical"),
                       Item('edge_parameter_name', label="Choose edge"),
                       handler=CustomHandler(),
                       resizable=True,
                       title="Matrix Viewer")

    def __init__(self, net_ref, **traits):
        """ net_ref is a reference to a cnetwork """
        super(MatrixViewer, self).__init__(**traits)

        self.network_reference = net_ref
        self.edge_parameter = self.network_reference._edge_para
        self.matrix_data_ref = self.network_reference.datasourcemanager._srcobj.edgeattributes_matrix_dict
        self.labels = self.network_reference.datasourcemanager._srcobj.labels

        # get the currently selected edge
        self.curr_edge = self.edge_parameter.parameterset.name
        # create plot
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

        # add edge parameter enum
        self.add_trait('edge_parameter_name',
                       Enum(self.matrix_data_ref.keys()))
        self.edge_parameter_name = self.curr_edge

    def _edge_parameter_name_changed(self, new):
        # update edge parameter dialog
        self.edge_parameter.set_to_edge_parameter(self.edge_parameter_name)

        # update the data
        self.pd.set_data("imagedata",
                         self.matrix_data_ref[self.edge_parameter_name])

        # set range
        #self.my_plot.set_value_selection((0.0, 1.0))

    def _update_fields(self):
        from numpy import trunc

        # map mouse location to array index
        frotmp = int(trunc(self.custtool.yval))
        totmp = int(trunc(self.custtool.xval))

        # check if within range
        sh = self.matrix_data_ref[self.edge_parameter_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            self.fro = self.labels[frotmp]
            self.to = self.labels[totmp]
            self.val = self.matrix_data_ref[self.edge_parameter_name][frotmp,
                                                                      totmp]

    def _create_plot_component(self):

        # we need the matrices!
        # start with the currently selected one
        #nr_nodes = self.matrix_data_ref[curr_edge].shape[0]

        # Create a plot data obect and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.matrix_data_ref[self.curr_edge])

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot(
            "imagedata",
            name="my_plot",
            #xbounds=(0,nr_nodes),
            #ybounds=(0,nr_nodes),
            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = self.curr_edge
        self.tplot.padding = 50

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # TODO: the range selection gives a Segmentation Fault,
        # but why, the matrix_viewer.py example works just fine!
        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        #self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        # my_plot.set_value_selection((-1.3, 6.9))

        return container
Example #57
0
class PlotView(wx.Panel):
    def __init__(self, parent, id=-1, **kwargs):
        wx.Panel.__init__(self, parent, id=id, **kwargs)
        self.statusBar = self.GetTopLevelParent().statusBar

        self.container = OverlayPlotContainer(padding = 50, fill_padding = True,
            bgcolor = "lightgray", use_backbuffer=True)
        self.legend = Legend(component=self.container, padding=10, align="ur")
        #self.legend.tools.append(LegendTool(self.legend, drag_button="right"))
        self.container.overlays.append(self.legend)

        self.plot_window = Window(self, component=self.container)

        self.container.tools.append(TraitsTool(self.container))

        self.firstplot = True
        self._palette = ['red', 'blue', 'green', 'purple', 'yellow']
        self._current_palette_index = 0

        self._traces = []

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.plot_window.control, 1, wx.EXPAND)
        self.SetSizer(sizer)
        self.SetAutoLayout(True)

    def _next_color(self):
        if self._current_palette_index == len(self._palette):
            self._current_palette_index = 0
        self._current_palette_index += 1
        return self._palette[self._current_palette_index - 1]

    def add_plot(self, signal, sweep_point=None):
##        waveform = signal.get_waveform()
##        x = waveform.get_x()[-1][0].tolist()
##        y = np.real(waveform.get_y()[0].tolist())

        if sweep_point is None:
            sweep_point = signal.get_circuit()._sweep_set._points[0]

        trace = Trace(signal, self, self._next_color(), sweep_point)

        x_name = trace.index_label
        y_name = trace.label

        x = trace.get_indices()
        y = trace.get_values()
        if type(y[0]) == complex:
            y = [value.real for value in y]
        #print x_name, len(x)
        #print y_name, len(y)
        #print x
        #print y

        if self.firstplot:
            self.plotdata = ArrayPlotData()
            self.plotdata.set_data(x_name, x)
            self.plotdata.set_data(y_name, y)

            plot = Plot(self.plotdata)
            plot.padding = 1

            plot.bgcolor = "white"
            plot.border_visible = True
            add_default_grids(plot)
            add_default_axes(plot)

            plot.tools.append(PanTool(plot))

            # The ZoomTool tool is stateful and allows drawing a zoom
            # box to select a zoom region.
            zoom = CustomZoomTool(plot)
            plot.overlays.append(zoom)

            # The DragZoom tool just zooms in and out as the user drags
            # the mouse vertically.
            dragzoom = DragZoom(plot, drag_button="right")
            plot.tools.append(dragzoom)

            #~ # Add a legend in the upper right corner, and make it relocatable
            #~ self.legend = Legend(component=plot, padding=10, align="ur")
            #~ self.legend.tools.append(LegendTool(self.legend, drag_button="right"))
            #~ plot.overlays.append(self.legend)

            #~ self.legend.plots = {}

            self.firstplot = False

            self.container.add(plot)

            self.plot = plot

        else:
            self.plotdata.set_data(x_name, x)
            self.plotdata.set_data(y_name, y)

        #self.plot.plot(self.plotdata.list_data())
        pl = self.plot.plot( (x_name, y_name), name=trace.label, type="line",
            color=trace.color, line_style=trace.line_style,
            line_width=trace.line_width, marker=trace.marker,
            marker_size=trace.marker_size, marker_color=trace.marker_color)

        self.legend.plots[trace.label] = pl

        self.Refresh()