Esempio n. 1
0
class MplCanvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        matplotlib.rcParams['font.size'] = 8
        self.figure = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.figure.add_subplot(111)

        FigureCanvas.__init__(self, self.figure)
        self.setParent(parent)
        
        self.toolbar = NavigationToolbar(self, parent)
        self.toolbar.setIconSize(QSize(16, 16))

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        
    def getToolbar(self):
        return self.toolbar

    def clear(self):
        self.figure.clear()
        self.axes = self.figure.add_subplot(111)
        
    def test(self):
        self.axes.plot([1,2,3,4])
        
    def saveAs(self, fname):
        self.figure.savefig(fname)
Esempio n. 2
0
class MplWidgetT(QtGui.QWidget):
    def __init__(self, parent=None):
        QtGui.QWidget.__init__(self, parent)
        self.canvas = MplCanvas()
        self.ntb = NavigationToolbar(self.canvas, self)
        self.ntb.setIconSize(QtCore.QSize(16, 16))
        self.vbl = QtGui.QVBoxLayout()
        self.vbl.addWidget(self.canvas)
        self.vbl.addWidget(self.ntb)
        self.setLayout(self.vbl)
Esempio n. 3
0
class MplWidgetT(QtGui.QWidget):
    def __init__(self, parent = None):
        QtGui.QWidget.__init__(self, parent)
        self.canvas = MplCanvas()
        self.ntb = NavigationToolbar(self.canvas, self)
        self.ntb.setIconSize(QtCore.QSize(16, 16))
        self.vbl = QtGui.QVBoxLayout()
        self.vbl.addWidget(self.canvas)
        self.vbl.addWidget(self.ntb)
        self.setLayout(self.vbl)
 def NewFigure_proc(self):
     widget = QtGui.QWidget(self.MainFigTabWidget)
     vlay = QtGui.QVBoxLayout(widget)
     self.Figures.append(MplWidget(widget))
     ntb = NavToolbar(self.Figures[-1], parent = widget)
     ntb.setIconSize(QtCore.QSize(15,15))
     vlay.setSpacing(0)
     vlay.setMargin(0)
     vlay.addWidget(self.Figures[-1])
     vlay.addWidget(ntb)
     widget.setLayout(vlay)
     self.MainFigTabWidget.addTab(widget, 'Figure '+str(len(self.Figures)))
     widget.setObjectName(str(self.MainFigTabWidget.count()))
     self.Figures[-1].setObjectName(str(self.MainFigTabWidget.count()))
    def __init__(self):
        QtGui.QMainWindow.__init__(self)
        self.setWindowTitle("NeuroExplorer GUI")
        self.MainWidget = QtGui.QWidget()
        self.MainLayout = QtGui.QHBoxLayout(self.MainWidget)
        self.WorkingDir = pth

        # create a file dialog
        self.fileDialog = QtGui.QFileDialog()

        # create an open file action that gets triggered when called from menu
        openAction = QtGui.QAction('&Open File', self)        
        openAction.setShortcut('Ctrl+O')
        openAction.triggered.connect(self.OpenFile_proc)

        # trigger settings from menu
        settingsAction = QtGui.QAction('&Settings', self)
        settingsAction.setShortcut('Ctrl+S')
        settingsAction.triggered.connect(self.Settings_proc)

        # trigger the closing of the application
        closeAction = QtGui.QAction('&Close H5File', self)        
        closeAction.setShortcut('Ctrl+X')
        closeAction.triggered.connect(self.CloseFile_proc)

        # create the menubar
        menubar = self.menuBar()
        fileMenu = menubar.addMenu('&File')
        fileMenu.addAction(openAction)
        fileMenu.addAction(settingsAction)
        fileMenu.addAction(closeAction)

        showDock0Action = QtGui.QAction('&Show Figure Control', self)
        showDock1Action = QtGui.QAction('&Show Analisys Centered', self)
        showDock2Action = QtGui.QAction('&Show Unit Centered', self)
              
        windowsMenu = menubar.addMenu('&Windows')
        windowsMenu.addAction(showDock0Action)
        windowsMenu.addAction(showDock1Action)
        windowsMenu.addAction(showDock2Action)

        ############ add a dockable figure control widget
        
        dock0 = QtGui.QDockWidget('Figure Control', self)
        dock0.setAllowedAreas(QtCore.Qt.LeftDockWidgetArea | QtCore.Qt.RightDockWidgetArea)
        dock0.setMinimumWidth(QtGui.QApplication.desktop().availableGeometry().width()/6)
        showDock0Action.triggered.connect(dock0.show)
        w = QtGui.QWidget(dock0)
        w.setMaximumHeight(60)
        vlay = QtGui.QVBoxLayout(w)
        vlay.setSpacing(0)
        vlay.setMargin(0)

        hlay = QtGui.QHBoxLayout()
        self.NewFigBtn = QtGui.QPushButton('New Figure')
        self.NewFigBtn.setFont(QtGui.QFont('',8))
        #self.NewFigBtn.setMaximumHeight(25)
        self.NewFigBtn.clicked.connect(self.NewFigure_proc)
        hlay.addWidget(self.NewFigBtn)

        self.NewFigBtn = QtGui.QPushButton('Save Figure')
        self.NewFigBtn.setFont(QtGui.QFont('',8))
        #self.NewFigBtn.setMaximumHeight(25)
        self.NewFigBtn.clicked.connect(self.SaveFig_proc)
        hlay.addWidget(self.NewFigBtn)
        vlay.addLayout(hlay)

        hlay = QtGui.QHBoxLayout(w)
        self.FigNameText = QtGui.QLineEdit()
        self.FigNameText.setFont(QtGui.QFont('',8))
        self.FigNameText.setMaximumHeight(25)
        self.ChangeTabLabelBtn = QtGui.QPushButton('Set Current Fig Label')
        self.ChangeTabLabelBtn.setFont(QtGui.QFont('',8))
        self.ChangeTabLabelBtn.clicked.connect(self.ChangeCurTabLabel_proc)
        hlay.addWidget(self.FigNameText)
        hlay.addWidget(self.ChangeTabLabelBtn)
        vlay.addLayout(hlay)
        dock0.setWidget(w)

        ############ add a dockable toolbox widget
        dock1 = QtGui.QDockWidget('Analisys Centered', self)
        dock1.setAllowedAreas(QtCore.Qt.LeftDockWidgetArea | QtCore.Qt.RightDockWidgetArea)
        
        w = QtGui.QWidget(dock1)
        #mainlay = QtGui.QVBoxLayout(w)
        showDock1Action.triggered.connect(dock1.show)

        ### h5 file operations group
        vlay = QtGui.QVBoxLayout(w)
        vlay.setMargin(2)
        vlay.setSpacing(2)

        hlay = QtGui.QHBoxLayout()
        self.AnalisysTypeCombo = QtGui.QComboBox()
        self.AnalisysTypeCombo.addItems(['PSTH','Autocorrelation','Crosscorrelation','Spectrum'])
        self.AnalisysTypeCombo.setFont(QtGui.QFont('',8))
        lbl = QtGui.QLabel('Analisys Type')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.AnalisysTypeCombo)
        vlay.addLayout(hlay)

        # add a spin box to select the number of columns
        hlay = QtGui.QHBoxLayout()
        self.nColumnsAxesSpin = QtGui.QSpinBox()
        self.nColumnsAxesSpin.setRange(0,10)
        self.nColumnsAxesSpin.setValue(6)
        lbl = QtGui.QLabel('Number of Columns')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.nColumnsAxesSpin)
        vlay.addLayout(hlay)

        # time limits
        hlay = QtGui.QHBoxLayout()
        self.tWin1Spin = QtGui.QDoubleSpinBox()
        #self.tWin1Spin.setMaximumHeight(20)
        self.tWin1Spin.setRange(0,5)
        self.tWin1Spin.setValue(1)
        self.tWin1Spin.setSingleStep(0.1)
        self.tWin1Spin.setFont(QtGui.QFont('',8))
        
        self.tWin2Spin = QtGui.QDoubleSpinBox()
        #self.tWin2Spin.setMaximumHeight(20)
        self.tWin2Spin.setRange(0,5)
        self.tWin2Spin.setValue(2)
        self.tWin2Spin.setSingleStep(0.1)
        self.tWin2Spin.setFont(QtGui.QFont('',8))
        self.timePerBinSpin = QtGui.QSpinBox()
        self.timePerBinSpin.setRange(5,500)
        self.timePerBinSpin.setValue(20)
        lbl = QtGui.QLabel('Twin1')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.tWin1Spin)
        hlay.addStretch(1)
        lbl = QtGui.QLabel('Twin2')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.tWin2Spin)
        lbl = QtGui.QLabel('Resolution (milisec/bin)')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.timePerBinSpin)
        vlay.addLayout(hlay)

        #ylim spin
        hlay = QtGui.QHBoxLayout()
        self.ylimSpin = QtGui.QSpinBox()
        self.ylimSpin.setFont(QtGui.QFont('',8))
        self.ylimSpin.setValue(50)
        hlay.addWidget(QtGui.QLabel('PSTH Ylim'))
        hlay.addWidget(self.ylimSpin)
        vlay.addLayout(hlay)
        
        # Plot Btn
        self.PlotRasterBtn = QtGui.QPushButton('Plot Raster')
        self.PlotRasterBtn.setFont(QtGui.QFont('', 8, weight=0))
        self.PlotRasterBtn.clicked.connect(self.PlotRaster_proc)
        vlay.addWidget(self.PlotRasterBtn)

        # units table widget
        self.UnitsTable = QtGui.QTableWidget(0,2, self)
        vlay.addWidget(self.UnitsTable)
        for k in range(self.UnitsTable.columnCount()):
            self.UnitsTable.setColumnWidth(k, 60)
        self.UnitsTable.setHorizontalHeaderLabels(['Unit','Plot?'])
        
        # select all select none buttons
        hlay = QtGui.QHBoxLayout()
        self.SelectAllBtn = QtGui.QPushButton('Select All')
        self.SelectAllBtn.setFont(QtGui.QFont('',8))
        self.SelectAllBtn.setMaximumHeight(20)
        self.SelectAllBtn.clicked.connect(self.SelectAll_proc)
        self.SelectNoneBtn = QtGui.QPushButton('Select None')
        self.SelectNoneBtn.setFont(QtGui.QFont('',8))
        self.SelectNoneBtn.setMaximumHeight(20)
        self.SelectNoneBtn.clicked.connect(self.SelectNone_proc)
        hlay.addWidget(self.SelectAllBtn)
        hlay.addWidget(self.SelectNoneBtn)
        vlay.addLayout(hlay)

        # create en event selecting combobox
        hlay = QtGui.QHBoxLayout()
        self.EventSelectCombo = QtGui.QComboBox()
        self.EventSelectCombo.setFont(QtGui.QFont('',8))
        self.EventSelectCombo.setMaximumHeight(20)
        hlay.addWidget(QtGui.QLabel('Event'))
        hlay.addWidget(self.EventSelectCombo)
        vlay.addLayout(hlay)
        
        # add the toolbox widget to the docking area
        dock1.setWidget(w)

        ############ add a dockable unit centered toolbox widget
        dock2 = QtGui.QDockWidget('Unit Centered Analisys', self)
        dock2.setAllowedAreas(QtCore.Qt.LeftDockWidgetArea | QtCore.Qt.RightDockWidgetArea)
        showDock2Action.triggered.connect(dock2.show)
        w = QtGui.QWidget(dock2)
        vlay = QtGui.QVBoxLayout(w)
        vlay.setMargin(2)
        vlay.setSpacing(2)

        # add a unit selector combo box
        hlay = QtGui.QHBoxLayout()
        self.ChannelSelectCombo = QtGui.QComboBox(w)
        self.ChannelSelectCombo.setFont(QtGui.QFont('',8))
        self.ChannelSelectCombo.currentIndexChanged.connect(self.UpdateUnitSelectCombo_proc)
        self.UnitSelectCombo = QtGui.QComboBox(w)
        self.UnitSelectCombo.setFont(QtGui.QFont('',8))
        hlay.addWidget(self.ChannelSelectCombo)
        hlay.addWidget(self.UnitSelectCombo)
        vlay.addLayout(hlay)

        # add a spin box to select the number of columns
        hlay = QtGui.QHBoxLayout()
        self.nColumnsUnitAnalisys = QtGui.QSpinBox()
        self.nColumnsUnitAnalisys.setRange(0,10)
        self.nColumnsUnitAnalisys.setValue(3)
        self.nColumnsUnitAnalisys.setFont(QtGui.QFont('',8))
        lbl = QtGui.QLabel('Number of Columns')
        lbl.setFont(QtGui.QFont('',8))
        hlay.addWidget(lbl)
        hlay.addWidget(self.nColumnsUnitAnalisys)
        vlay.addLayout(hlay)

        # add an analisys table
        self.AnalisysTable = QtGui.QTableWidget(0,5,w)
        for k in range(self.AnalisysTable.columnCount()):
            self.AnalisysTable.setColumnWidth(k,80)
        self.AnalisysTable.setColumnWidth(2,50)
        self.AnalisysTable.setColumnWidth(3,50)
        self.AnalisysTable.verticalHeader().setVisible(False)
        self.AnalisysTable.setHorizontalHeaderLabels(['Analisys',
                                                      'Event',
                                                      'TWin1',
                                                      'TWin2',
                                                      'Unit 2'])
        self.AnalisysTable.horizontalHeader().setFont(QtGui.QFont('',8))
        vlay.addWidget(self.AnalisysTable)
        dock2.setWidget(w)


        ### each created analisys creates an axes
        hlay = QtGui.QHBoxLayout()
        # add analisys btn
        self.AddAnalisysBtn = QtGui.QPushButton('Add Analisys', w)
        self.AddAnalisysBtn.setFont(QtGui.QFont('',8))
        self.AddAnalisysBtn.setMaximumHeight(20)
        self.AddAnalisysBtn.clicked.connect(self.AddAnalisys_proc)
        hlay.addWidget(self.AddAnalisysBtn)

        #remove analisys btn
        self.RemoveAnalisysBtn = QtGui.QPushButton('Remove Analisys', w)
        self.RemoveAnalisysBtn.setFont(QtGui.QFont('',8))
        self.RemoveAnalisysBtn.setMaximumHeight(20)
        self.RemoveAnalisysBtn.clicked.connect(self.RemoveAnalisys_proc)
        hlay.addWidget(self.RemoveAnalisysBtn)
        vlay.addLayout(hlay)

        # Plot Btn
        self.PlotAnalisysBtn = QtGui.QPushButton('Plot Analisys')
        self.PlotAnalisysBtn.setFont(QtGui.QFont('', 8, weight=0))
        self.PlotAnalisysBtn.clicked.connect(self.PlotAnalisys_proc)
        vlay.addWidget(self.PlotAnalisysBtn)
        
        ############# add a figure tab widget
        self.MainFigTabWidget = QtGui.QTabWidget()
        self.MainFigTabWidget.setTabsClosable(True)
        self.MainFigTabWidget.setMovable(True)
        self.MainFigTabWidget.tabCloseRequested.connect(self.closeTab_proc)
        widget = QtGui.QWidget()
        
        vlay = QtGui.QVBoxLayout(widget)
        self.Figures = []
        self.Figures.append(MplWidget(widget))
        ntb = NavToolbar(self.Figures[-1], parent = widget)
        ntb.setIconSize(QtCore.QSize(15,15))
        vlay.setSpacing(0)
        vlay.setMargin(0)
        vlay.addWidget(self.Figures[-1])
        vlay.addWidget(ntb)
        widget.setLayout(vlay)

        self.MainFigTabWidget.addTab(widget, 'Figure 1')
        widget.setObjectName(str(self.MainFigTabWidget.count()))
        self.Figures[-1].setObjectName(str(self.MainFigTabWidget.count()))
        #self.MainLayout.addWidget(self.MainFigTabWidget)

        # add the dock to the left docking area
        self.addDockWidget(QtCore.Qt.LeftDockWidgetArea,  dock0)
        self.addDockWidget(QtCore.Qt.LeftDockWidgetArea,  dock1)
        self.addDockWidget(QtCore.Qt.LeftDockWidgetArea,  dock2)
        # set central widget
        self.setCentralWidget(self.MainFigTabWidget)

        # if running in linux set a certain style for the buttons and widgets
        if sys.platform == 'linux2':
            QtGui.QApplication.setStyle(QtGui.QStyleFactory.create('Plastique'))
