예제 #1
0
def run(file_flag):
    # Shimer MAC addresses
    # shimm_addr= ["00:06:66:46:9A:67", "00:06:66:46:B6:4A"]#, "00:06:66:46:BD:8D", "00:06:66:46:9A:1A", "00:06:66:46:BD:BF"]
    shimm_addr = ["00:06:66:46:B7:D4"]
    emg_addr = []  # ["00:06:66:46:9A:1A", "00:06:66:46:BD:BF"]

    # Configuration parameters
    scan_flag = 1
    plot_flag = 0

    sock_port = 1
    nodes = []
    plt_axx = 500
    plt_ylim = 4000
    plt_rate = 20

    rng_size = 50

    # rng_acc_x=RingBuffer(50)
    # Add sample to ringbuffer
    # rng_acc_x.append(pack_0)
    # buff1= np.zeros((n_nodes,10,rng_size),dtype=np.int)
    # buff2= np.zeros((n_nodes,10,rng_size),dtype=np.int)
    # buff_flag= 1
    # buff= [[[0 for x in range(10)] for y in range(2)] for z in range(rng_size)]
    # buff_idx= 0

    if plot_flag == 1:
        # plot parameters
        sample_idx = 0
        analogData = AnalogData(plt_axx)

    # Get the list of available nodes
    if scan_flag == 0:
        target_addr = shimm_addr
    else:
        try:
            target_addr = []
            print("Scanning bluetooth devices...")
            nearby_devices = bluetooth.discover_devices()
            for bdaddr in nearby_devices:
                print("			" + str(bdaddr) + " - " +
                      bluetooth.lookup_name(bdaddr))
                if bdaddr in shimm_addr:
                    target_addr.append(bdaddr)
        except:
            print("[Error] Problem while scanning bluetooth")
            sys.exit(1)

    n_nodes = len(target_addr)
    if n_nodes > 0:
        print(("Found %d target Shimmer nodes") % (len(target_addr)))
    else:
        print("Could not find target bluetooth device nearby. Exiting")
        sys.exit(1)

    print("Configuring the nodes...")
    for node_idx, bdaddr in enumerate(target_addr):
        try:
            # Connecting to the sensors
            sock = bluetooth.BluetoothSocket(bluetooth.RFCOMM)
            if bdaddr in emg_addr:
                n = shimmer_node(bdaddr, sock, 0x2)
            else:
                n = shimmer_node(bdaddr, sock, 0x1)
            nodes.append(n)

            print((bdaddr, sock_port), end=' ')
            nodes[-1].sock.connect((bdaddr, sock_port))
            print(" Shimmer %d (" % (node_idx) +
                  bluetooth.lookup_name(bdaddr) + ") [Connected]")
            # send the set sensors command
            nodes[-1].sock.send(
                struct.pack('BBB', 0x08, nodes[-1].senscfg_hi,
                            nodes[-1].senscfg_lo))
            nodes[-1].wait_for_ack()

            # send the set sampling rate command
            nodes[-1].sock.send(struct.pack('BB', 0x05, 0x14))  # 51.2Hz
            nodes[-1].wait_for_ack()

            # Inquiry command
            print("	Shimmer %d (" % (node_idx) +
                  bluetooth.lookup_name(bdaddr) + ") [Configured]")
            nodes[-1].sock.send(struct.pack('B', 0x01))
            nodes[-1].wait_for_ack()
            inq = nodes[-1].read_inquiry()
        except bluetooth.btcommon.BluetoothError as e:
            print(("BluetoothError during read_data: {0}".format(e.strerror)))
            print("Unable to connect to the nodes. Exiting")
            sys.exit(1)

    # Create file and plot
    try:
        if file_flag == 1:
            # Create buffer
            now = datetime.datetime.now()
            qc.make_dirs('../DATA')
            logname = "../DATA/IMU_" + now.strftime("%Y%m%d%H%M") + ".log"
            print("[cnbi_shimmer] Creating file: %s" % (logname))
            outfile = open(logname, "w")
            for node_idx, shim in enumerate(nodes):
                outfile.write(str(node_idx) + ": " + str(shim.addr) + "\n")
            outfile.close()

            fname = "../DATA/IMU_" + now.strftime("%Y%m%d%H%M") + ".dat"
            print("[cnbi_shimmer] Creating file: %s" % (fname))
            outfile = open(fname, "w")

        # Create plot
        if plot_flag == 1:
            analogPlot = AnalogPlot(analogData)
            plt.axis([0, plt_axx, 0, plt_ylim])
            plt.ion()
            plt.show()
    except:
        print("[Error]: Error creating file/plot!! Exiting")
        # close the socket
        print("Closing nodes")
        for node_idx, shim in enumerate(nodes):
            shim.sock.close()
            print("	Shimmer %d [Ok]" % (node_idx))
        sys.exit(1)

    print(
        "[cnbi_shimmer] Recording started. Press Ctrl+C to finish recording.")
    # send start streaming command
    for shim in nodes:
        shim.sock.send(struct.pack('B', 0x07))

    for node_idx, shim in enumerate(nodes):
        shim.wait_for_ack()
        shim.up = 1
        print("	Shimmer %d [Ok]" % (node_idx))

    # Main acquisition loop
    while True:
        try:
            sample = []
            sample_lslclock = []
            for shim in nodes:
                if shim.up == 1:
                    sample.append(shim.read_data())
                else:
                    sample.append([0] * (shim.n_fields))

            for samp in sample:
                sample_lslclock.append([pylsl.local_clock()] + list(samp[1:]))

            if file_flag == 1:
                simplejson.dump(sample_lslclock,
                                outfile,
                                separators=(',', ';'))
                outfile.write('\n')

            # print sample
            # plt.title(str(sample[0][0]))

            # leeq
            if plot_flag == 1:
                analogData.add([sample[0][1], sample[0][2]])
                sample_idx = sample_idx + 1
                if sample_idx % plt_rate == 0:
                    analogPlot.update(analogData)

            if file_flag == 0:
                print(qc.list2string(sample_lslclock[1], '%9.1f', ' '))

        # Exit if key is pressed
        except KeyboardInterrupt:
            print("\n[cnbi_shimmer] Stopping acquisition....")
            break
        except bluetooth.btcommon.BluetoothError as e:
            print(("[Error] BluetoothError during read_data: {0}".format(
                e.strerror)))

    # send stop streaming command
    print("[cnbi_shimmer] Stopping streaming")
    try:
        for shim in nodes:
            shim.sock.send(struct.pack('B', 0x20))
        for node_idx, shim in enumerate(nodes):
            shim.wait_for_ack()
            print("	Shimmer %d [Ok]" % (node_idx))
    except bluetooth.btcommon.BluetoothError as e:
        print(("[Error] BluetoothError during read_data: {0}".format(
            e.strerror)))
    '''
        n_nodes =	len(target_addr)
        while n_nodes>0:
        sample= []
        for node_idx,shim in enumerate(nodes):
        pckt= shim.wait_stop_streaming()
        print "	Shimmer %d [waiting]" % (node_idx)
        if len(pckt) != 1:
            sample.append(pckt)
        else:
        sample.append(str("0"*(shim.samplesize)))
        nodes.remove(shim)
        n_nodes= n_nodes-1
        print "	Shimmer %d [Ok]" % (node_idx)
        simplejson.dump(sample, outfile, separators=(',',';'))
        analogData.add([sample[0][1],sample[1][1]])
        analogPlot.update(analogData)
    '''

    # Closing	file
    if file_flag == 1:
        print("[cnbi_shimmer] Closing file: %s" % (fname))
        try:
            outfile.close()
        except:
            print("			[Error] Problem closing file!")

    # close the socket
    print("[cnbi_shimmer] Closing nodes")
    for node_idx, shim in enumerate(nodes):
        shim.sock.close()
        print("	Shimmer %d [Ok]" % (node_idx))

    print("[cnbi_shimmer] Recording Finished. Please close this window.")
    getch()
