Пример #1
0
 def classifier_choice(self, text):
     print('classifier_choice() activated.')
     if not self.prediction_started:
         self.classifier_type = text
         classifierID = selectClassifierID(self.params.finalClassifierDir,
                                           self.classifier_type)
         self.client.setStagePredictor(classifierID)
         print('classifier_type changed to', self.classifier_type)
Пример #2
0
def generateClassifier(params, chamberID, observed_samplingFreq,
                       observed_epochTime):
    # self.recordWaves = params.writeWholeWaves
    # self.extractorType = params.extractorType
    # self.finalClassifierDir = params.finalClassifierDir
    networkName = 'UTSN-L'
    classifierID, model_samplingFreq, model_epochTime = selectClassifierID(
        params.finalClassifierDir, networkName, observed_samplingFreq,
        observed_epochTime)
    # classifier_samplingFreq, classifier_epochTime = classifierMetadata(params.finalClassifierDir, classifierID)
    # check if classifier's samplingFreq and epochTime matches with requested samplingFreq and epochTime
    client = ClassifierClient(params.writeWholeWaves,
                              params.extractorType,
                              params.classifierType,
                              classifierID,
                              chamberID=chamberID,
                              samplingFreq=model_samplingFreq,
                              epochTime=model_epochTime)
    client.predictionStateOn()
    client.hasGUI = False
    return client