Esempio n. 6
0
class SimulationGui(QMainWindow):
    def __init__(self, net=None, parent=None, fname=None):
        QMainWindow.__init__(self)

        self.ui = Ui_SimulationWindow()
        self.ui.setupUi(self)

        if fname:
            self.set_title(fname)

        # context menu
        self.ui.nodeInspector.addAction(self.ui.actionCopyInspectorData)
        self.ui.nodeInspector.addAction(self.ui.actionShowLocalizedSubclusters)
        # callbacks
        self.ui.actionCopyInspectorData.activated\
            .connect(self.on_actionCopyInspectorData_triggered)
        self.ui.actionShowLocalizedSubclusters.activated\
            .connect(self.on_actionShowLocalizedSubclusters_triggered)

        self.dpi = 72
        # take size of networDisplayWidget
        self.fig = Figure((700 / self.dpi, 731 / self.dpi), self.dpi,
                          facecolor='0.9')
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.ui.networkDisplayWidget)
        self.nav = NavigationToolbar(self.canvas, self.ui.networkDisplayWidget,
                                     coordinates=True)
        self.nav.setGeometry(QRect(0, 0, 651, 36))
        self.nav.setIconSize(QSize(24, 24))

        self.axes = self.fig.add_subplot(111)
        # matplotlib.org/api/figure_api.html#matplotlib.figure.SubplotParams
        self.fig.subplots_adjust(left=0.03, right=0.99, top=0.92)

        if net:
            self.init_sim(net)

        self.connect(self.ui.showNodes, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showEdges, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showMessages, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showLabels, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.redrawNetworkButton, SIGNAL('clicked(bool)'),
                     self.redraw)
        self.connect(self.ui.treeGroupBox, SIGNAL('toggled(bool)'),
                     self.refresh_visibility)
        self.connect(self.ui.treeKey, SIGNAL('textEdited(QString)'),
                     self.redraw)
        self.connect(self.ui.propagationError, SIGNAL('toggled(bool)'),
                     self.refresh_visibility)
        self.connect(self.ui.locKey, SIGNAL('textEdited(QString)'),
                     self.redraw)
        # callbacks
        self.ui.actionOpenNetwork.activated\
            .connect(self.on_actionOpenNetwork_triggered)
        self.ui.actionSaveNetwork.activated\
            .connect(self.on_actionSaveNetwork_triggered)
        self.ui.actionRun.activated.connect(self.on_actionRun_triggered)
        self.ui.actionStep.activated.connect(self.on_actionStep_triggered)
        self.ui.actionReset.activated.connect(self.on_actionReset_triggered)

        self.canvas.mpl_connect('pick_event', self.on_pick)

    def handleInspectorMenu(self, pos):
        menu = QMenu()
        menu.addAction('Add')
        menu.addAction('Delete')
        menu.exec_(QCursor.pos())

    def init_sim(self, net):
        self.net = net
        self.sim = Simulation(net)
        self.connect(self.sim, SIGNAL("redraw()"), self.redraw)
        self.connect(self.sim, SIGNAL("updateLog(QString)"), self.update_log)
        self.redraw()

    def update_log(self, text):
        """ Add item to list widget """
        print "Add: " + text
        self.ui.logListWidget.insertItem(0, text)
        # self.ui.logListWidget.sortItems()

    def redraw(self):
        self.refresh_network_inspector()
        self.draw_network()
        self.reset_zoom()
        self.refresh_visibility()

    def draw_network(self, net=None, clear=True, subclusters=None,
                     drawMessages=True):
        if not net:
            net = self.net
        currentAlgorithm = self.net.get_current_algorithm()
        if clear:
            self.axes.clear()
        self.axes.imshow(net.environment.im, vmin=0, cmap='binary_r',
                         origin='lower')

        self.draw_tree(str(self.ui.treeKey.text()), net)
        self.draw_edges(net)
        self.draw_propagation_errors(str(self.ui.locKey.text()), net)
        if subclusters:
            node_colors = self.get_node_colors(net, subclusters=subclusters)
        else:
            node_colors = self.get_node_colors(net, algorithm=currentAlgorithm)
        self.node_collection = self.draw_nodes(net, node_colors)
        if drawMessages:
            self.draw_messages(net)
        self.draw_labels(net)
        self.drawnNet = net
        step_text = ' (step %d)' % self.net.algorithmState['step'] \
                    if isinstance(currentAlgorithm, NodeAlgorithm) else ''
        self.axes.set_title((currentAlgorithm.name
                             if currentAlgorithm else '') + step_text)

        self.refresh_visibility()
        # To save multiple figs of the simulation uncomment next two lines:
        #self.fig.savefig('network-alg-%d-step-%d.png' %
        #                 (self.net.algorithmState['index'], self.net.algorithmState['step']))

    def draw_nodes(self, net=None, node_colors={}, node_radius={}):
        if not net:
            net = self.net
        if type(node_colors) == str:
            node_colors = {node: node_colors for node in net.nodes()}
        nodeCircles = []
        for n in net.nodes():
            c = NodeCircle(tuple(net.pos[n]), node_radius.get(n, 8.0),
                           color=node_colors.get(n, 'r'),
                       ec='k', lw=1.0, ls='solid', picker=3)
            nodeCircles.append(c)
        node_collection = PatchCollection(nodeCircles, match_original=True)
        node_collection.set_picker(3)
        self.axes.add_collection(node_collection)
        return node_collection

    def get_node_colors(self, net, algorithm=None, subclusters=None,
                        drawLegend=True):
            COLORS = 'rgbcmyw' * 100
            node_colors = {}
            if algorithm:
                color_map = {}
                if isinstance(algorithm, NodeAlgorithm):
                    for ind, status in enumerate(algorithm.STATUS.keys()):
                        if status == 'IDLE':
                            color_map.update({status: 'k'})
                        else:
                            color_map.update({status: COLORS[ind]})
                    if drawLegend:
                        proxy = []
                        labels = []
                        for status, color in color_map.items():
                            proxy.append(Circle((0, 0), radius=8.0,
                                                color=color, ec='k',
                                                lw=1.0, ls='solid'))
                            labels.append(status)
                        self.fig.legends = []
                        self.fig.legend(proxy, labels, loc=8,
                                        prop={'size': '10.0'}, ncol=len(proxy),
                                        title='Statuses for %s:'
                                                % algorithm.name)
                for n in net.nodes():
                    if n.status == '' or not n.status in color_map.keys():
                        node_colors[n] = 'r'
                    else:
                        node_colors[n] = color_map[n.status]
            elif subclusters:
                for i, sc in enumerate(subclusters):
                    for n in sc:
                        if n in node_colors:
                            node_colors[n] = 'k'
                        else:
                            node_colors[n] = COLORS[i]
            return node_colors

    def draw_edges(self, net=None):
        if not net:
            net = self.net
        self.edge_collection = nx.draw_networkx_edges(net, net.pos, alpha=0.6,
                                                      edgelist=None,
                                                      ax=self.axes)

    def draw_messages(self, net=None):
        if not net:
            net = self.net
        self.messages = []
        msgCircles = []
        for node in net.nodes():
            for msg in node.outbox:
                # broadcast
                if msg.destination is None:
                    for neighbor in net.adj[node].keys():
                        nbr_msg = msg.copy()
                        nbr_msg.destination = neighbor
                        c = MessageCircle(nbr_msg, net, 'out', 3.0, lw=0,
                                          picker=3, zorder=3, color='b')
                        self.messages.append(nbr_msg)
                        msgCircles.append(c)
                else:
                    c = MessageCircle(msg, net, 'out', 3.0, lw=0, picker=3,
                                      zorder=3, color='b')
                    self.messages.append(msg)
                    msgCircles.append(c)
            for msg in node.inbox:
                c = MessageCircle(msg, net, 'in', 3.0, lw=0, picker=3,
                                  zorder=3, color='g')
                self.messages.append(msg)
                msgCircles.append(c)
        if self.messages:
            self.message_collection = PatchCollection(msgCircles,
                                                      match_original=True)
            self.message_collection.set_picker(3)
            self.axes.add_collection(self.message_collection)

    def draw_labels(self, net=None):
        if not net:
            net = self.net
        label_pos = {}
        for n in net.nodes():
            label_pos[n] = net.pos[n].copy() + 10
        self.label_collection = nx.draw_networkx_labels(net, label_pos,
                                                        labels=net.labels,
                                                        ax=self.axes)

    def refresh_visibility(self):
        try:
            self.node_collection.set_visible(self.ui.showNodes.isChecked())
            self.edge_collection.set_visible(self.ui.showEdges.isChecked())
            for label in self.label_collection.values():
                label.set_visible(self.ui.showLabels.isChecked())
            self.tree_collection.set_visible(self.ui.treeGroupBox.isChecked())
            self.ini_error_collection.set_visible(self.ui.propagationError\
                                                    .isChecked())
            self.propagation_error_collection.set_visible(self.ui\
                                                          .propagationError\
                                                          .isChecked())
            # sould be last, sometimes there are no messages
            self.message_collection.set_visible(self.ui.showMessages\
                                                    .isChecked())
        except AttributeError:
            print 'Refresh visibility warning'
        self.canvas.draw()

    def draw_tree(self, treeKey, net=None):
        """
        Show tree representation of network.

        Attributes:
            treeKey (str):
                key in nodes memory (dictionary) where tree data is stored
                storage format can be a list off tree neighbors or a dict:
                    {'parent': parent_node,
                     'children': [child_node1, child_node2 ...]}
        """
        if not net:
            net = self.net
        treeNet = net.get_tree_net(treeKey)
        if treeNet:
            self.tree_collection = draw_networkx_edges(treeNet, treeNet.pos,
                                                       treeNet.edges(),
                                                       width=1.8, alpha=0.6,
                                                       ax=self.axes)

    def draw_propagation_errors(self, locKey, net):
        SCALE_FACTOR = 0.6
        if not net:
            net = self.net
        if any([not locKey in node.memory for node in net.nodes()]):
            self.propagation_error_collection = []
            self.ini_error_collection = []
            return

        rms = {'iniRms': {}, 'stitchRms': {}}
        for node in net.nodes():
            rms['iniRms'][node] = get_rms(self.net.pos,
                                          (node.memory['iniLocs']), True) * \
                                          SCALE_FACTOR
            rms['stitchRms'][node] = get_rms(self.net.pos, node.memory[locKey],
                                             True) * SCALE_FACTOR
        self.propagation_error_collection = \
                            self.draw_nodes(net=net, node_colors='g',
                                            node_radius=rms['stitchRms'])
        self.ini_error_collection = self.draw_nodes(net=net, node_colors='b',
                                                    node_radius=rms['iniRms'])

    def reset_zoom(self):
        self.axes.set_xlim((0, self.net.environment.im.shape[1]))
        self.axes.set_ylim((0, self.net.environment.im.shape[0]))

    def set_title(self, fname):
        new = ' - '.join([str(self.windowTitle()).split(' - ')[0], str(fname)])
        self.setWindowTitle(new)

    def refresh_network_inspector(self):
        niModel = DictionaryTreeModel(dic=self.net.get_dic())
        self.ui.networkInspector.setModel(niModel)
        self.ui.networkInspector.expandToDepth(0)

    """
    Callbacks
    """

    def on_actionRun_triggered(self):
        self.ui.logListWidget.clear()
        print 'running ...',
        self.sim.stepping = True
        self.sim.run_all()

    def on_actionStep_triggered(self):
        print 'next step ...',
        self.sim.run(self.ui.stepSize.value())

    def on_actionReset_triggered(self):
        print 'reset ...',
        self.sim.reset()
        self.redraw()

    def on_actionCopyInspectorData_triggered(self):
        string = 'Node inspector data\n-------------------'
        # raise()
        for qModelIndex in self.ui.nodeInspector.selectedIndexes():
            string += '\n' + qModelIndex.internalPointer().toString('    ')

        clipboard = app.clipboard()
        clipboard.setText(string)
        event = QEvent(QEvent.Clipboard)
        app.sendEvent(clipboard, event)

    def on_actionShowLocalizedSubclusters_triggered(self):
        if len(self.ui.nodeInspector.selectedIndexes()) == 1:
            qModelIndex = self.ui.nodeInspector.selectedIndexes()[0]
            treeItem = qModelIndex.internalPointer()
            assert(isinstance(treeItem.itemDataValue, Positions))

            estimated = deepcopy(treeItem.itemDataValue)
            estimatedsub = estimated.subclusters[0]
            # rotate, translate and optionally scale
            # w.r.t. original positions (pos)
            align_clusters(Positions.create(self.net.pos), estimated, True)
            net = self.net.subnetwork(estimatedsub.keys(), pos=estimatedsub)

            self.draw_network(net=net, drawMessages=False)

            edge_pos = numpy.asarray([(self.net.pos[node], estimatedsub[node][:2])
                                       for node in net])
            error_collection = LineCollection(edge_pos, colors='r')
            self.axes.add_collection(error_collection)

            rms = get_rms(self.net.pos, estimated, scale=False)
            self.update_log('rms = %.3f' % rms)
            self.update_log('localized = %.2f%% (%d/%d)' %
                            (len(estimatedsub) * 1. / len(self.net.pos) * 100,
                            len(estimatedsub), len(self.net.pos)))

    def on_actionSaveNetwork_triggered(self, *args):
        default_filetype = 'gz'
        start = datetime.now().strftime('%Y%m%d') + default_filetype

        filters = ['Network pickle (*.gz)', 'All files (*)']
        selectedFilter = 'Network pickle (gz)'
        filters = ';;'.join(filters)

        fname = QFileDialog.getSaveFileName(self, "Choose a filename",
                                            start, filters, selectedFilter)[0]
        if fname:
            try:
                write_pickle(self.net, fname)
            except Exception, e:
                QMessageBox.critical(
                    self, "Error saving file", str(e),
                    QMessageBox.Ok, QMessageBox.NoButton)
            else:
                self.set_title(fname)