예제 #2
0
    out_path = DATA_PATH + '/epochs'
    qc.make_dirs(out_path)

    # load data
    raw, events = pu.load_multi(rawlist, multiplier=MULTIPLIER)
    raw.pick_types(meg=False, eeg=True, stim=False)
    sfreq = raw.info['sfreq']
    if REF_CH_NEW is not None:
        pu.rereference(raw, REF_CH_NEW, REF_CH_OLD)

    # pick channels
    if CHANNEL_PICKS is None:
        picks = [raw.ch_names.index(c) for c in raw.ch_names if c not in EXCLUDES]
    elif type(CHANNEL_PICKS[0]) == str:
        picks = [raw.ch_names.index(c) for c in CHANNEL_PICKS]
    else:
        assert type(CHANNEL_PICKS[0]) is int
        picks = CHANNEL_PICKS

    # do epoching
    for epname, epval in EPOCHS.items():
        epochs = mne.Epochs(raw, events, dict(epname=epval), tmin=TMIN, tmax=TMAX,
                            proj=False, picks=picks, baseline=None, preload=True)
        data = epochs.get_data()  # epochs x channels x times
        for i, ep_data in enumerate(data):
            fout = '%s/%s-%d.txt' % (out_path, epname, i + 1)
            with open(fout, 'w') as f:
                for t in range(ep_data.shape[1]):
                    f.write(qc.list2string(ep_data[:, t], '%.6f') + '\n')
            logger.info('Exported %s' % fout)
