Exemple #1
0
class PlotView(View):
    '''
    classdocs
    '''


    def __init__(self, *args, **kwargs):
        '''
        Constructor
        '''
        View.__init__(self,*args,**kwargs)
        self._plot = PlotWidget()
        self.addWidget(self._plot)
        self.setTitle("PlotView")
        self._pen = {'width':5}
        
    def update_slot(self, args):
        '''
        Method called to update the plot. In this case, arguments to specify how to draw a line plot.
        
        Parameters:
        
        - args (:class:`dict`) : a `dict` with at least two fields - "xAxis" and "yAxis". "yAxis" may be a list, in which case all the arrays within will be overlaid. Any extra fields within `args` will be passed along to the :class:`PlotWidget.plot` function.
        
        Returns:
        
        - N/A
        
    
        '''
        assert type(args) is dict, 'PlotView did not receive a dict while calling update.'
        assert 'xAxis' in args, 'PlotView did not receive an x-axis'
        assert 'yAxis' in args, 'PlotView did not receive any y-axis'
        xAxis = args['xAxis']
        yAxis = args['yAxis']
        del args['xAxis']
        del args['yAxis']
        if 'pen' not in args:
            args['pen'] = self._pen
        if type(yAxis) is list:
            self._plot.clear()
            for dset in yAxis:
                self._plot.plot(xAxis,dset,**args)
        else:
            self._plot.plot(xAxis,yAxis,clear=True,**args)
Exemple #2
0
class QtPlotView(QtControl, ProxyPlotView):
    __weakref__ = None
    widget = Typed(PlotWidget)
    _views = List()
    _colors = List(default=['r', 'g', 'b'])

    def create_widget(self):
        self.widget = PlotWidget(self.parent_widget(), background='w')

    def init_widget(self):
        super(QtPlotView, self).init_widget()
        d = self.declaration
        #self.widget.setSizePolicy(QSizePolicy.Expanding,QSizePolicy.Expanding)
        self.set_data(d.data)
        self.set_antialiasing(d.antialiasing)
        self.set_aspect_locked(d.aspect_locked)
        self.set_axis_scales(d.axis_scales)
        self.set_labels(d.labels)
        self.widget.showGrid(d.grid[0], d.grid[1], d.grid_alpha)

        d.setup(self.widget)

    def set_title(self, title):
        self.set_labels(self.declaration.labels)

    def set_labels(self, labels):
        if self.declaration.title:
            labels['title'] = self.declaration.title
        self.widget.setLabels(**labels)

    def set_antialiasing(self, enabled):
        self.widget.setAntialiasing(enabled)

    def set_aspect_locked(self, locked):
        self.widget.setAspectLocked(locked)

    def set_axis_scales(self, scales):
        if not scales:
            return
        for k, v in scales.items():
            if k in self.widget.plotItem.axes:
                self.widget.plotItem.axes[k]['item'].setScale(v)

    def set_grid(self, grid):
        d = self.declaration
        self.widget.showGrid(grid[0], grid[1], d.grid_alpha)

    def set_grid_alpha(self, alpha):
        d = self.declaration
        self.widget.showGrid(d.grid[0], d.grid[1], alpha)

    def set_data(self, data):
        self.widget.clear()
        if not data:
            return

        if isinstance(data, (list, tuple)) and \
                isinstance(data[0], GraphicsObject):
            self._set_graphic_items(data)
        else:
            self._set_numeric_data(data)

    def _set_graphic_items(self, items):
        self.widget.clear()
        for item in items:
            self.widget.addItem(item)

    def _set_numeric_data(self, data):
        self.widget.plotItem.clear()
        if self._views:
            for view in self._views:
                view.clear()

        views = []
        i = 0
        if self.declaration.multi_axis:
            for i, plot in enumerate(data):
                if i > 3:
                    break
                if 'pen' not in plot:
                    plot['pen'] = self._colors[i]
                if i > 0:
                    view = ViewBox()
                    views.append(view)
                    self.widget.plotItem.scene().addItem(view)
                    if i == 1:
                        axis = self.widget.plotItem.getAxis('right')
                    elif i > 1:
                        axis = AxisItem('right')
                        axis.setZValue(-10000)
                        self.widget.plotItem.layout.addItem(axis, 2, 3)
                    axis.linkToView(view)
                    view.setXLink(self.widget.plotItem)
                    view.addItem(PlotCurveItem(**plot))
                else:
                    self.widget.plot(**plot)
        if i > 0:

            def syncViews():
                for v in views:
                    v.setGeometry(self.widget.plotItem.vb.sceneBoundingRect())
                    v.linkedViewChanged(self.widget.plotItem.vb, v.XAxis)

            syncViews()
            self.widget.plotItem.vb.sigResized.connect(syncViews)
        self._views = views