Esempio n. 7
0
class SimulationGui(QMainWindow):
    def __init__(self, net=None, parent=None, fname=None):
        QMainWindow.__init__(self)

        self.ui = Ui_SimulationWindow()
        self.ui.setupUi(self)

        if fname:
            self.set_title(fname)

        # context menu
        self.ui.nodeInspector.addAction(self.ui.actionCopyInspectorData)
        self.ui.nodeInspector.addAction(self.ui.actionShowLocalizedSubclusters)
        # callbacks
        self.ui.actionCopyInspectorData.activated\
            .connect(self.on_actionCopyInspectorData_triggered)
        self.ui.actionShowLocalizedSubclusters.activated\
            .connect(self.on_actionShowLocalizedSubclusters_triggered)

        self.dpi = 72
        # take size of networDisplayWidget
        self.fig = Figure((700 / self.dpi, 731 / self.dpi),
                          self.dpi,
                          facecolor='0.9')
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.ui.networkDisplayWidget)
        self.nav = NavigationToolbar(self.canvas,
                                     self.ui.networkDisplayWidget,
                                     coordinates=True)
        self.nav.setGeometry(QRect(0, 0, 651, 36))
        self.nav.setIconSize(QSize(24, 24))

        self.axes = self.fig.add_subplot(111)
        # matplotlib.org/api/figure_api.html#matplotlib.figure.SubplotParams
        self.fig.subplots_adjust(left=0.03, right=0.99, top=0.92)

        if net:
            self.init_sim(net)

        self.connect(self.ui.showNodes, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showEdges, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showMessages, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.showLabels, SIGNAL('stateChanged(int)'),
                     self.refresh_visibility)
        self.connect(self.ui.redrawNetworkButton, SIGNAL('clicked(bool)'),
                     self.redraw)
        self.connect(self.ui.treeGroupBox, SIGNAL('toggled(bool)'),
                     self.refresh_visibility)
        self.connect(self.ui.treeKey, SIGNAL('textEdited(QString)'),
                     self.redraw)
        self.connect(self.ui.propagationError, SIGNAL('toggled(bool)'),
                     self.refresh_visibility)
        self.connect(self.ui.locKey, SIGNAL('textEdited(QString)'),
                     self.redraw)
        # callbacks
        self.ui.actionOpenNetwork.activated\
            .connect(self.on_actionOpenNetwork_triggered)
        self.ui.actionSaveNetwork.activated\
            .connect(self.on_actionSaveNetwork_triggered)
        self.ui.actionRun.activated.connect(self.on_actionRun_triggered)
        self.ui.actionStep.activated.connect(self.on_actionStep_triggered)
        self.ui.actionReset.activated.connect(self.on_actionReset_triggered)

        self.canvas.mpl_connect('pick_event', self.on_pick)

    def handleInspectorMenu(self, pos):
        menu = QMenu()
        menu.addAction('Add')
        menu.addAction('Delete')
        menu.exec_(QCursor.pos())

    def init_sim(self, net):
        self.net = net
        self.sim = Simulation(net)
        self.connect(self.sim, SIGNAL("redraw()"), self.redraw)
        self.connect(self.sim, SIGNAL("updateLog(QString)"), self.update_log)
        self.redraw()

    def update_log(self, text):
        """ Add item to list widget """
        print "Add: " + text
        self.ui.logListWidget.insertItem(0, text)
        # self.ui.logListWidget.sortItems()

    def redraw(self):
        self.refresh_network_inspector()
        self.draw_network()
        self.reset_zoom()
        self.refresh_visibility()

    def draw_network(self,
                     net=None,
                     clear=True,
                     subclusters=None,
                     drawMessages=True):
        if not net:
            net = self.net
        currentAlgorithm = self.net.get_current_algorithm()
        if clear:
            self.axes.clear()
        self.axes.imshow(net.environment.im,
                         vmin=0,
                         cmap='binary_r',
                         origin='lower')

        self.draw_tree(str(self.ui.treeKey.text()), net)
        self.draw_edges(net)
        self.draw_propagation_errors(str(self.ui.locKey.text()), net)
        if subclusters:
            node_colors = self.get_node_colors(net, subclusters=subclusters)
        else:
            node_colors = self.get_node_colors(net, algorithm=currentAlgorithm)
        self.node_collection = self.draw_nodes(net, node_colors)
        if drawMessages:
            self.draw_messages(net)
        self.draw_labels(net)
        self.drawnNet = net
        step_text = ' (step %d)' % self.net.algorithmState['step'] \
                    if isinstance(currentAlgorithm, NodeAlgorithm) else ''
        self.axes.set_title(
            (currentAlgorithm.name if currentAlgorithm else '') + step_text)

        self.refresh_visibility()
        # To save multiple figs of the simulation uncomment next two lines:
        #self.fig.savefig('network-alg-%d-step-%d.png' %
        #                 (self.net.algorithmState['index'], self.net.algorithmState['step']))

    def draw_nodes(self, net=None, node_colors={}, node_radius={}):
        if not net:
            net = self.net
        if type(node_colors) == str:
            node_colors = {node: node_colors for node in net.nodes()}
        nodeCircles = []
        for n in net.nodes():
            c = NodeCircle(tuple(net.pos[n]),
                           node_radius.get(n, 8.0),
                           color=node_colors.get(n, 'r'),
                           ec='k',
                           lw=1.0,
                           ls='solid',
                           picker=3)
            nodeCircles.append(c)
        node_collection = PatchCollection(nodeCircles, match_original=True)
        node_collection.set_picker(3)
        self.axes.add_collection(node_collection)
        return node_collection

    def get_node_colors(self,
                        net,
                        algorithm=None,
                        subclusters=None,
                        drawLegend=True):
        COLORS = 'rgbcmyw' * 100
        node_colors = {}
        if algorithm:
            color_map = {}
            if isinstance(algorithm, NodeAlgorithm):
                for ind, status in enumerate(algorithm.STATUS.keys()):
                    if status == 'IDLE':
                        color_map.update({status: 'k'})
                    else:
                        color_map.update({status: COLORS[ind]})
                if drawLegend:
                    proxy = []
                    labels = []
                    for status, color in color_map.items():
                        proxy.append(
                            Circle((0, 0),
                                   radius=8.0,
                                   color=color,
                                   ec='k',
                                   lw=1.0,
                                   ls='solid'))
                        labels.append(status)
                    self.fig.legends = []
                    self.fig.legend(proxy,
                                    labels,
                                    loc=8,
                                    prop={'size': '10.0'},
                                    ncol=len(proxy),
                                    title='Statuses for %s:' % algorithm.name)
            for n in net.nodes():
                if n.status == '' or not n.status in color_map.keys():
                    node_colors[n] = 'r'
                else:
                    node_colors[n] = color_map[n.status]
        elif subclusters:
            for i, sc in enumerate(subclusters):
                for n in sc:
                    if n in node_colors:
                        node_colors[n] = 'k'
                    else:
                        node_colors[n] = COLORS[i]
        return node_colors

    def draw_edges(self, net=None):
        if not net:
            net = self.net
        self.edge_collection = nx.draw_networkx_edges(net,
                                                      net.pos,
                                                      alpha=0.6,
                                                      edgelist=None,
                                                      ax=self.axes)

    def draw_messages(self, net=None):
        if not net:
            net = self.net
        self.messages = []
        msgCircles = []
        for node in net.nodes():
            for msg in node.outbox:
                # broadcast
                if msg.destination is None:
                    for neighbor in net.adj[node].keys():
                        nbr_msg = msg.copy()
                        nbr_msg.destination = neighbor
                        c = MessageCircle(nbr_msg,
                                          net,
                                          'out',
                                          3.0,
                                          lw=0,
                                          picker=3,
                                          zorder=3,
                                          color='b')
                        self.messages.append(nbr_msg)
                        msgCircles.append(c)
                else:
                    c = MessageCircle(msg,
                                      net,
                                      'out',
                                      3.0,
                                      lw=0,
                                      picker=3,
                                      zorder=3,
                                      color='b')
                    self.messages.append(msg)
                    msgCircles.append(c)
            for msg in node.inbox:
                c = MessageCircle(msg,
                                  net,
                                  'in',
                                  3.0,
                                  lw=0,
                                  picker=3,
                                  zorder=3,
                                  color='g')
                self.messages.append(msg)
                msgCircles.append(c)
        if self.messages:
            self.message_collection = PatchCollection(msgCircles,
                                                      match_original=True)
            self.message_collection.set_picker(3)
            self.axes.add_collection(self.message_collection)

    def draw_labels(self, net=None):
        if not net:
            net = self.net
        label_pos = {}
        for n in net.nodes():
            label_pos[n] = net.pos[n].copy() + 10
        self.label_collection = nx.draw_networkx_labels(net,
                                                        label_pos,
                                                        labels=net.labels,
                                                        ax=self.axes)

    def refresh_visibility(self):
        try:
            self.node_collection.set_visible(self.ui.showNodes.isChecked())
            self.edge_collection.set_visible(self.ui.showEdges.isChecked())
            for label in self.label_collection.values():
                label.set_visible(self.ui.showLabels.isChecked())
            self.tree_collection.set_visible(self.ui.treeGroupBox.isChecked())
            self.ini_error_collection.set_visible(self.ui.propagationError\
                                                    .isChecked())
            self.propagation_error_collection.set_visible(self.ui\
                                                          .propagationError\
                                                          .isChecked())
            # sould be last, sometimes there are no messages
            self.message_collection.set_visible(self.ui.showMessages\
                                                    .isChecked())
        except AttributeError:
            print 'Refresh visibility warning'
        self.canvas.draw()

    def draw_tree(self, treeKey, net=None):
        """
        Show tree representation of network.

        Attributes:
            treeKey (str):
                key in nodes memory (dictionary) where tree data is stored
                storage format can be a list off tree neighbors or a dict:
                    {'parent': parent_node,
                     'children': [child_node1, child_node2 ...]}
        """
        if not net:
            net = self.net
        treeNet = net.get_tree_net(treeKey)
        if treeNet:
            self.tree_collection = draw_networkx_edges(treeNet,
                                                       treeNet.pos,
                                                       treeNet.edges(),
                                                       width=1.8,
                                                       alpha=0.6,
                                                       ax=self.axes)

    def draw_propagation_errors(self, locKey, net):
        SCALE_FACTOR = 0.6
        if not net:
            net = self.net
        if any([not locKey in node.memory for node in net.nodes()]):
            self.propagation_error_collection = []
            self.ini_error_collection = []
            return

        rms = {'iniRms': {}, 'stitchRms': {}}
        for node in net.nodes():
            rms['iniRms'][node] = get_rms(self.net.pos,
                                          (node.memory['iniLocs']), True) * \
                                          SCALE_FACTOR
            rms['stitchRms'][node] = get_rms(self.net.pos, node.memory[locKey],
                                             True) * SCALE_FACTOR
        self.propagation_error_collection = \
                            self.draw_nodes(net=net, node_colors='g',
                                            node_radius=rms['stitchRms'])
        self.ini_error_collection = self.draw_nodes(net=net,
                                                    node_colors='b',
                                                    node_radius=rms['iniRms'])

    def reset_zoom(self):
        self.axes.set_xlim((0, self.net.environment.im.shape[1]))
        self.axes.set_ylim((0, self.net.environment.im.shape[0]))

    def set_title(self, fname):
        new = ' - '.join([str(self.windowTitle()).split(' - ')[0], str(fname)])
        self.setWindowTitle(new)

    def refresh_network_inspector(self):
        niModel = DictionaryTreeModel(dic=self.net.get_dic())
        self.ui.networkInspector.setModel(niModel)
        self.ui.networkInspector.expandToDepth(0)

    """
    Callbacks
    """

    def on_actionRun_triggered(self):
        self.ui.logListWidget.clear()
        print 'running ...',
        self.sim.stepping = True
        self.sim.run_all()

    def on_actionStep_triggered(self):
        print 'next step ...',
        self.sim.run(self.ui.stepSize.value())

    def on_actionReset_triggered(self):
        print 'reset ...',
        self.sim.reset()
        self.redraw()

    def on_actionCopyInspectorData_triggered(self):
        string = 'Node inspector data\n-------------------'
        # raise()
        for qModelIndex in self.ui.nodeInspector.selectedIndexes():
            string += '\n' + qModelIndex.internalPointer().toString('    ')

        clipboard = app.clipboard()
        clipboard.setText(string)
        event = QEvent(QEvent.Clipboard)
        app.sendEvent(clipboard, event)

    def on_actionShowLocalizedSubclusters_triggered(self):
        if len(self.ui.nodeInspector.selectedIndexes()) == 1:
            qModelIndex = self.ui.nodeInspector.selectedIndexes()[0]
            treeItem = qModelIndex.internalPointer()
            assert (isinstance(treeItem.itemDataValue, Positions))

            estimated = deepcopy(treeItem.itemDataValue)
            estimatedsub = estimated.subclusters[0]
            # rotate, translate and optionally scale
            # w.r.t. original positions (pos)
            align_clusters(Positions.create(self.net.pos), estimated, True)
            net = self.net.subnetwork(estimatedsub.keys(), pos=estimatedsub)

            self.draw_network(net=net, drawMessages=False)

            edge_pos = numpy.asarray([
                (self.net.pos[node], estimatedsub[node][:2]) for node in net
            ])
            error_collection = LineCollection(edge_pos, colors='r')
            self.axes.add_collection(error_collection)

            rms = get_rms(self.net.pos, estimated, scale=False)
            self.update_log('rms = %.3f' % rms)
            self.update_log('localized = %.2f%% (%d/%d)' %
                            (len(estimatedsub) * 1. / len(self.net.pos) * 100,
                             len(estimatedsub), len(self.net.pos)))

    def on_actionSaveNetwork_triggered(self, *args):
        default_filetype = 'gz'
        start = datetime.now().strftime('%Y%m%d') + default_filetype

        filters = ['Network pickle (*.gz)', 'All files (*)']
        selectedFilter = 'Network pickle (gz)'
        filters = ';;'.join(filters)

        fname = QFileDialog.getSaveFileName(self, "Choose a filename", start,
                                            filters, selectedFilter)[0]
        if fname:
            try:
                write_pickle(self.net, fname)
            except Exception, e:
                QMessageBox.critical(self, "Error saving file", str(e),
                                     QMessageBox.Ok, QMessageBox.NoButton)
            else:
                self.set_title(fname)
