Exemple #1
0
class CutScreen(QWidget):
    def __init__(self,
                 parent,
                 xdata=[],
                 ydata=[],
                 xfit=[],
                 yfit=[],
                 sizex=100,
                 sizey=100):

        QWidget.__init__(self, parent)

        self.sizex = sizex
        self.sizey = sizey
        #self.setMinimumSize(self.sizex,self.sizey)

        self.xdata = xdata
        self.ydata = ydata
        self.xfit = xfit
        self.yfit = yfit

        self.plot = None
        self.data_cut = None
        self.fit_cut = None

        self.setup_widget()

    def setup_widget(self):

        self.plot = CurvePlot(self)
        self.data_cut = make.curve(self.xdata, self.ydata, color="b")
        self.data_fit = make.curve(self.xfit, self.yfit, color="r")

        self.plot.add_item(self.data_cut)
        self.plot.add_item(self.data_fit)

        vlayout = QVBoxLayout()
        vlayout.addWidget(self.plot)
        self.setLayout(vlayout)

    def update_plot(self):

        self.data_cut.set_data(self.xdata, self.ydata)
        self.data_fit.set_data(self.xfit, self.yfit)

        self.plot.set_plot_limits(self.xdata.min(), self.xdata.max(),
                                  self.ydata.min(), self.ydata.max())
        self.plot.replot()
Exemple #2
0
class FilterTestWidget(QWidget):
    """
    Filter testing widget
    parent: parent widget (QWidget)
    x, y: NumPy arrays
    func: function object (the signal filter to be tested)
    """

    def __init__(self, parent, x, y, func):
        QWidget.__init__(self, parent)
        self.setMinimumSize(320, 200)
        self.x = x
        self.y = y
        self.func = func
        # ---guiqwt related attributes:
        self.plot = None
        self.curve_item = None
        # ---

    def setup_widget(self, title):
        # ---Create the plot widget:
        self.plot = CurvePlot(self)
        self.curve_item = make.curve([], [], color="b")
        self.plot.add_item(self.curve_item)
        self.plot.set_antialiasing(True)
        # ---

        button = QPushButton("Test filter: %s" % title)
        button.clicked.connect(self.process_data)
        vlayout = QVBoxLayout()
        vlayout.addWidget(self.plot)
        vlayout.addWidget(button)
        self.setLayout(vlayout)

        self.update_curve()

    def process_data(self):
        self.y = self.func(self.y)
        self.update_curve()

    def update_curve(self):
        # ---Update curve
        self.curve_item.set_data(self.x, self.y)
        self.plot.replot()
Exemple #3
0
class FilterTestWidget(QWidget):
    """
    Filter testing widget
    parent: parent widget (QWidget)
    x, y: NumPy arrays
    func: function object (the signal filter to be tested)
    """
    def __init__(self, parent, x, y, func):
        QWidget.__init__(self, parent)
        self.setMinimumSize(320, 200)
        self.x = x
        self.y = y
        self.func = func
        #---guiqwt related attributes:
        self.plot = None
        self.curve_item = None
        #---
        
    def setup_widget(self, title):
        #---Create the plot widget:
        self.plot = CurvePlot(self)
        self.curve_item = make.curve([], [], color='b')
        self.plot.add_item(self.curve_item)
        self.plot.set_antialiasing(True)
        #---
        
        button = QPushButton("Test filter: %s" % title)
        button.clicked.connect(self.process_data)
        vlayout = QVBoxLayout()
        vlayout.addWidget(self.plot)
        vlayout.addWidget(button)
        self.setLayout(vlayout)
        
        self.update_curve()
        
    def process_data(self):
        self.y = self.func(self.y)
        self.update_curve()
        
    def update_curve(self):
        #---Update curve
        self.curve_item.set_data(self.x, self.y)
        self.plot.replot()