Exemple #3
0
class QtPlotView(QtControl, ProxyPlotView):
    __weakref__ = None
    widget = Typed(PlotWidget)
    _views = List()
    _colors = List(default=['r', 'g', 'b'])

    def create_widget(self):
        self.widget = PlotWidget(self.parent_widget(), background='w')

    def init_widget(self):
        super(QtPlotView, self).init_widget()
        d = self.declaration
        #self.widget.setSizePolicy(QSizePolicy.Expanding,QSizePolicy.Expanding)
        self.set_data(d.data)
        self.set_antialiasing(d.antialiasing)
        self.set_aspect_locked(d.aspect_locked)
        self.set_axis_scales(d.axis_scales)
        self.set_labels(d.labels)
        self.widget.showGrid(d.grid[0], d.grid[1], d.grid_alpha)
        
        d.setup(self.widget)
        
    def set_title(self, title):
        self.set_labels(self.declaration.labels)
        
    def set_labels(self, labels):
        if self.declaration.title:
            labels['title'] = self.declaration.title
        self.widget.setLabels(**labels)
        
    def set_antialiasing(self,enabled):
        self.widget.setAntialiasing(enabled)
        
    def set_aspect_locked(self,locked):
        self.widget.setAspectLocked(locked)
        
    def set_axis_scales(self,scales):
        if not scales:
            return
        for k, v in scales.items():
            if k in self.widget.plotItem.axes:
                self.widget.plotItem.axes[k]['item'].setScale(v)
        
    def set_data(self, data):
        self.widget.clear()
        if not data:
            return
        
        if isinstance(data, (list, tuple)) and isinstance(data[0],
                                                          GraphicsObject):
            self._set_graphic_items(data)
        else:
            self._set_numeric_data(data)
        
    def _set_graphic_items(self, items):
        self.widget.clear()
        for item in items:
            self.widget.addItem(item)
        
    def _set_numeric_data(self,data):
        self.widget.plotItem.clear()
        if self._views:
            for view in self._views:
                view.clear()
            
        views = []
        i = 0
        if self.declaration.multi_axis:
            for i,plot in enumerate(data):
                if i>3:
                    break
                if 'pen' not in plot:
                    plot['pen'] = self._colors[i]
                if i>0:
                    view = ViewBox()
                    views.append(view)
                    self.widget.plotItem.scene().addItem(view)
                    if i==1:
                        axis = self.widget.plotItem.getAxis('right')
                    elif i>1:
                        axis = AxisItem('right')
                        axis.setZValue(-10000)
                        self.widget.plotItem.layout.addItem(axis,2,3)
                    axis.linkToView(view)
                    view.setXLink(self.widget.plotItem)
                    view.addItem(PlotCurveItem(**plot))
                else:    #view.setYLink(self.widget.plotItem)
                    self.widget.plot(**plot)
        if i>0:
            def syncViews():
                for v in views:
                    v.setGeometry(self.widget.plotItem.vb.sceneBoundingRect())
                    v.linkedViewChanged(self.widget.plotItem.vb,v.XAxis)
            syncViews()
            self.widget.plotItem.vb.sigResized.connect(syncViews)
        self._views = views