Esempio n. 8
0
class MplWidget(QtGui.QWidget):
    """Widget defined in Qt Designer"""
    def __init__(self, tools, toolbar=True, menu=True, parent=None):
        # initialization of Qt MainWindow widget
        QtGui.QWidget.__init__(self, parent)
        # set the canvas to the Matplotlib widget
        self.canvas = MplCanvas()
        # create a vertical box layout
        self.layout = QtGui.QVBoxLayout()
        # add mpl widget to layout
        self.layout.addWidget(self.canvas)
        # reference to toolsFrame
        self.tool = tools

        if toolbar:
            # add navigation toolbar to layout
            self.toolbar = NavigationToolbar(self.canvas, self)
            self.layout.addWidget(self.toolbar)
            # enable hover event handling
            self.setAttribute(Qt.WA_Hover)
            # create and install event filter
            self.filter = Filter(self)
            self.installEventFilter(self.filter)
            # hide toolbar
            self.initComponents()
        else:
            self.toolbar = None

        # set the layout to th vertical box
        self.setLayout(self.layout)
        # active lines list
        self.lines = []
        # legend
        self.legend = None
        # autoscale
        self.canvas.ax.autoscale_view(True, True, True)

        if menu:
            # setup context menu
            self.setContextMenuPolicy(Qt.ActionsContextMenu)
            self.initActions()
            self.alwaysAutoScale.setChecked(True)

    #-------------- initialization ---------------#
    def initComponents(self):
        if self.toolbar is not None:
            self.toolbar.hide()
            self.newIcons()

    def initActions(self):
        # toolbar
        self.toggleLegendAction = QtGui.QAction(QtGui.QIcon(RES + ICONS + LEGEND), 'Toggle legend',
                                     self, triggered=self.toggleLegend)
        self.toggleLegendAction.setCheckable(True)
        if self.toolbar is not None:
            self.toolbar.addAction(self.toggleLegendAction)

        # context menu
        self.addAction(self.toggleLegendAction)
        self.addAction(QtGui.QAction(QtGui.QIcon(RES + ICONS + COPY),'Copy data to table',
                                     self, triggered=self.toTable))
        self.addAction(QtGui.QAction(QtGui.QIcon(RES + ICONS + GRAPH),'Plot data in tools',
                                     self, triggered=self.toGraphTool))
        self.addAction(QtGui.QAction(QtGui.QIcon(RES + ICONS + SCALE), 'Autoscale',
                                     self, triggered=self.updateScale))

        self.alwaysAutoScale = QtGui.QAction('Scale on update', self)
        self.alwaysAutoScale.setCheckable(True)

        self.selectLinesMenu = QtGui.QMenu()
        self.selectLines = (QtGui.QAction('Plots', self))
        self.selectLines.setMenu(self.selectLinesMenu)

        aSep = QtGui.QAction('', self)
        aSep.setSeparator(True)
        self.addAction(aSep)
        self.addAction(self.selectLines)
        self.addAction(self.alwaysAutoScale)

    def newIcons(self):
        for position in range(0, self.toolbar.layout().count()):
            widget = self.toolbar.layout().itemAt(position).widget()
            if isinstance(widget, QtGui.QToolButton):
                icon = QtGui.QIcon(RES + ICONS + TOOLBAR_ICONS[position])
                self.toolbar.layout().itemAt(position).widget().setIcon(icon)

        self.toolbar.setIconSize(QSize(ICO_GRAPH, ICO_GRAPH))

    def resetGraphicEffect(self):
        if self.graphicsEffect() is not None:
            self.graphicsEffect().setEnabled(False)

    #------------- plotting methods ---------------#

    ## Hides axes in widget.
    #  @param axes Widget axes form canvas.
    @staticmethod
    def hideAxes(axes):
        axes.get_xaxis().set_visible(False)
        axes.get_yaxis().set_visible(False)

    ## Clears widget canvas, removing all data and clearing figure.
    #  @param repaint_axes Add standard plot after clearing figure.
    def clearCanvas(self, repaint_axes=True):
        self.canvas.ax.clear()
        self.canvas.fig.clear()
        if repaint_axes:
            self.canvas.ax = self.canvas.fig.add_subplot(111)

    ## Update existing data or plot anew.
    #  @param data List or array to plot/update.
    #  @param line Which line (by index) to update (if any).
    #  @param label Data label (new or existing).
    #  @param style Line style (solid, dashed, dotted).
    #  @param color Line color.
    def updatePlot(self, data, line=0, label=None, style='solid', color=None):
        if not self.canvas.ax.has_data():
            if label is not None:
                if color is not None:
                    self.lines = self.canvas.ax.plot(data, label=label, linestyle=style, color=color)
                else:
                    self.lines = self.canvas.ax.plot(data, label=label, linestyle=style)
            else:
                if color is not None:
                    self.lines = self.canvas.ax.plot(data, linestyle=style, color=color)
                else:
                    self.lines = self.canvas.ax.plot(data, linestyle=style)
        else:
            if not self.lines:
                self.lines = self.canvas.ax.get_lines()
            if label is not None:
                if label not in [l._label for l in self.lines]:
                    if color is not None:
                        self.lines.extend(self.canvas.ax.plot(data, label=label, linestyle=style, color=color))
                    else:
                        self.lines.extend(self.canvas.ax.plot(data, label=label, linestyle=style))
                    line = len(self.lines) - 1
                else:
                    line = [l._label for l in self.lines].index(label)
            line_to_update = self.lines[line]
            if len(data) != len(line_to_update._x):
                # x, y ~ data in y
                line_to_update.set_data(np.arange(len(data)), data)
            else:
                # in case data length stays the same
                line_to_update.set_data(line_to_update._x, data)
            self.canvas.draw()

        self.updateLegend()
        self.updateLinesSubmenu()

        if self.alwaysAutoScale.isChecked():
            self.updateScale()

    ## Plots scalogram for wavelet decomposition.
    #  @param data Wavelet coefficients in matrix.
    #  @param top Axis position.
    #  @param colorbar Shows colorbar for data levels.
    #  @param power Scales resulting graph by power of 2.
    def scalogram(self, data, top=True, colorbar=True, power=False):