Пример #3
0
    def initUI(self):

        self.startPredictionButton = QtWidgets.QPushButton('Predict', self)
        self.startPredictionButton.setCheckable(True)
        self.startPredictionButton.clicked.connect(self.startPrediction)
        self.startPredictionButton.resize(
            self.startPredictionButton.sizeHint())
        self.startPredictionButton.move(int(5 * self.scale),
                                        int(10 * self.scale))

        quitButton = QtWidgets.QPushButton('Quit', self)
        quitButton.clicked.connect(QtCore.QCoreApplication.instance().quit)
        quitButton.resize(quitButton.sizeHint())
        quitButton.move(int(85 * self.scale), int(10 * self.scale))

        checkConnectionButton = QtWidgets.QPushButton('Test connection', self)
        checkConnectionButton.clicked.connect(self.check_connection)
        checkConnectionButton.resize(checkConnectionButton.sizeHint())
        checkConnectionButton.move(int(160 * self.scale), int(10 * self.scale))

        # change standardization of ch2
        self.nameLabel_terminal_label = QLabel(self)
        self.nameLabel_terminal_label.setText('Terminal:')
        self.nameLabel_terminal_label.move(int(250 * self.scale),
                                           int(35 * self.scale))
        self.terminal_combobox = QtWidgets.QComboBox(self)
        self.terminal_combobox.addItem(self.terminal_str_diff)
        self.terminal_combobox.addItem(self.terminal_str_rse)
        self.terminal_combobox.addItem(self.terminal_str_nrse)
        # self.terminal_combobox.addItem(self.terminal_str_pseudo) # not available for NI DAQ USB-6210
        self.terminal_combobox.move(int(310 * self.scale),
                                    int(38 * self.scale))
        self.terminal_combobox.resize(self.terminal_combobox.sizeHint())
        self.terminal_combobox.activated[str].connect(self.terminal_choice)
        self.terminal_combobox.setCurrentText(self.terminal_config)

        # change model
        self.nameLabel_classifier = QLabel(self)
        self.nameLabel_classifier.setText('Model:')
        self.nameLabel_classifier.move(int(300 * self.scale),
                                       int(10 * self.scale))

        self.classifier_combobox = QtWidgets.QComboBox(self)
        for classifier_type in self.classifier_types:
            self.classifier_combobox.addItem(classifier_type)
        self.classifier_combobox.resize(self.classifier_combobox.sizeHint())
        self.classifier_combobox.move(int(340 * self.scale),
                                      int(10 * self.scale))
        self.classifier_combobox.activated[str].connect(self.classifier_choice)
        '''
        self.nameLabel_ch2_overwrite_label = QLabel(self)
        self.nameLabel_ch2_overwrite_label.setText('Overwrite by Ch2:')
        self.nameLabel_ch2_overwrite_label.move(int(450 * self.scale), int(5 * self.scale))

        self.overwriteOrNotButton = QtWidgets.QPushButton(self.label_notOverwrite, self)
        self.overwriteOrNotButton.clicked.connect(self.toggleOverwriteOrNot)
        self.overwriteOrNotButton.resize(self.overwriteOrNotButton.sizeHint())
        self.overwriteOrNotButton.update()
        self.overwriteOrNotButton.move(int(530 * self.scale), int(5 * self.scale))
        self.overwriteOrNotButton.setCheckable(True)
        '''

        # change standardization of eeg
        self.nameLabel_eeg_visualize_mode_label = QLabel(self)
        self.nameLabel_eeg_visualize_mode_label.setText('EEG mode:')
        self.nameLabel_eeg_visualize_mode_label.move(int(610 * self.scale),
                                                     int(2 * self.scale))
        self.eeg_visualize_mode_combobox = QtWidgets.QComboBox(self)
        self.eeg_visualize_mode_combobox.addItem(
            self.eeg_visualize_mode_str_none)
        self.eeg_visualize_mode_combobox.addItem(
            self.eeg_visualize_mode_str_normalize)
        self.eeg_visualize_mode_combobox.move(int(680 * self.scale),
                                              int(5 * self.scale))
        self.eeg_visualize_mode_combobox.resize(
            self.eeg_visualize_mode_combobox.sizeHint())
        self.eeg_visualize_mode_combobox.activated[str].connect(
            self.eeg_visualize_mode_choice)

        # change standardization of ch2
        self.nameLabel_ch2_visualize_mode_label = QLabel(self)
        self.nameLabel_ch2_visualize_mode_label.setText('Ch2 mode:')
        self.nameLabel_ch2_visualize_mode_label.move(int(610 * self.scale),
                                                     int(30 * self.scale))
        self.ch2_visualize_mode_combobox = QtWidgets.QComboBox(self)
        self.ch2_visualize_mode_combobox.addItem(
            self.ch2_visualize_mode_str_none)
        self.ch2_visualize_mode_combobox.addItem(
            self.ch2_visualize_mode_str_normalize)
        self.ch2_visualize_mode_combobox.move(int(680 * self.scale),
                                              int(33 * self.scale))
        self.ch2_visualize_mode_combobox.resize(
            self.ch2_visualize_mode_combobox.sizeHint())
        self.ch2_visualize_mode_combobox.activated[str].connect(
            self.ch2_visualize_mode_choice)

        # set overwrite threshold to W
        self.nameLabel_ch2_thresh = QLabel(self)
        self.nameLabel_ch2_thresh.setText('Overwrite threshold to W:')
        self.nameLabel_ch2_thresh.resize(self.nameLabel_ch2_thresh.sizeHint())
        self.nameLabel_ch2_thresh.move(int(840 * self.scale),
                                       int(15 * self.scale))
        self.ch2_thresh = QLineEdit(self)
        self.ch2_thresh.setText(str(self.params.ch2_thresh_default))
        self.ch2_thresh.move(int(1000 * self.scale), int(15 * self.scale))
        self.ch2_thresh.resize(50, 20)
        self.ch2_thresh.textChanged.connect(self.ch2_thresh_text_change)

        # change usage of ch2
        self.nameLabel_ch2_usage_label = QLabel(self)
        self.nameLabel_ch2_usage_label.setText('Ch2 usage:')
        self.nameLabel_ch2_usage_label.move(int(840 * self.scale),
                                            int(60 * self.scale))
        self.ch2_usage_combobox = QtWidgets.QComboBox(self)
        self.ch2_usage_combobox.addItem(self.ch2_usage_str_dontshowCh2)
        self.ch2_usage_combobox.addItem(self.ch2_usage_str_showCh2)
        self.ch2_usage_combobox.addItem(self.ch2_usage_str_overwrite)
        self.ch2_usage_combobox.move(int(910 * self.scale),
                                     int(60 * self.scale))
        self.ch2_usage_combobox.resize(self.ch2_usage_combobox.sizeHint())
        self.ch2_usage_combobox.activated[str].connect(self.ch2_usage_choice)
        self.ch2_usage_combobox_setup()

        self.ylim_label_eeg = QLabel(self)
        self.ylim_label_eeg.setText('eeg y-max:')
        self.ylim_label_eeg.move(int(40 * self.scale), int(40 * self.scale))
        self.ylim_label_eeg.resize(self.ylim_label_eeg.sizeHint())
        self.ylim_value_eeg_box = QLineEdit(self)
        self.ylim_value_eeg_box.move(int(45 * self.scale),
                                     int(60 * self.scale))
        self.ylim_value_eeg_box.resize(40, 20)

        self.ylim_slider_eeg = QSlider(Qt.Vertical, self)
        self.ylim_slider_eeg.move(int(10 * self.scale), int(40 * self.scale))
        self.ylim_slider_eeg.resize(20, 80)
        self.ylim_slider_eeg.setMinimum(0)
        self.ylim_slider_eeg.setMaximum(100)
        self.ylim_slider_eeg.setTickPosition(QSlider.TicksBelow)
        self.ylim_slider_eeg.setTickInterval(1)
        self.ylim_slider_eeg.valueChanged.connect(self.ylim_change_eeg)

        self.ylim_label_ch2 = QLabel(self)
        self.ylim_label_ch2.setText('ch2 y-max:')
        self.ylim_label_ch2.move(int(160 * self.scale), int(40 * self.scale))
        self.ylim_label_ch2.resize(self.ylim_label_ch2.sizeHint())
        self.ylim_value_ch2_box = QLineEdit(self)
        self.ylim_value_ch2_box.move(int(165 * self.scale),
                                     int(60 * self.scale))
        self.ylim_value_ch2_box.resize(40, 20)

        self.ylim_slider_ch2 = QSlider(Qt.Vertical, self)
        self.ylim_slider_ch2.move(int(130 * self.scale), int(40 * self.scale))
        self.ylim_slider_ch2.resize(20, 40)
        self.ylim_slider_ch2.setMinimum(0)
        self.ylim_slider_ch2.setMaximum(100)
        self.ylim_slider_ch2.setTickPosition(QSlider.TicksBelow)
        self.ylim_slider_ch2.setTickInterval(1)
        self.ylim_slider_ch2.valueChanged.connect(self.ylim_change_ch2)

        self.ch2_thresh_slider_tick_factor = 4
        self.ch2_thresh_slider = QSlider(Qt.Horizontal, self)
        self.ch2_thresh_slider.move(int(1035 * self.scale),
                                    int(20 * self.scale))
        self.ch2_thresh_slider.resize(190, 20)
        self.ch2_thresh_minimum = -2
        self.ch2_thresh_maximum = 10
        self.ch2_thresh_slider.setMinimum(self.ch2_thresh_minimum *
                                          self.ch2_thresh_slider_tick_factor)
        self.ch2_thresh_slider.setMaximum(self.ch2_thresh_maximum *
                                          self.ch2_thresh_slider_tick_factor)
        self.ch2_thresh_slider.setValue(self.ch2_thresh_slider_tick_factor)
        self.ch2_thresh_slider.setTickPosition(QSlider.TicksBelow)
        self.ch2_thresh_slider.setTickInterval(1)
        self.ch2_thresh_slider.valueChanged.connect(self.ch2_thresh_change)

        self.label_graph_ch2 = QLabel(self)
        self.label_graph_ch2.setFont(QtGui.QFont('Courier New', 24))
        self.label_graph_ch2.setText('Epoch# : Prediction')
        self.label_graph_ch2.resize(self.label_graph_ch2.sizeHint())
        self.label_graph_ch2.move(int(500 * self.scale), int(60 * self.scale))

        self.label_graph_eeg = QLabel(self)
        self.label_graph_eeg.setFont(QtGui.QFont('Courier New', 20))
        self.label_graph_eeg.setText('EEG')
        self.label_graph_eeg.move(int(5 * self.scale), int(205 * self.scale))

        self.label_graph_ch2 = QLabel(self)
        self.label_graph_ch2.setFont(QtGui.QFont('Courier New', 20))
        self.label_graph_ch2.setText('Ch2')
        self.label_graph_ch2.move(int(5 * self.scale), int(405 * self.scale))

        self.font = QtGui.QFont()
        self.font.setPointSize(18)
        self.font.setBold(True)
        self.font.setWeight(75)

        self.listOfPredictionResults = []
        self.listOfGraphs = []
        self.listOfGraphs.append([])
        self.listOfGraphs.append([])

        for graphID in range(self.graphNum):

            self.listOfPredictionResults.append(PredictionResultLabel(self))
            predXLoc = (graphID * 300) + 125
            predYLoc = 90
            self.listOfPredictionResults[graphID].move(
                int(predXLoc * self.scale), int(predYLoc * self.scale))

            xLoc = (graphID * 300) + 50
            for chanID in range(2):
                yLoc = (chanID * 200) + 120
                self.listOfGraphs[chanID].append(
                    DynamicGraphCanvas(self,
                                       width=int(3 * self.scale),
                                       height=int(2 * self.scale),
                                       dpi=100))
                self.listOfGraphs[chanID][graphID].move(
                    int(xLoc * self.scale), int(yLoc * self.scale))

        self.setWindowTitle('Sleep stage classifier')
        xSize = self.graphNum * 310
        ySize = 550
        self.resize(int(xSize * self.scale), int(ySize * self.scale))
        self.show()
        self.activateWindow()
        statusbar = self.statusBar()
        self.readFromDaq = False
        try:
            if len(self.args) > 5:
                print('Too many arquments for running app.py.')
                quit()
            self.sleepTime = float(
                self.args[4]) if len(self.args) > 4 else self.defaultSleepTime
            self.offsetWindowID = int(
                self.args[3]) if len(self.args) > 3 else 0
            if len(self.args) > 1:
                optionID = self.args[1]
                if optionID == 'm':
                    classifierID = selectClassifierID(
                        self.params.finalClassifierDir, self.classifier_type)
                    self.inputFileID = self.args[2] if len(
                        self.args) > 2 else self.randomlySelectInputFileID()
                    print('demo mode: reading inputFileID=', self.inputFileID)
                elif optionID == 'o':
                    classifierID = selectClassifierID(
                        self.params.finalClassifierDir, self.classifier_type)
                    self.inputFileID = ''
                else:
                    classifierID = optionID
                    self.inputFileID = ''
                self.client = ClassifierClient(self.recordWaves,
                                               self.extractorType,
                                               self.classifierType,
                                               classifierID, self.inputFileID,
                                               self.offsetWindowID)
            else:  # Neither classifierID nor inputFileID are specified.
                self.readFromDaq = True
                classifierID = selectClassifierID(
                    self.params.finalClassifierDir, self.classifier_type)
                print(
                    'Data is read from DAQ. classifier ID is randomly selected.'
                )
                self.client = ClassifierClient(self.recordWaves,
                                               self.extractorType,
                                               self.classifierType,
                                               classifierID)
            self.client.hasGUI = True
            self.ylim_value_eeg_box.setText(str(self.client.ylim_max_eeg))
            self.ylim_value_ch2_box.setText(str(self.client.ylim_max_ch2))
            self.ylim_slider_eeg.setValue(int(self.client.ylim_max_eeg * 10))
            self.ylim_slider_ch2.setValue(int(self.client.ylim_max_ch2 * 10))

        except Exception as e:
            print('Exception in self.client = ...')
            statusbar.showMessage(str(e))
            raise e