Exemple #4
0
class PlotWidget(qg.QWidget):
    """

    """
    def __init__(self, parent):
        qg.QWidget.__init__(self, parent)
        
        self.data = Dbase(DBaseName = None)
        self.y1Data = np.arange(0,100) #self.data.Query()
        self.y2Data = np.arange(0,100)
        self.xData = np.arange(0,100) #np.arange(0, len(self.yData))
        self.setMinimumSize(700, 600)
        #---guiqwt related attributes:
        self.plot = None
        self.curve_item = None
        #---
        
    def setup_widget(self, title):
        """Create the plot widget:
        """
        self.plot = CurvePlot(self)
        self.curve_item = (make.curve([], [], color='b'))
        self.second_curve_item = (make.curve([], [], color='g'))

        self.plot.add_item(self.curve_item)
        self.plot.add_item(self.second_curve_item)
        self.plot.set_antialiasing(True)

        self.databaseScroll = treeList(DBaseName = None)
        self.databaseScroll.setSortingEnabled(True)

        spacer = qg.QSpacerItem(30,40)
        
        preprocessData = qg.QGroupBox(u"preprocess data")
        
        # create buttons
        listButton = qg.QPushButton(u"Refresh list")
        processButton = qg.QPushButton(u"       run preprocessing       ")
        addwaveletButton = qg.QPushButton(u"  add wavelet spikes to DB  ")
        
        self.checkAddData = qg.QMessageBox()
        
        self.wavelet = qg.QCheckBox(u"wavelet filter      ")
        label1 = qg.QLabel("enter threshold value: ")
        self.waveletThreshold = qg.QLineEdit(u"10")

        

        # connect user actions with methods:
        self.connect(listButton, qc.SIGNAL('clicked()'), self.but_clicked)
        self.connect(processButton, qc.SIGNAL('clicked()'), 
                     self.run_preprocess)
        self.connect(self.databaseScroll, 
                     qc.SIGNAL("doubleClicked(QModelIndex)"), 
                     self.double_clicked)
        self.connect(addwaveletButton, qc.SIGNAL('clicked()'), 
                     self.add_data_to_DBase)
        
        vlayout = qg.QVBoxLayout()
        hlayout = qg.QHBoxLayout()

        
        vlayout.addWidget(self.databaseScroll)
        vlayout.addWidget(listButton)

        vlayout.addWidget(self.plot)
        vlayout.addSpacerItem(spacer)
        
        
        hlayout.addWidget(self.wavelet)
        hlayout.addWidget(label1)
        hlayout.addWidget(self.waveletThreshold)

        hlayout.addWidget(processButton)        
        hlayout.addWidget(addwaveletButton)
        preprocessData.setLayout(hlayout)
        vlayout.addWidget(preprocessData)
        
        self.setLayout(vlayout)
        
        self.update_curve()

                                             
    def run_preprocess(self):
        """
        """
        if self.wavelet.isChecked():
            try:
                yData = self.y1Data
                waveletFilt = pp.wavefilter(yData)
                self.y1Data = waveletFilt
                
                thresh = float(self.waveletThreshold.displayText())
                foo = self.y1Data > thresh
                self.y2Data = self.spikes = pp.spikeThreshold(foo)
                self.y2Data = self.y2Data * max(self.y1Data)
                
                index = self.databaseScroll
                root = index.parent().parent().row()
                neuron = index.parent().row()
                epoch = index.row()  
                
                r = self.databaseScroll
                self.data = Dbase(r, PRINT = 1)               
                n = self.databaseScroll
                epochs = self.data.GetTree(n)
                e = epochs[epoch]
                
                self.spikeOverwriteLoc = [r, n + '.' + e + '.spikes', n + '.' + e, 
                                          n, n + '_' + e + '_spikes']
                self.data.Data.CloseDatabase(PRINT=1)
                print 'spike handle: ', self.spikeOverwriteLoc[0:2]
                self.update_curve()
                
            except ValueError:
                print 'ValueError. Could not compute wave filter'
                pass


    def add_data_to_DBase(self):
        """
        """
        self.checkAddData.setText("Are you sure you want to add filtered \
                                spikes (green trace) to the DB?")
        self.checkAddData.setInformativeText("This will overwrite the existing\
                                                spike data.")
        self.checkAddData.setStandardButtons(qg.QMessageBox.Yes | qg.QMessageBox.No )
        self.checkAddData.setDefaultButton(qg.QMessageBox.Yes)
        choice = self.checkAddData.exec_() == qg.QMessageBox.Yes
        if choice:
            tim = dt.datetime.utcnow().ctime()
            
            self.data = Dbase(self.spikeOverwriteLoc[0], PRINT = 1)
            self.data.Data.RemoveChild(self.spikeOverwriteLoc[1], option=1)
            self.data.Data.AddData2Database('spikes', self.spikes,
                                            self.spikeOverwriteLoc[2])
            self.data.Data.AddGitVersion(self.spikeOverwriteLoc[3],
                                         Action = 'updated_{0}'.format(
                                         self.spikeOverwriteLoc[4] + 
                                         str(tim.day) + '_' +
                                         str(tim.month) + '_' +
                                         str(tim.year) ))
                                         
            self.data.Data.CloseDatabase(PRINT=1)
            
            """ add git version here """
            print 'spike data successfully overwritten.'

    
    def update_curve(self):
        """Update curve
        """
        self.curve_item.set_data(self.xData, self.y1Data)
        self.second_curve_item.set_data(self.xData, self.y2Data)
        self.plot.replot()
        self.plot.do_autoscale()
        
    def but_clicked(self):
        """
        when refresh button is clicked, the tree is refreshed
        """
        self.databaseScroll.refreshTree()
    
    def double_clicked(self):
        """
        when a name button is clicked, iterate over the model, 
        find the neuron with this name, and set the treeviews current item
        """
        index = self.databaseScroll.currentIndex()
        item = self.databaseScroll.currentItem()
        clickedCol = index.column()

        hasChild = True
        ind = index.parent()
        colCount = 1
        parents = []
        while hasChild:
            #find root index and parents
            root = ind.row()
            if root != -1:
                parent = str(self.databaseScroll.itemFromIndex(
                                        ind).text(clickedCol - colCount))
                parents.insert(0, parent)
            if root == -1:
                hasChild = False
                if parents == []:
                    parent = str(self.databaseScroll.itemFromIndex(
                                        index).text(0))
                    parents.insert(0, parent)
                    
            ind = ind.parent()
            colCount += 1

        name = str(item.text(clickedCol))
        print 'name: ', name
        print 'parents: ', parents
        
        if len(parents) == 1:
            parents.append('')
        if name[:-3] == '.h5':
            name = '/'
        
        kind = self.databaseScroll.GetDtype(parents[0], name, parents[1:])
        print kind
        #if dtype_ == 'np.array':
        #    clickedData = self.databaseScroll.GetData(parents[0], name, 
        #                                          parents[1:])
        #    print clickedData
        '''    
Exemple #5
0
class iScopeWidget(QWidget):
    """
    Filter testing widget
    parent: parent widget (QWidget)
    x, y: NumPy arrays
    func: function object (the signal filter to be tested)
    """

    def __init__(self, parent, x, y):
        QWidget.__init__(self, parent)
        self.setMinimumSize(320, 200)
        self.x = x
        self.y = y
        # ---guiqwt related attributes:
        self.plot = None
        self.curve_item = None
        # ---

    def setup_widget(self):
        # ---Create the plot widget:
        x = self.x
        y = self.y
        self.plot = CurvePlot(self)
        self.curve_item = make.curve([], [], color="b")
        self.plot.add_item(self.curve_item)
        self.plot.set_antialiasing(True)
        width = x[-1] - x[0]
        self.intrange = make.range(x[0] + 0.4 * width, x[-1] - 0.4 * width)
        self.plot.add_item(self.intrange)
        self.lbgrange = make.range(x[0] + 0.3 * width, x[-1] - 0.65 * width)
        self.plot.add_item(self.lbgrange)
        self.lbgrange.pen = Qt.QPen(Qt.QColor("blue"))
        self.lbgrange.brush = Qt.QBrush(Qt.QColor(0, 0, 120, 100))
        self.rbgrange = make.range(x[0] + 0.7 * width, x[-1] - 0.1 * width)
        self.rbgrange.pen = Qt.QPen(Qt.QColor("blue"))
        self.rbgrange.brush = Qt.QBrush(Qt.QColor(0, 0, 120, 100))
        self.label1 = make.label(r"", "TR", (0, 0), "TR")
        self.plot.add_item(self.rbgrange)
        self.bg_item = make.curve([], [], color="r")
        self.plot.add_item(self.bg_item)
        self.fit_bg()
        self.plot.add_item(self.label1)
        self.connect(self.plot, SIG_RANGE_CHANGED, self.fit_bg)
        # ---
        vlayout = QVBoxLayout()
        vlayout.addWidget(self.plot)
        self.setLayout(vlayout)
        self.update_curve()

    def fit_bg(self):
        degree = 3
        table = array([self.x, self.y])
        low, high = self.lbgrange.get_range()
        idxlower = set(where(table[0] <= high)[0])
        idxhigher = set(where(table[0] >= low)[0])
        idx1 = list(idxhigher.intersection(idxlower))
        low, high = self.rbgrange.get_range()
        idxlower = set(where(table[0] <= high)[0])
        idxhigher = set(where(table[0] >= low)[0])
        idx2 = list(idxhigher.intersection(idxlower))
        idx = idx1 + idx2
        x, y = table[:, idx]
        self.coeff = polyfit(x, y, degree)

        left, right = self.intrange.get_range()
        # bg = abs(quad(lambda x: polyval(self.coeff, x), left, right)[0])
        idxlower = set(where(table[0] <= right)[0])
        idxhigher = set(where(table[0] >= left)[0])
        idx = list(idxhigher.intersection(idxlower))
        x, y = table[:, idx]
        self.int = abs(trapz(y - polyval(self.coeff, x), x=x))

        self.update_label()
        self.update_curve()

    def update_label(self):
        self.label1.set_text(u"""trapz(red) - int(blue) = %e""" % self.int)

    def update_curve(self):
        # ---Update curve
        self.curve_item.set_data(self.x, self.y)
        self.plot.replot()
        y = polyval(self.coeff, self.x)
        self.bg_item.set_data(self.x, y)
Exemple #6
0
class UI(QtGui.QMainWindow):
    def __init__(self, parent=None):
        super(UI, self).__init__(parent)
        self.ui_obj = None
        self.y_position = 0
        self.slider_pos = 0
        self.image_temp = None
        self.image_w = 0
        self.image_h = 0
        self.image_path = None
        self.ui_init()
        self.init_connect()
        self.plot_init()

    def ui_init(self):
        self.ui_obj = Ui_MainWindow()
        self.ui_obj.setupUi(self)

    def init_connect(self):
        self.connect(self.ui_obj.y_verticalSlider,
                     QtCore.SIGNAL('valueChanged(int)'), self.plot_line_gray)
        self.ui_obj.open_image_button.clicked.connect(self.open_image)
        self.ui_obj.save_plot_button.clicked.connect(self.save_image)

    def open_image(self):
        self.file_name = QtGui.QFileDialog.getOpenFileName(
            self, "open file dialog", "C:\Users\Administrator\Desktop",
            "*.jpg *.png")
        print self.file_name
        if self.file_name != "":
            # if self.ui_obj.image_path_lineEdit.text() != "":
            #     self.image_path = self.ui_obj.image_path_lineEdit.text()
            self.image_path = self.file_name
            if os.path.exists(self.image_path):
                self.display_image()
            else:
                # self.image_clean()
                print "path is not exists!!\n"
        else:
            # self.image_clean()
            print "pls fill imagepath\n"

    def save_image(self):
        pass

    def image_clean(self):
        self.image_temp = None
        self.image_w = 0
        self.image_h = 0

    def display_image(self):
        # print "cv2 imread path :"+str(self.image_path)
        self.image_temp = cv2.imread(str(self.image_path))
        self.image_h = len(self.image_temp)
        self.image_w = len(self.image_temp[0])
        # #print image.shape[1],image.shape[0]
        # #print self.ui_obj.picture_display.width(),self.ui_obj.picture_display.height()
        # qimage = QtGui.QImage(image, image.shape[1], image.shape[0], QtGui.QImage.Format_RGB888)
        # qimage_resize = qimage.scaled(self.ui_obj.picture_display.width(), self.ui_obj.picture_display.height(), 0, 0)
        # self.ui_obj.picture_display.setPixmap(QtGui.QPixmap.fromImage(qimage_resize))
        # self.ui_obj.picture_display.show()

        pixmap = QtGui.QPixmap()
        pixmap.load(str(self.image_path))
        newpixmap = pixmap.scaled(self.ui_obj.picture_display.width(),
                                  self.ui_obj.picture_display.height(), 0, 0)
        self.ui_obj.picture_display.setPixmap(newpixmap)
        self.ui_obj.picture_display.setAlignment(QtCore.Qt.AlignCenter)
        self.ui_obj.picture_display.show()

    def plot_init(self):
        self.manager = PlotManager(self)
        self.plots = []
        self.plot = CurvePlot(xlabel="", ylabel="")
        # self.plot.axisScaleDraw(CurvePlot.Y_LEFT).setMinimumExtent(10)
        self.manager.add_plot(self.plot)
        self.plots.append(self.plot)
        self.plot.plot_id = id(self.plot)
        self.curve = make.curve([0], [0], color="blue", title="gray value")
        self.plot.add_item(self.curve)
        self.plot.add_item(make.legend("TR"))
        self.ui_obj.line_info_display.addWidget(self.plot)

    def plot_line_gray(self):
        self.slider_pos = self.ui_obj.y_verticalSlider.value()
        # print"print line ",self.slider_pos
        # print type(self.slider_pos), len(self.image_temp)
        self.y_position = (100 - self.slider_pos) * (int(self.image_h) / 100)
        print "y_position:", self.y_position

        line = [
            self.image_temp[self.y_position][i][1]
            for i in range(len(self.image_temp[self.y_position]))
        ]
        print len(line), [i for i in range(len(line))]
        self.plot.set_axis_limits('left', 0, 255)
        self.plot.set_axis_limits('bottom', 0, self.image_w)
        self.curve.set_data([i for i in range(len(line))], line)
        self.plot.replot()

        # print "image :", self.image_temp[3]

    def blur(self, kernel):
        pass

    def nosie(self, method):
        pass