#        self.resetGraphicEffect()
        self.clearCanvas()

        x = np.arange(len(data[0]))
        y = np.arange(len(data))

        if power:
            contour = self.canvas.ax.contourf(x, y, np.abs(data) ** 2)
        else:
            contour = self.canvas.ax.contourf(x, y, np.abs(data))

        if colorbar:
            self.canvas.fig.colorbar(contour, ax=self.canvas.ax, orientation='vertical', format='%2.1f')

        if top:
            self.canvas.ax.set_ylim((y[-1], y[0]))
        else:
            self.canvas.ax.set_ylim((y[0], y[-1]))

        self.canvas.ax.set_xlim((x[0], x[-1]))
#        self.canvas.ax.set_ylabel('scales')

        self.canvas.draw()

    ## Plots list of arrays with shared x/y axes.
    #  @param data Arrays to plot (list or matrix).
    def multiline(self, data):
#        self.resetGraphicEffect()
        # abscissa
        axprops = dict(yticks=[])
        # ordinate
        yprops = dict(rotation=0,
              horizontalalignment='right',
              verticalalignment='center',
              x=-0.01)

        # level/figure ratio
        ratio = 1. / len(data)

        # positioning (fractions of total figure)
        left = 0.1
        bottom = 1.0
        width = 0.85
        space = 0.035
        height = ratio - space

        # legend
        label = 'Lvl %d'
        i = 0

        bottom -= height
        ax = self.canvas.fig.add_axes([left, bottom - space, width, height], **axprops)

        ax.plot(data[i])
        setp(ax.get_xticklabels(), visible=False)
        ax.set_ylabel(label % i, **yprops)
        i += 1

        axprops['sharex'] = ax
        axprops['sharey'] = ax

        while i < len(data):
            bottom -= height
            ax = self.canvas.fig.add_axes([left, bottom, width, height], **axprops)
            ax.plot(data[i], label='Lvl' + str(i))
            ax.set_ylabel(label % i, **yprops)
            i += 1
            if i != len(data):
                setp(ax.get_xticklabels(), visible=False)

    #----------------- actions -----------------#
    def getTopParent(self):
        widget = self.parentWidget()
        while True:
            if widget.parentWidget() is None:
                return widget
            else:
                widget = widget.parentWidget()

    def toggleLegend(self):
        self.updateLegend()

    def updateLegend(self):
        #NB: sometimes induces random exceptions from legend.py -> offsetbox.py
        try:
            prop = font_manager.FontProperties(size=11)
            self.legend = DraggableLegend(self.canvas.ax.legend(fancybox=True, shadow=True, prop=prop))
            if self.toggleLegendAction.isChecked():
                    self.legend.legend.set_visible(True)
            else:
                    self.legend.legend.set_visible(False)
            self.canvas.draw()
        except Exception, e:
            pass