Пример #4
0
    def serve(self):

        PORT = 45123
        ### BUFSIZE = 10240
        BUFSIZE = 10240 * 16  # must be this big to process 1024 Hz
        networkName = 'UTSN-L'
        encode_judge = defaultdict(
            lambda: 3, w=0, n=1, r=2)  # for encoding judge result to a number
        ai_clients = {}

        # bind, listen, and accept a client
        tcpServSock = socket(AF_INET, SOCK_STREAM)
        tcpServSock.bind(('', PORT))
        tcpServSock.listen(1)
        # print('waiting to accept a client...')
        tcp_client, sendAddr = tcpServSock.accept()
        # print('accepted a client from', sendAddr)

        while True:
            # for setting up sampling frequency and epoch width
            received_data = tcp_client.recv(BUFSIZE)
            observed_samplingFreq = struct.unpack_from('H', received_data,
                                                       0)[0]  #WORD
            observed_epochTime = struct.unpack_from('H', received_data,
                                                    2)[0]  #WORD
            print('observed_samplingFreq =', observed_samplingFreq)
            print('observed_epochTime =', observed_epochTime)
            observed_samplePointNum = observed_samplingFreq * observed_epochTime
            fmt = reduce(lambda a, _: a + 'f',
                         range(observed_samplingFreq * observed_epochTime),
                         '')  # range used for unpacking EEG from received data
            classifierID, model_samplingFreq, model_epochTime = selectClassifierID(
                self.params_for_classifier.finalClassifierDir, networkName,
                observed_samplingFreq, observed_epochTime)
            model_samplePointNum = model_samplingFreq * model_epochTime

            if classifierID == -1:
                res = 0
                retByte = res.to_bytes(2, 'little')
                tcp_client.send(retByte)
                break
            else:
                res = 1
                retByte = res.to_bytes(2, 'little')
                tcp_client.send(retByte)

                while True:
                    # try:
                    received_data = tcp_client.recv(BUFSIZE)
                    # print('received_data =', received_data)
                    #print('len(received_data) =', len(received_data))

                    if len(received_data) == 0:
                        exit()

                    elif len(
                            received_data
                    ) == 1:  # the received data is for connection check
                        # connection test (received 1 byte)
                        resp = 'Connection OK'.encode('utf-8')
                        tcp_client.send(resp)

                    elif len(received_data
                             ) == 6:  # the received data is for resetting
                        commandID = struct.unpack_from('I', received_data,
                                                       0)[0]  #DWORD
                        chamberID = struct.unpack_from('H', received_data,
                                                       4)[0]  #WORD
                        resetCommand = 901
                        if commandID == resetCommand:
                            print('resetting chamber', chamberID)
                            ai_clients[chamberID] = generateClassifier(
                                self.params_for_classifier, chamberID,
                                observed_samplingFreq, observed_epochTime)
                            reset_status = 1
                        else:
                            reset_status = 0
                        respByte = reset_status.to_bytes(2, 'little')
                        tcp_client.send(respByte)

                    else:
                        ### elif len(received_data) == 5142:  # the received data is signal + metadata

                        # obtain the chamber number
                        # print('### len(received_data) =', len(received_data))
                        chamberID = struct.unpack_from('H', received_data,
                                                       0)[0]  #WORD
                        # print('chamberID =', int(chamberID))

                        # obtain the epoch number
                        epochID = struct.unpack_from('I', received_data,
                                                     2)[0]  #DWORD
                        # print('epochID =', int(epochID))

                        # obtain the time that the record started
                        dt = struct.unpack_from('HHHHHHI', received_data, 6)
                        # print('%%% dt =', dt)
                        startDT = datetime(dt[0], dt[1], dt[2], dt[3], dt[4],
                                           dt[5], dt[6])
                        # print('startDT =', startDT)

                        # EEG data
                        signalW = struct.unpack_from(fmt, received_data,
                                                     22)  #float
                        #print('len(signalW) =', len(signalW))
                        signal_rawarray = np.array(signalW, dtype='float64')
                        #print('before up/down sampling, signal_rawarray.shape =', signal_rawarray.shape)

                        #print('model_samplePointNum =', model_samplePointNum)
                        #print('observed_samplePointNum =', observed_samplePointNum)

                        signal_rawarray = up_or_down_sampling(
                            signal_rawarray, model_samplePointNum,
                            observed_samplePointNum)
                        # print('after up/down sampling, signal_rawarray.shape =', signal_rawarray.shape)

                        # generate a new classifierClient when new chamberID comes.
                        if chamberID not in ai_clients.keys():
                            ai_clients[chamberID] = generateClassifier(
                                self.params_for_classifier, chamberID,
                                model_samplingFreq, observed_epochTime)

                        # Loops because classifierClients accepts segments, not full epochs, in order to visualize waves in GUI.
                        # Before the final segment, judgeStr is '-'.
                        assert model_samplingFreq % self.params_for_classifier.graphUpdateFreqInHz == 0
                        updateGraph_samplePointNum = np.int(
                            model_samplingFreq /
                            self.params_for_classifier.graphUpdateFreqInHz)
                        assert updateGraph_samplePointNum > 0
                        startID = 0
                        while startID < signal_rawarray.shape[0]:
                            # print('startID =', startID)
                            # print('len(signal_rawarray[startID:startID+updateGraph_samplePointNum]) =', len(signal_rawarray[startID:startID+updateGraph_samplePointNum]))
                            dataToAIClient = formatRawArray(
                                startDT, model_samplingFreq,
                                signal_rawarray[startID:startID +
                                                updateGraph_samplePointNum])
                            # print('in server, dataToAIClient =', dataToAIClient)
                            judgeStr = ai_clients[chamberID].process(
                                dataToAIClient)
                            startID += updateGraph_samplePointNum

                        cByte = chamberID.to_bytes(2, 'little')
                        eByte = epochID.to_bytes(4, 'little')
                        jByte = encode_judge[judgeStr].to_bytes(2, 'little')

                        # return to the client
                        retByte = cByte + eByte + jByte
                        tcp_client.send(retByte)