예제 #3
0
    def classify(self,
                 decoder,
                 true_label,
                 title_text,
                 bar_dirs,
                 state='start',
                 prob_history=None):
        """
        Run a single trial
        """
        true_label_index = bar_dirs.index(true_label)
        self.tm_trigger.reset()
        if self.bar_bias is not None:
            bias_idx = bar_dirs.index(self.bar_bias[0])

        if self.logf is not None:
            self.logf.write('True label: %s\n' % true_label)

        tm_classify = qc.Timer(autoreset=True)
        self.stimo_timer = qc.Timer()
        while True:
            self.tm_display.sleep_atleast(self.refresh_delay)
            self.tm_display.reset()
            if state == 'start' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['INIT']:
                state = 'gap_s'
                if self.cfg.TRIALS_PAUSE:
                    self.viz.put_text('Press any key')
                    self.viz.update()
                    key = cv2.waitKeyEx()
                    if key == KEYS['esc']:
                        return
                self.viz.fill()
                self.tm_trigger.reset()
                self.trigger.signal(self.tdef.INIT)

            elif state == 'gap_s':
                if self.cfg.TIMINGS['GAP'] > 0:
                    self.viz.put_text(title_text)
                state = 'gap'
                self.tm_trigger.reset()

            elif state == 'gap' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['GAP']:
                state = 'cue'
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                self.trigger.signal(self.tdef.CUE)
                self.tm_trigger.reset()

            elif state == 'cue' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['READY']:
                state = 'dir_r'

                if self.cfg.SHOW_CUE is True:
                    if self.cfg.FEEDBACK_TYPE == 'BAR':
                        self.viz.move(true_label,
                                      100,
                                      overlay=False,
                                      barcolor='G')
                    elif self.cfg.FEEDBACK_TYPE == 'BODY':
                        self.viz.put_text(DIRS[true_label], 'R')
                    if true_label == 'L':  # left
                        self.trigger.signal(self.tdef.LEFT_READY)
                    elif true_label == 'R':  # right
                        self.trigger.signal(self.tdef.RIGHT_READY)
                    elif true_label == 'U':  # up
                        self.trigger.signal(self.tdef.UP_READY)
                    elif true_label == 'D':  # down
                        self.trigger.signal(self.tdef.DOWN_READY)
                    elif true_label == 'B':  # both hands
                        self.trigger.signal(self.tdef.BOTH_READY)
                    else:
                        raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(False)
                self.viz.move(true_label, 100, overlay=False, barcolor='G')

                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    self.viz.set_pc_feedback(True)
                    if self.cfg.SHOW_CUE is True:
                        self.viz.put_text(dirs[true_label], 'R')

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFREADY)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_READY)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_READY)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_READY)
                elif true_label == 'B':  # both hands
                    self.trigger.signal(self.tdef.BOTH_READY)
                else:
                    raise RuntimeError('Unknown direction %s' % true_label)
                self.tm_trigger.reset()
                '''
            elif state == 'dir_r' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['DIR_CUE']:
                self.viz.fill()
                self.viz.draw_cue()
                self.viz.glass_draw_cue()
                state = 'dir'

                # initialize bar scores
                bar_label = bar_dirs[0]
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_label, bar_score, overlay=False)
                probs_acc = np.zeros(len(probs))

                if true_label == 'L':  # left
                    self.trigger.signal(self.tdef.LEFT_GO)
                elif true_label == 'R':  # right
                    self.trigger.signal(self.tdef.RIGHT_GO)
                elif true_label == 'U':  # up
                    self.trigger.signal(self.tdef.UP_GO)
                elif true_label == 'D':  # down
                    self.trigger.signal(self.tdef.DOWN_GO)
                elif true_label == 'B':  # both
                    self.trigger.signal(self.tdef.BOTH_GO)
                else:
                    raise RuntimeError('Unknown truedirection %s' % true_label)

                self.tm_watchdog.reset()
                self.tm_trigger.reset()

            elif state == 'dir':
                if self.tm_trigger.sec() > self.cfg.TIMINGS['CLASSIFY'] or (
                        self.premature_end and bar_score >= 100):
                    if not hasattr(
                            self.cfg,
                            'SHOW_RESULT') or self.cfg.SHOW_RESULT is True:
                        # show classfication result
                        if self.cfg.WITH_STIMO is True:
                            if self.cfg.STIMO_FULLGAIT_CYCLE is not None and bar_label == 'U':
                                res_color = 'G'
                            elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                                res_color = 'G'
                            else:
                                res_color = 'Y'
                        else:
                            res_color = 'Y'
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption=DIRS[bar_label],
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          100,
                                          overlay=False,
                                          barcolor=res_color)
                    else:
                        if self.cfg.FEEDBACK_TYPE == 'BODY':
                            self.viz.move(bar_label,
                                          bar_score,
                                          overlay=False,
                                          barcolor=res_color,
                                          caption='TRIAL END',
                                          caption_color=res_color)
                        else:
                            self.viz.move(bar_label,
                                          0,
                                          overlay=False,
                                          barcolor=res_color)
                    self.trigger.signal(self.tdef.FEEDBACK)

                    # STIMO
                    if self.cfg.WITH_STIMO is True and self.cfg.STIMO_CONTINUOUS is False:
                        if self.cfg.STIMO_FULLGAIT_CYCLE is not None:
                            if bar_label == 'U':
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[0])
                                logger.info('STIMO: Sent 1')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                                self.ser.write(
                                    self.cfg.STIMO_FULLGAIT_PATTERN[1])
                                logger.info('STIMO: Sent 2')
                                time.sleep(self.cfg.STIMO_FULLGAIT_CYCLE)
                        elif self.cfg.TRIALS_RETRY is False or bar_label == true_label:
                            if bar_label == 'L':
                                self.ser.write(b'1')
                                logger.info('STIMO: Sent 1')
                            elif bar_label == 'R':
                                self.ser.write(b'2')
                                logger.info('STIMO: Sent 2')

                    # FES event mode mode
                    if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is False:
                        if bar_label == 'L':
                            stim_code = [0, 30, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            logger.info('FES: Sent Left')
                            time.sleep(0.5)
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                        elif bar_label == 'R':
                            stim_code = [30, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)
                            time.sleep(0.5)
                            logger.info('FES: Sent Right')
                            stim_code = [0, 0, 0, 0, 0, 0, 0, 0]
                            self.stim.UpdateChannelSettings(stim_code)

                    if self.cfg.DEBUG_PROBS:
                        msg = 'DEBUG: Accumulated probabilities = %s' % qc.list2string(
                            probs_acc, '%.3f')
                        logger.info(msg)
                        if self.logf is not None:
                            self.logf.write(msg + '\n')
                    if self.logf is not None:
                        self.logf.write('%s detected as %s (%d)\n\n' %
                                        (true_label, bar_label, bar_score))
                        self.logf.flush()

                    # end of trial
                    state = 'feedback'
                    self.tm_trigger.reset()
                else:
                    # classify
                    probs_new = decoder.get_prob_smooth_unread()
                    if probs_new is None:
                        if self.tm_watchdog.sec() > 3:
                            logger.warning(
                                'No classification being done. Are you receiving data streams?'
                            )
                            self.tm_watchdog.reset()
                    else:
                        self.tm_watchdog.reset()

                        if prob_history is not None:
                            prob_history[true_label].append(
                                probs_new[true_label_index])

                        probs_acc += np.array(probs_new)
                        '''
                        New decoder: already smoothed by the decoder so bias after.
                        '''
                        probs = list(probs_new)
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p / newsum for p in probs]
                        '''
                        # Method 2: bias and smoothen
                        if self.bar_bias is not None:
                            # print('BEFORE: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                            probs_new[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs_new)
                            probs_new = [p / newsum for p in probs_new]
                        # print('AFTER: %.3f %.3f'% (probs_new[0], probs_new[1]) )
                        for i in range(len(probs_new)):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        '''
                        ''' Original method
                        # Method 1: smoothen and bias
                        for i in range( len(probs_new) ):
                            probs[i] = probs[i] * self.alpha_old + probs_new[i] * self.alpha_new
                        # bias bar
                        if self.bar_bias is not None:
                            probs[bias_idx] += self.bar_bias[1]
                            newsum = sum(probs)
                            probs = [p/newsum for p in probs]
                        '''

                        # determine the direction
                        # TODO: np.argmax(probs)
                        max_pidx = qc.get_index_max(probs)
                        max_label = bar_dirs[max_pidx]

                        if self.cfg.POSITIVE_FEEDBACK is False or \
                                (self.cfg.POSITIVE_FEEDBACK and true_label == max_label):
                            dx = probs[max_pidx]
                            if max_label == 'R':
                                dx *= self.bar_step_right
                            elif max_label == 'L':
                                dx *= self.bar_step_left
                            elif max_label == 'U':
                                dx *= self.bar_step_up
                            elif max_label == 'D':
                                dx *= self.bar_step_down
                            elif max_label == 'B':
                                dx *= self.bar_step_both
                            else:
                                logger.debug('Direction %s using bar step %d' %
                                             (max_label, self.bar_step_left))
                                dx *= self.bar_step_left

                            # slow start
                            selected = self.cfg.BAR_SLOW_START['selected']
                            if self.cfg.BAR_SLOW_START[
                                    selected] and self.tm_trigger.sec(
                                    ) < self.cfg.BAR_SLOW_START[selected]:
                                dx *= self.tm_trigger.sec(
                                ) / self.cfg.BAR_SLOW_START[selected][0]

                            # add likelihoods
                            if max_label == bar_label:
                                bar_score += dx
                            else:
                                bar_score -= dx
                                # change of direction
                                if bar_score < 0:
                                    bar_score = -bar_score
                                    bar_label = max_label
                            bar_score = int(bar_score)
                            if bar_score > 100:
                                bar_score = 100
                            if self.cfg.FEEDBACK_TYPE == 'BODY':
                                if self.cfg.SHOW_CUE:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False,
                                                  caption=DIRS[true_label],
                                                  caption_color='G')
                                else:
                                    self.viz.move(bar_label,
                                                  bar_score,
                                                  overlay=False)
                            else:
                                self.viz.move(bar_label,
                                              bar_score,
                                              overlay=False)

                            # send the confidence value continuously
                            if self.cfg.WITH_STIMO and self.cfg.STIMO_CONTINUOUS:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'U':
                                        stimo_code = bar_score
                                    else:
                                        stimo_code = 0
                                    self.ser.write(bytes([stimo_code]))
                                    logger.info('Sent STIMO code %d' %
                                                stimo_code)
                                    self.stimo_timer.reset()
                            # with FES
                            if self.cfg.WITH_FES is True and self.cfg.FES_CONTINUOUS is True:
                                if self.stimo_timer.sec(
                                ) >= self.cfg.STIMO_COOLOFF:
                                    if bar_label == 'L':
                                        stim_code = [
                                            bar_score, 0, 0, 0, 0, 0, 0, 0
                                        ]
                                    else:
                                        stim_code = [
                                            0, bar_score, 0, 0, 0, 0, 0, 0
                                        ]
                                    self.stim.UpdateChannelSettings(stim_code)
                                    logger.info('Sent FES code %d' % bar_score)
                                    self.stimo_timer.reset()

                        if self.cfg.DEBUG_PROBS:
                            if self.bar_bias is not None:
                                biastxt = '[Bias=%s%.3f]  ' % (
                                    self.bar_bias[0], self.bar_bias[1])
                            else:
                                biastxt = ''
                            msg = '%s%s  prob %s   acc %s   bar %s%d  (%.1f ms)' % \
                                  (biastxt, bar_dirs, qc.list2string(probs_new, '%.2f'), qc.list2string(probs, '%.2f'),
                                   bar_label, bar_score, tm_classify.msec())
                            logger.info(msg)
                            if self.logf is not None:
                                self.logf.write(msg + '\n')

            elif state == 'feedback' and self.tm_trigger.sec(
            ) > self.cfg.TIMINGS['FEEDBACK']:
                self.trigger.signal(self.tdef.BLANK)
                if self.cfg.FEEDBACK_TYPE == 'BODY':
                    state = 'return'
                    self.tm_trigger.reset()
                else:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            elif state == 'return':
                self.viz.set_glass_feedback(False)
                if self.cfg.WITH_STIMO:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='B')
                else:
                    self.viz.move(bar_label,
                                  bar_score,
                                  overlay=False,
                                  barcolor='Y')
                self.viz.set_glass_feedback(True)
                bar_score -= 5
                if bar_score <= 0:
                    state = 'gap_s'
                    self.viz.fill()
                    self.viz.update()
                    return bar_label

            self.viz.update()
            key = cv2.waitKeyEx(1)
            if key == KEYS['esc']:
                return
            elif key == KEYS['space']:
                dx = 0
                bar_score = 0
                probs = [1.0 / len(bar_dirs)] * len(bar_dirs)
                self.viz.move(bar_dirs[0], bar_score, overlay=False)
                self.viz.update()
                logger.info('probs and dx reset.')
                self.tm_trigger.reset()
            elif key in ARROW_KEYS and ARROW_KEYS[key] in bar_dirs:
                # change bias on the fly
                if self.bar_bias is None:
                    self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                else:
                    if ARROW_KEYS[key] == self.bar_bias[0]:
                        self.bar_bias[1] += BIAS_INCREMENT
                    elif self.bar_bias[1] >= BIAS_INCREMENT:
                        self.bar_bias[1] -= BIAS_INCREMENT
                    else:
                        self.bar_bias = [ARROW_KEYS[key], BIAS_INCREMENT]
                if self.bar_bias[1] == 0:
                    self.bar_bias = None
                else:
                    bias_idx = bar_dirs.index(self.bar_bias[0])
예제 #4
0
def test_receiver():
    import mne
    import os

    CH_INDEX = [1] # channel to monitor
    TIME_INDEX = None # integer or None. None = average of raw values of the current window
    SHOW_PSD = False
    mne.set_log_level('ERROR')
    os.environ['OMP_NUM_THREADS'] = '1' # actually improves performance for multitaper

    # connect to LSL server
    amp_name, amp_serial = pu.search_lsl()
    sr = StreamReceiver(window_size=1, buffer_size=1, amp_serial=amp_serial, eeg_only=False, amp_name=amp_name)
    sfreq = sr.get_sample_rate()
    trg_ch = sr.get_trigger_channel()
    logger.info('Trigger channel = %d' % trg_ch)

    # PSD init
    if SHOW_PSD:
        psde = mne.decoding.PSDEstimator(sfreq=sfreq, fmin=1, fmax=50, bandwidth=None, \
            adaptive=False, low_bias=True, n_jobs=1, normalization='length', verbose=None)

    watchdog = qc.Timer()
    tm = qc.Timer(autoreset=True)
    last_ts = 0
    while True:
        sr.acquire()
        window, tslist = sr.get_window() # window = [samples x channels]
        window = window.T # chanel x samples

        qc.print_c('LSL Diff = %.3f' % (pylsl.local_clock() - tslist[-1]), 'G')

        # print event values
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue
        trigger = np.unique(window[trg_ch, tsnew[0]:])

        # for Biosemi
        # if sr.amp_name=='BioSemi':
        #    trigger= set( [255 & int(x-1) for x in trigger ] )

        if len(trigger) > 0:
            logger.info('Triggers: %s' % np.array(trigger))

        logger.info('[%.1f] Receiving data...' % watchdog.sec())

        if TIME_INDEX is None:
            datatxt = qc.list2string(np.mean(window[CH_INDEX, :], axis=1), '%-15.6f')
            print('[%.3f : %.3f]' % (tslist[0], tslist[-1]) + ' data: %s' % datatxt)
        else:
            datatxt = qc.list2string(window[CH_INDEX, TIME_INDEX], '%-15.6f')
            print('[%.3f]' % tslist[TIME_INDEX] + ' data: %s' % datatxt)

        # show PSD
        if SHOW_PSD:
            psd = psde.transform(window.reshape((1, window.shape[0], window.shape[1])))
            psd = psd.reshape((psd.shape[1], psd.shape[2]))
            psdmean = np.mean(psd, axis=1)
            for p in psdmean:
                print('%.1f' % p, end=' ')

        last_ts = tslist[-1]
        tm.sleep_atleast(0.05)
예제 #5
0
def feature_importances_topo(featfile,
                             topo_layout_file=None,
                             channels=None,
                             channel_name_show=None):
    """
    Compute feature importances across frequency bands and channels

    @params
    topo_laytout_file: if not None, topography map images will be generated and saved.
    channel_name_show: list of channel names to show on topography map.

    """
    logger.info('Loading %s' % featfile)

    if channels is None:
        channel_set = set()
        with open(featfile) as f:
            f.readline()
            for l in f:
                ch = l.strip().split('\t')[1]
                channel_set.add(ch)
        channels = list(channel_set)

    # channel index lookup table
    ch2index = {ch: i for i, ch in enumerate(channels)}

    data_delta = np.zeros(len(channels))
    data_theta = np.zeros(len(channels))
    data_mu = np.zeros(len(channels))
    data_beta = np.zeros(len(channels))
    data_beta1 = np.zeros(len(channels))
    data_beta2 = np.zeros(len(channels))
    data_beta3 = np.zeros(len(channels))
    data_lgamma = np.zeros(len(channels))
    data_hgamma = np.zeros(len(channels))
    data_per_ch = np.zeros(len(channels))

    f = open(featfile)
    f.readline()
    for l in f:
        token = l.strip().split('\t')
        importance = float(token[0])
        ch = token[1]
        fq = float(token[2])
        if fq <= 3:
            data_delta[ch2index[ch]] += importance
        elif fq <= 7:
            data_theta[ch2index[ch]] += importance
        elif fq <= 12:
            data_mu[ch2index[ch]] += importance
        elif fq <= 30:
            data_beta[ch2index[ch]] += importance
        elif fq <= 70:
            data_lgamma[ch2index[ch]] += importance
        else:
            data_hgamma[ch2index[ch]] += importance
        if 12.5 <= fq <= 16:
            data_beta1[ch2index[ch]] += importance
        elif fq <= 20:
            data_beta2[ch2index[ch]] += importance
        elif fq <= 28:
            data_beta3[ch2index[ch]] += importance
        data_per_ch[ch2index[ch]] += importance

    hlen = 18 + len(channels) * 7
    result = '>> Feature importance distribution\n'
    result += 'bands   ' + qc.list2string(channels,
                                          '%6s') + ' | ' + 'per band\n'
    result += '-' * hlen + '\n'
    result += 'delta   ' + qc.list2string(
        data_delta, '%6.2f') + ' | %6.2f\n' % np.sum(data_delta)
    result += 'theta   ' + qc.list2string(
        data_theta, '%6.2f') + ' | %6.2f\n' % np.sum(data_theta)
    result += 'mu      ' + qc.list2string(
        data_mu, '%6.2f') + ' | %6.2f\n' % np.sum(data_mu)
    #result += 'beta    ' + qc.list2string(data_beta, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta)
    result += 'beta1   ' + qc.list2string(
        data_beta1, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta1)
    result += 'beta2   ' + qc.list2string(
        data_beta2, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta2)
    result += 'beta3   ' + qc.list2string(
        data_beta3, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta3)
    result += 'lgamma  ' + qc.list2string(
        data_lgamma, '%6.2f') + ' | %6.2f\n' % np.sum(data_lgamma)
    result += 'hgamma  ' + qc.list2string(
        data_hgamma, '%6.2f') + ' | %6.2f\n' % np.sum(data_hgamma)
    result += '-' * hlen + '\n'
    result += 'per_ch  ' + qc.list2string(data_per_ch, '%6.2f') + ' | 100.00\n'
    print(result)
    p = qc.parse_path(featfile)
    open('%s/%s_summary.txt' % (p.dir, p.name), 'w').write(result)

    # export topo maps
    if topo_layout_file is not None:
        # default visualization setting
        res = 64
        contours = 6

        # select channel names to show
        if channel_name_show is None:
            channel_name_show = channels
        chan_vis = [''] * len(channels)
        for ch in channel_name_show:
            chan_vis[channels.index(ch)] = ch

        # set channel locations and reverse lookup table
        chanloc = {}
        if not os.path.exists(topo_layout_file):
            topo_layout_file = NEUROD_ROOT + '/layout/' + topo_layout_file
            if not os.path.exists(topo_layout_file):
                raise FileNotFoundError('Layout file %s not found.' %
                                        topo_layout_file)
        logger.info('Using layout %s' % topo_layout_file)
        for l in open(topo_layout_file):
            token = l.strip().split('\t')
            ch = token[5]
            x = float(token[1])
            y = float(token[2])
            chanloc[ch] = [x, y]
        pos = np.zeros((len(channels), 2))
        for i, ch in enumerate(channels):
            pos[i] = chanloc[ch]

        vmin = min(data_per_ch)
        vmax = max(data_per_ch)
        total = sum(data_per_ch)
        rate_delta = sum(data_delta) * 100.0 / total
        rate_theta = sum(data_theta) * 100.0 / total
        rate_mu = sum(data_mu) * 100.0 / total
        rate_beta = sum(data_beta) * 100.0 / total
        rate_beta1 = sum(data_beta1) * 100.0 / total
        rate_beta2 = sum(data_beta2) * 100.0 / total
        rate_beta3 = sum(data_beta3) * 100.0 / total
        rate_lgamma = sum(data_lgamma) * 100.0 / total
        rate_hgamma = sum(data_hgamma) * 100.0 / total
        export_topo(data_per_ch,
                    pos,
                    'features_topo_all.png',
                    xlabel='all bands 1-40 Hz',
                    chan_vis=chan_vis)
        export_topo(data_delta,
                    pos,
                    'features_topo_delta.png',
                    xlabel='delta 1-3 Hz (%.1f%%)' % rate_delta,
                    chan_vis=chan_vis)
        export_topo(data_theta,
                    pos,
                    'features_topo_theta.png',
                    xlabel='theta 4-7 Hz (%.1f%%)' % rate_theta,
                    chan_vis=chan_vis)
        export_topo(data_mu,
                    pos,
                    'features_topo_mu.png',
                    xlabel='mu 8-12 Hz (%.1f%%)' % rate_mu,
                    chan_vis=chan_vis)
        export_topo(data_beta,
                    pos,
                    'features_topo_beta.png',
                    xlabel='beta 13-30 Hz (%.1f%%)' % rate_beta,
                    chan_vis=chan_vis)
        export_topo(data_beta1,
                    pos,
                    'features_topo_beta1.png',
                    xlabel='beta 12.5-16 Hz (%.1f%%)' % rate_beta1,
                    chan_vis=chan_vis)
        export_topo(data_beta2,
                    pos,
                    'features_topo_beta2.png',
                    xlabel='beta 16-20 Hz (%.1f%%)' % rate_beta2,
                    chan_vis=chan_vis)
        export_topo(data_beta3,
                    pos,
                    'features_topo_beta3.png',
                    xlabel='beta 20-28 Hz (%.1f%%)' % rate_beta3,
                    chan_vis=chan_vis)
        export_topo(data_lgamma,
                    pos,
                    'features_topo_lowgamma.png',
                    xlabel='low gamma 31-40 Hz (%.1f%%)' % rate_lgamma,
                    chan_vis=chan_vis)