class Sorting_Quality_Widget(QtGui.QWidget):
    def __init__(self, h5file=None):
        QtGui.QWidget.__init__(self)
        # define a right side control panel
        gLay = QtGui.QGridLayout()
        row = 0

        if isinstance(h5file, tables.file.File):
            self.h5file = h5file

        elif isinstance(h5file, str):
            self.h5file = str(
                QtGui.QFileDialog.getOpenFileName(caption='select an h5 file',
                                                  filter='*.h5'))
            if self.h5file:
                self.h5file = tables.openFile(self.h5file, 'r')

        elif not h5file:
            self.loadH5FileBtn = QtGui.QPushButton('Load H5File')
            self.loadH5FileBtn.clicked.connect(self.loadH5FileProc)
            gLay.addWidget(self.loadH5FileBtn, row, 0, 1, 2)
            row += 1
            self.setWindowTitle('Spike Sorting Quality Explorer')

        self.FirstUnitCombo = QtGui.QComboBox()
        gLay.addWidget(self.FirstUnitCombo, row, 0, 1, 2)
        row += 1

        self.selectBtn = QtGui.QPushButton('Select None')
        self.selectBtn.clicked.connect(self.selectProc)
        self.selectBtn.setCheckable(True)
        gLay.addWidget(self.selectBtn, row, 0)

        self.plotXCorrBtn = QtGui.QPushButton('Plot xCorr')
        self.plotXCorrBtn.clicked.connect(self.plotXCorr)
        gLay.addWidget(self.plotXCorrBtn, row, 1)
        row += 1

        self.UnitsSelector = QtGui.QTableWidget(0, 1)
        self.UnitsSelector.verticalHeader().setVisible(False)
        self.UnitsSelector.horizontalHeader().setVisible(False)
        self.UnitsSelector.setColumnWidth(0, 200)
        gLay.addWidget(self.UnitsSelector, row, 0, 1, 2)
        row += 1

        mainLay = QtGui.QHBoxLayout(self)
        mainLay.addLayout(gLay)

        # define a left side figure
        vLay = QtGui.QVBoxLayout()
        self.mainFig = MplWidget(self)
        self.mainFig.figure.set_facecolor('k')
        self.ntb = NavToolbar(self.mainFig, self)
        self.ntb.setIconSize(QtCore.QSize(15, 15))
        vLay.addWidget(self.mainFig)
        vLay.addWidget(self.ntb)

        mainLay.addLayout(vLay)

        self.show()

        self.UnitChecks = []

    def loadH5FileProc(self):

        if hasattr(self, 'h5file') and\
           isinstance(self.h5file, tables.file.File) and\
           self.h5file.isopen:
            self.h5file.close()

        self.h5file = str(
            QtGui.QFileDialog.getOpenFileName(
                caption='select an h5 file',
                filter='*.h5',
                directory='/home/hachi/Desktop/Data/Recording'))
        if self.h5file:
            self.h5file = tables.openFile(self.h5file, 'r')
            self.updateUnitsList()

    def selectProc(self):
        if not self.selectBtn.isChecked():
            self.selectBtn.setText('Select None')
            for k in self.UnitChecks:
                k.setChecked(True)
        else:
            self.selectBtn.setText('Select All')
            for k in self.UnitChecks:
                k.setChecked(False)

    '''def getUnitsProc(self):
        if not hasattr(self, 'h5file'): return
        if h5file,close(): return
        
        try:
            nodes = self.h5file.listNodes('/Spikes')
        except:
            print 'There is a problem with the H5File'
            
        count = 0
        units = []
        for group in nodes:
            for member in group:
                if member._v_name.find('Unit') != -1:
                    units.append(member)
                    self.UnitsSelector.insertRow(count)
                    count += 1'''

    def updateUnitsList(self):

        if not hasattr(self, 'h5file'): return

        # clear the FirstUnit Selector
        self.FirstUnitCombo.clear()

        # clean the table, kill the checkboxes
        self.UnitsSelector.setRowCount(0)
        for k in self.UnitChecks:
            k.deleteLater()

        try:
            nodes = self.h5file.listNodes('/Spikes')
        except:
            print 'There is a problem with the H5File'

        count = 0
        self.UnitChecks = []
        self.unitIDs = []
        for group in nodes:
            for member in group:
                if member._v_name.find('Unit') != -1:
                    self.UnitsSelector.insertRow(count)
                    unitID = group._v_name + ' ' + member._v_name
                    self.UnitChecks.append(QtGui.QCheckBox(unitID))
                    self.UnitsSelector.setCellWidget(count, 0,
                                                     self.UnitChecks[-1])
                    self.UnitsSelector.setRowHeight(count, 20)
                    self.FirstUnitCombo.addItem(unitID)
                    self.unitIDs.append(unitID)
                    count += 1

    def plotXCorr(self):

        self.mainFig.figure.clf()
        baseUnit = str(self.FirstUnitCombo.currentText())
        chan = baseUnit[0:8]
        unit = baseUnit[9:]

        #get the timestamps for that unit
        baseNode = self.h5file.getNode('/Spikes/' + chan)
        TS = baseNode.TimeStamp.read()
        baseUnitTS = baseNode.__getattr__(unit).Indx.read()
        baseUnitTS = TS[baseUnitTS]

        # check wich units to plot
        units2Plot = []
        for k in range(self.UnitsSelector.rowCount()):
            if self.UnitChecks[k].isChecked():
                units2Plot.append(str(self.UnitChecks[k].text()))

        # create a grid of subplots of 8 columns by n rows
        nRows = np.ceil(len(units2Plot) / 8.0)

        ylim = 0
        axes_list = []
        # iterate over the list of units and plot the crosscorrelation
        for j, k in enumerate(units2Plot):
            axes_list.append(self.mainFig.figure.add_subplot(nRows, 8, j + 1))

            chan = k[0:8]
            unit = k[9:]
            axes_list[-1].set_title(chan + ' ' + unit, color='w')
            #get the timestamps for that unit
            node = self.h5file.getNode('/Spikes/' + chan)
            TS = node.TimeStamp.read()
            UnitTS = node.__getattr__(unit).Indx.read()
            UnitTS = TS[UnitTS]

            r = []
            bin_size = 1
            #r, t = cross_correlation(baseUnitTS, UnitTS, bins = 20, win_lag = [-10, 10])
            for ts in baseUnitTS:
                t = UnitTS - ts
                r.extend(t[(t > -20) & (t < 20)])
            r, t = np.histogram(r, bins=int(40 / bin_size))
            #indx = np.flatnonzero((t>=-200) & (t<=200))
            axes_list[-1].bar(t[:-1], r, edgecolor='none', color='w')
            #ax.plot(t[indx], r[indx], 'w')
            axes_list[-1].set_xlim(-20, 20)
            ylim = max([ylim, max(r)])

            # change the color of the axes to white
            axes_list[-1].tick_params(axis='x', colors='w')
            axes_list[-1].tick_params(axis='y', colors='w')
            axes_list[-1].set_axis_bgcolor('none')

            for key, spine in axes_list[-1].spines.iteritems():
                spine.set_color('w')

        #for ax in axes_list:
        #    ax.set_ylim(0, ylim)

        self.mainFig.figure.tight_layout()
        self.mainFig.figure.canvas.draw()