Пример #5
0
    def start(self):
        channelOpt = 1
        params = ParameterSetup()
        self.recordWaves = params.writeWholeWaves
        self.extractorType = params.extractorType
        self.classifierType = params.classifierType
        self.postDir = params.postDir
        self.predDir = params.predDir
        self.finalClassifierDir = params.finalClassifierDir
        self.samplingFreq = params.samplingFreq

        # eegFilePath = args[1]
        # inputFileID = splitext(split(eegFilePath)[1])[0]
        postFiles = listdir(self.postDir)
        fileCnt = 0
        for inputFileName in postFiles:
            if not inputFileName.startswith('.'):
                print('inputFileName = ' + inputFileName)
                inputFileID = splitext(inputFileName)[0]
                print('inputFileID = ' + inputFileID)
                predFileFullPath = self.predDir + '/' + inputFileID + '_pred.txt'
                print('predFileFullPath = ' + predFileFullPath)

                if not isfile(predFileFullPath):
                    fileCnt += 1
                    print('  processing ' + inputFileID)
                    try:
                        classifierID = selectClassifierID(
                            self.finalClassifierDir, self.classifier_type)
                        if len(self.args) > 1:
                            if self.args[1] == '--output_the_same_fileID':
                                self.client = ClassifierClient(
                                    self.recordWaves,
                                    self.extractorType,
                                    self.classifierType,
                                    classifierID,
                                    inputFileID=inputFileID)
                            else:
                                self.client = ClassifierClient(
                                    self.recordWaves, self.extractorType,
                                    self.classifierType, classifierID)
                        else:
                            self.client = ClassifierClient(
                                self.recordWaves, self.extractorType,
                                self.classifierType, classifierID)
                        self.client.predictionStateOn()
                        self.client.hasGUI = False
                        # sys.stdout.write('classifierClient started by ' + str(channelOpt) + ' channel.')

                    except Exception as e:
                        print(str(e))
                        raise e

                    try:
                        eegFilePath = self.postDir + '/' + inputFileName
                        self.server = EEGFileReaderServer(
                            self.client,
                            eegFilePath,
                            samplingFreq=self.samplingFreq)

                    except Exception as e:
                        print(str(e))
                        raise e

                else:
                    print('  skipping ' + inputFileID + ' because ' +
                          predFileFullPath + ' exists.')