class ScatterPlotWidget(pg.QtGui.QSplitter):
    """
    This is a high-level widget for exploring relationships in tabular data.
        
    Given a multi-column record array, the widget displays a scatter plot of a
    specific subset of the data. Includes controls for selecting the columns to
    plot, filtering data, and determining symbol color and shape.
    
    The widget consists of four components:
    
    1) A list of column names from which the user may select 1 or 2 columns
       to plot. If one column is selected, the data for that column will be
       plotted in a histogram-like manner by using :func:`pseudoScatter()
       <pyqtgraph.pseudoScatter>`. If two columns are selected, then the
       scatter plot will be generated with x determined by the first column
       that was selected and y by the second.
    2) A DataFilter that allows the user to select a subset of the data by 
       specifying multiple selection criteria.
    3) A ColorMap that allows the user to determine how points are colored by
       specifying multiple criteria.
    4) A PlotWidget for displaying the data.
    """
    sigScatterPlotClicked = pg.QtCore.Signal(object, object, object)
    
    def __init__(self, parent=None):
        pg.QtGui.QSplitter.__init__(self, pg.QtCore.Qt.Horizontal)
        self.ctrlPanel = pg.QtGui.QSplitter(pg.QtCore.Qt.Vertical)
        self.addWidget(self.ctrlPanel)
        self.fieldList = pg.QtGui.QListWidget()
        self.fieldList.setSelectionMode(self.fieldList.ExtendedSelection)
        self.ptree = pg.parametertree.ParameterTree(showHeader=False)
        self.filter = DataFilterParameter()
        self.colorMap = ColorMapParameter()
        self.style = StyleMapParameter()
        self.params = pg.parametertree.Parameter.create(name='params', type='group', children=[self.filter, self.colorMap, self.style])
        self.ptree.setParameters(self.params, showTop=False)
        
        self.plot = PlotWidget()
        self.ctrlPanel.addWidget(self.fieldList)
        self.ctrlPanel.addWidget(self.ptree)
        self.addWidget(self.plot)
        
        fg = pg.mkColor(pg.getConfigOption('foreground'))
        fg.setAlpha(150)
        self.filterText = pg.TextItem(border=pg.getConfigOption('foreground'), color=fg)
        self.filterText.setPos(60,20)
        self.filterText.setParentItem(self.plot.plotItem)
        
        self.data = None
        self.indices = None
        self.mouseOverField = None
        self.scatterPlot = None
        self.selectionScatter = None
        self.selectedIndices = []
        self._visibleXY = None  # currently plotted points
        self._visibleData = None  # currently plotted records
        self._visibleIndices = None
        self._indexMap = None
        
        self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged)
        self.filter.sigFilterChanged.connect(self.filterChanged)
        self.colorMap.sigColorMapChanged.connect(self.updatePlot)
        self.style.sigStyleChanged.connect(self.updatePlot)
    
    def setFields(self, fields, mouseOverField=None):
        """
        Set the list of field names/units to be processed.
        
        The format of *fields* is the same as used by 
        :func:`ColorMapWidget.setFields <pyqtgraph.widgets.ColorMapWidget.ColorMapParameter.setFields>`
        """
        self.fields = OrderedDict(fields)
        self.mouseOverField = mouseOverField
        self.fieldList.clear()
        for f,opts in fields:
            item = pg.QtGui.QListWidgetItem(f)
            item.opts = opts
            item = self.fieldList.addItem(item)
        self.filter.setFields(fields)
        self.colorMap.setFields(fields)
        self.style.setFields(fields)

    def setSelectedFields(self, *fields):
        self.fieldList.itemSelectionChanged.disconnect(self.fieldSelectionChanged)
        try:
            self.fieldList.clearSelection()
            for f in fields:
                i = list(self.fields.keys()).index(f)
                item = self.fieldList.item(i)
                item.setSelected(True)
        finally:
            self.fieldList.itemSelectionChanged.connect(self.fieldSelectionChanged)
        self.fieldSelectionChanged()

    def setData(self, data):
        """
        Set the data to be processed and displayed. 
        Argument must be a numpy record array.
        """
        self.data = data
        self.indices = np.arange(len(data))
        self.filtered = None
        self.filteredIndices = None
        self.updatePlot()
        
    def setSelectedIndices(self, inds):
        """Mark the specified indices as selected.

        Must be a sequence of integers that index into the array given in setData().
        """
        self.selectedIndices = inds
        self.updateSelected()

    def setSelectedPoints(self, points):
        """Mark the specified points as selected.

        Must be a list of points as generated by the sigScatterPlotClicked signal.
        """
        self.setSelectedIndices([pt.originalIndex for pt in points])

    def fieldSelectionChanged(self):
        sel = self.fieldList.selectedItems()
        if len(sel) > 2:
            self.fieldList.blockSignals(True)
            try:
                for item in sel[1:-1]:
                    item.setSelected(False)
            finally:
                self.fieldList.blockSignals(False)
                
        self.updatePlot()
        
    def filterChanged(self, f):
        self.filtered = None
        self.updatePlot()
        desc = self.filter.describe()
        if len(desc) == 0:
            self.filterText.setVisible(False)
        else:
            self.filterText.setText('\n'.join(desc))
            self.filterText.setVisible(True)
        
    def updatePlot(self):
        self.plot.clear()
        if self.data is None or len(self.data) == 0:
            return
        
        if self.filtered is None:
            mask = self.filter.generateMask(self.data)
            self.filtered = self.data[mask]
            self.filteredIndices = self.indices[mask]
        data = self.filtered
        if len(data) == 0:
            return
        
        colors = np.array([pg.mkBrush(*x) for x in self.colorMap.map(data)])
        
        style = self.style.map(data)
        
        ## Look up selected columns and units
        sel = list([str(item.text()) for item in self.fieldList.selectedItems()])
        units = list([item.opts.get('units', '') for item in self.fieldList.selectedItems()])
        if len(sel) == 0:
            self.plot.setTitle('')
            return
        

        if len(sel) == 1:
            self.plot.setLabels(left=('N', ''), bottom=(sel[0], units[0]), title='')
            if len(data) == 0:
                return
            #x = data[sel[0]]
            #y = None
            xy = [data[sel[0]], None]
        elif len(sel) == 2:
            self.plot.setLabels(left=(sel[1],units[1]), bottom=(sel[0],units[0]))
            if len(data) == 0:
                return
            
            xy = [data[sel[0]], data[sel[1]]]
            #xydata = []
            #for ax in [0,1]:
                #d = data[sel[ax]]
                ### scatter catecorical values just a bit so they show up better in the scatter plot.
                ##if sel[ax] in ['MorphologyBSMean', 'MorphologyTDMean', 'FIType']:
                    ##d += np.random.normal(size=len(cells), scale=0.1)
                    
                #xydata.append(d)
            #x,y = xydata

        ## convert enum-type fields to float, set axis labels
        enum = [False, False]
        for i in [0,1]:
            axis = self.plot.getAxis(['bottom', 'left'][i])
            if xy[i] is not None and (self.fields[sel[i]].get('mode', None) == 'enum' or xy[i].dtype.kind in ('S', 'O')):
                vals = self.fields[sel[i]].get('values', list(set(xy[i])))
                xy[i] = np.array([vals.index(x) if x in vals else len(vals) for x in xy[i]], dtype=float)
                axis.setTicks([list(enumerate(vals))])
                enum[i] = True
            else:
                axis.setTicks(None)  # reset to automatic ticking
        
        ## mask out any nan values
        mask = np.ones(len(xy[0]), dtype=bool)
        if xy[0].dtype.kind == 'f':
            mask &= np.isfinite(xy[0])
        if xy[1] is not None and xy[1].dtype.kind == 'f':
            mask &= np.isfinite(xy[1])
        
        xy[0] = xy[0][mask]

        for k in style.keys():
            if style[k] is None:
                continue
            style[k] = style[k][mask]
        style['symbolBrush'] = colors[mask]
        data = data[mask]
        indices = self.filteredIndices[mask]

        ## Scatter y-values for a histogram-like appearance
        if xy[1] is None:
            ## column scatter plot
            xy[1] = pg.pseudoScatter(xy[0])
        else:
            xy[1] = xy[1][mask]
        ## beeswarm plots
        
        for ax in [0,1]:
            if not enum[ax]:
                continue
            imax = int(xy[ax].max()) if len(xy[ax]) > 0 else 0
            for i in range(imax+1):
                keymask = xy[ax] == i
                scatter = pg.pseudoScatter(xy[1-ax][keymask], bidir=True)
                if len(scatter) == 0:
                    continue
                smax = np.abs(scatter).max()
                if smax != 0:
                    scatter *= 0.2 / smax
                xy[ax][keymask] += scatter


        if self.scatterPlot is not None:
            try:
                self.scatterPlot.sigPointsClicked.disconnect(self.plotClicked)
            except:
                pass
        
        self._visibleXY = xy
        self._visibleData = data
        self._visibleIndices = indices
        self._indexMap = None
        self.scatterPlot = self.plot.plot(xy[0], xy[1], data=data, **style)
        self.scatterPlot.sigPointsClicked.connect(self.plotClicked)
        self.updateSelected()

    def updateSelected(self):
        if self._visibleXY is None:
            return
        # map from global index to visible index
        indMap = self._getIndexMap()
        inds = [indMap[i] for i in self.selectedIndices if i in indMap]
        x,y = self._visibleXY[0][inds], self._visibleXY[1][inds]

        if self.selectionScatter is not None:
            self.plot.plotItem.removeItem(self.selectionScatter)
        if len(x) == 0:
            return
        self.selectionScatter = self.plot.plot(x, y, pen=None, symbol='s', symbolSize=12, symbolBrush=None, symbolPen='y')

    def _getIndexMap(self):
        # mapping from original data index to visible point index
        if self._indexMap is None:
            self._indexMap = {j:i for i,j in enumerate(self._visibleIndices)}
        return self._indexMap

    def plotClicked(self, plot, points, ev):
        # Tag each point with its index into the original dataset
        for pt in points:
            pt.originalIndex = self._visibleIndices[pt.index()]
        self.sigScatterPlotClicked.emit(self, points, ev)