class Sorting_Quality_Widget(QtGui.QWidget):
    
    def __init__(self, h5file = None):
        QtGui.QWidget.__init__(self)        
        # define a right side control panel
        gLay = QtGui.QGridLayout()
        row = 0
        
        if isinstance(h5file, tables.file.File):
            self.h5file = h5file
            
        elif isinstance(h5file, str):
            self.h5file = str(QtGui.QFileDialog.getOpenFileName(caption='select an h5 file',
                                                                filter='*.h5'))
            if self.h5file:
                self.h5file = tables.openFile(self.h5file, 'r')
        
        elif not h5file:
            self.loadH5FileBtn = QtGui.QPushButton('Load H5File')
            self.loadH5FileBtn.clicked.connect(self.loadH5FileProc)
            gLay.addWidget(self.loadH5FileBtn, row, 0, 1, 2)
            row += 1
            self.setWindowTitle('Spike Sorting Quality Explorer')
                
        self.FirstUnitCombo = QtGui.QComboBox()
        gLay.addWidget(self.FirstUnitCombo, row, 0, 1, 2)
        row += 1
        
        self.selectBtn = QtGui.QPushButton('Select None')
        self.selectBtn.clicked.connect(self.selectProc)
        self.selectBtn.setCheckable(True)
        gLay.addWidget(self.selectBtn, row, 0)
        
        self.plotXCorrBtn = QtGui.QPushButton('Plot xCorr')
        self.plotXCorrBtn.clicked.connect(self.plotXCorr)
        gLay.addWidget(self.plotXCorrBtn, row, 1)
        row += 1
        
        self.UnitsSelector = QtGui.QTableWidget(0, 1)
        self.UnitsSelector.verticalHeader().setVisible(False)
        self.UnitsSelector.horizontalHeader().setVisible(False)
        self.UnitsSelector.setColumnWidth(0, 200)
        gLay.addWidget(self.UnitsSelector, row, 0, 1, 2)
        row += 1
        
        mainLay = QtGui.QHBoxLayout(self)
        mainLay.addLayout(gLay)
        
        # define a left side figure
        vLay = QtGui.QVBoxLayout()
        self.mainFig = MplWidget(self)
        self.mainFig.figure.set_facecolor('k')
        self.ntb = NavToolbar(self.mainFig, self)
        self.ntb.setIconSize(QtCore.QSize(15, 15))
        vLay.addWidget(self.mainFig)
        vLay.addWidget(self.ntb)
    
        mainLay.addLayout(vLay)        
        
        self.show()
        
        self.UnitChecks = []
    
    def loadH5FileProc(self):
        
        if hasattr(self, 'h5file') and\
           isinstance(self.h5file, tables.file.File) and\
           self.h5file.isopen:
               self.h5file.close()
            
        self.h5file = str(QtGui.QFileDialog.getOpenFileName(caption='select an h5 file',
                                                            filter='*.h5',
                                                            directory = '/home/hachi/Desktop/Data/Recording'))
        if self.h5file:
            self.h5file = tables.openFile(self.h5file, 'r')
            self.updateUnitsList()
    
    def selectProc(self):
        if not self.selectBtn.isChecked():
            self.selectBtn.setText('Select None')
            for k in self.UnitChecks:
                k.setChecked(True)
        else:
            self.selectBtn.setText('Select All')
            for k in self.UnitChecks:
                k.setChecked(False)
    
    '''def getUnitsProc(self):
        if not hasattr(self, 'h5file'): return
        if h5file,close(): return
        
        try:
            nodes = self.h5file.listNodes('/Spikes')
        except:
            print 'There is a problem with the H5File'
            
        count = 0
        units = []
        for group in nodes:
            for member in group:
                if member._v_name.find('Unit') != -1:
                    units.append(member)
                    self.UnitsSelector.insertRow(count)
                    count += 1'''

    def updateUnitsList(self):
        
        if not hasattr(self, 'h5file'): return
        
        # clear the FirstUnit Selector
        self.FirstUnitCombo.clear()
        
        # clean the table, kill the checkboxes
        self.UnitsSelector.setRowCount(0)
        for k in self.UnitChecks: k.deleteLater()
            
        try:
            nodes = self.h5file.listNodes('/Spikes')
        except:
            print 'There is a problem with the H5File'
        
        count = 0
        self.UnitChecks = []
        self.unitIDs = []
        for group in nodes:
            for member in group:
                if member._v_name.find('Unit') != -1:
                    self.UnitsSelector.insertRow(count)
                    unitID = group._v_name + ' ' + member._v_name
                    self.UnitChecks.append(QtGui.QCheckBox(unitID))
                    self.UnitsSelector.setCellWidget(count, 0, self.UnitChecks[-1])
                    self.UnitsSelector.setRowHeight(count, 20)
                    self.FirstUnitCombo.addItem(unitID)
                    self.unitIDs.append(unitID)
                    count += 1
    
    def plotXCorr(self):
                
        self.mainFig.figure.clf()
        baseUnit = str(self.FirstUnitCombo.currentText())
        chan = baseUnit[0:8]
        unit = baseUnit[9:]
        
        #get the timestamps for that unit
        baseNode = self.h5file.getNode('/Spikes/'+chan)
        TS = baseNode.TimeStamp.read()
        baseUnitTS = baseNode.__getattr__(unit).Indx.read()
        baseUnitTS = TS[baseUnitTS]
        
        # check wich units to plot
        units2Plot = []
        for k in range(self.UnitsSelector.rowCount()):
            if self.UnitChecks[k].isChecked():
                units2Plot.append( str(self.UnitChecks[k].text()) )
                
        # create a grid of subplots of 8 columns by n rows
        nRows = np.ceil(len(units2Plot)/8.0)
        
        ylim = 0
        axes_list = []
        # iterate over the list of units and plot the crosscorrelation
        for j, k in enumerate(units2Plot):
            axes_list.append(self.mainFig.figure.add_subplot(nRows, 8, j+1))
            
            chan = k[0:8]
            unit = k[9:]
            axes_list[-1].set_title(chan+' '+unit, color = 'w')
            #get the timestamps for that unit
            node = self.h5file.getNode('/Spikes/'+chan)
            TS = node.TimeStamp.read()
            UnitTS = node.__getattr__(unit).Indx.read()
            UnitTS = TS[UnitTS]
            
            r = []
            bin_size = 1
            #r, t = cross_correlation(baseUnitTS, UnitTS, bins = 20, win_lag = [-10, 10])
            for ts in baseUnitTS:
                t = UnitTS - ts
                r.extend(t[(t > -20) & (t < 20)])
            r, t = np.histogram(r, bins = int(40/bin_size) )
            #indx = np.flatnonzero((t>=-200) & (t<=200))
            axes_list[-1].bar(t[:-1], r, edgecolor = 'none', color = 'w')
            #ax.plot(t[indx], r[indx], 'w')
            axes_list[-1].set_xlim(-20, 20)
            ylim = max([ylim, max(r)])
            
            # change the color of the axes to white
            axes_list[-1].tick_params(axis = 'x', colors = 'w')
            axes_list[-1].tick_params(axis = 'y', colors = 'w')
            axes_list[-1].set_axis_bgcolor('none')
            
            for key, spine in axes_list[-1].spines.iteritems():
                spine.set_color('w')

        #for ax in axes_list:
        #    ax.set_ylim(0, ylim)
            
        self.mainFig.figure.tight_layout()
        self.mainFig.figure.canvas.draw()