Esempio n. 1
0
class BCIDecoder(object):
    """
    Decoder class

    The label order of self.labels and self.label_names match likelihood orders computed by get_prob()

    """

    def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None, amp_name=None):
        """
        Params
        ------
        classifier: classifier file
        spatial: spatial filter to use
        buffer_size: length of the signal buffer in seconds
        """

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name

        if self.fake == False:
            model = qc.load_obj(self.classifier)
            if model == None:
                self.print('Error loading %s' % model)
                sys.exit(-1)
            self.cls = model['cls']
            self.psde = model['psde']
            self.labels = list(self.cls.classes_)
            self.label_names = [model['classes'][k] for k in self.labels]
            self.spatial = model['spatial']
            self.spectral = model['spectral']
            self.notch = model['notch']
            self.w_seconds = model['w_seconds']
            self.w_frames = model['w_frames']
            self.wstep = model['wstep']
            self.sfreq = model['sfreq']
            if not int(self.sfreq * self.w_seconds) == self.w_frames:
                raise RuntimeError('sfreq * w_sec %d != w_frames %d' % (int(self.sfreq * self.w_seconds), self.w_frames))

            if 'multiplier' in model:
                self.multiplier = model['multiplier']
            else:
                self.multiplier = 1

            # Stream Receiver
            self.sr = StreamReceiver(window_size=self.w_seconds, amp_name=self.amp_name, amp_serial=self.amp_serial)
            if self.sfreq != self.sr.sample_rate:
                raise RuntimeError('Amplifier sampling rate (%.1f) != model sampling rate (%.1f). Stop.' % (
                    self.sr.sample_rate, self.sfreq))

            # Map channel indices based on channel names of the streaming server
            self.spatial_ch = model['spatial_ch']
            self.spectral_ch = model['spectral_ch']
            self.notch_ch = model['notch_ch']
            self.ref_new = model['ref_new']
            self.ref_old = model['ref_old']
            self.ch_names = self.sr.get_channel_names()
            mc = model['ch_names']
            self.picks = [self.ch_names.index(mc[p]) for p in model['picks']]
            if self.spatial_ch is not None:
                self.spatial_ch = [self.ch_names.index(mc[p]) for p in model['spatial_ch']]
            if self.spectral_ch is not None:
                self.spectral_ch = [self.ch_names.index(mc[p]) for p in model['spectral_ch']]
            if self.notch_ch is not None:
                self.notch_ch = [self.ch_names.index(mc[p]) for p in model['notch_ch']]
            if self.ref_new is not None:
                self.ref_new = self.ch_names.index(mc[model['ref_new']])
            if self.ref_old is not None:
                self.ref_old = self.ch_names.index(mc[model['ref_old']])

            # PSD buffer
            psd_temp = self.psde.transform(np.zeros((1, len(self.picks), self.w_frames)))
            self.psd_shape = psd_temp.shape
            self.psd_size = psd_temp.size
            self.psd_buffer = np.zeros((0, self.psd_shape[1], self.psd_shape[2]))
            self.ts_buffer = []

        else:
            # Fake left-right decoder
            model = None
            self.psd_shape = None
            self.psd_size = None
            # TODO: parameterize directions using fake_dirs
            self.labels = [11, 9]
            self.label_names = ['LEFT_GO', 'RIGHT_GO']

    def print(self, *args):
        if len(args) > 0: print('[BCIDecoder] ', end='')
        print(*args)

    def get_labels(self):
        """
        Returns
        -------
        Class labels numbers in the same order as the likelihoods returned by get_prob()
        """
        return self.labels

    def get_label_names(self):
        """
        Returns
        -------
        Class label names in the same order as get_labels()
        """
        return self.label_names

    def start(self):
        pass

    def stop(self):
        pass

    def get_prob(self):
        """
        Read the latest window

        Returns
        -------
        The likelihood P(X|C), where X=window, C=model
        """
        if self.fake:
            # fake deocder: biased likelihood for the first class
            probs = [random.uniform(0.0, 1.0)]
            # others class likelihoods are just set to equal
            p_others = (1 - probs[0]) / (len(self.labels) - 1)
            for x in range(1, len(self.labels)):
                probs.append(p_others)
            time.sleep(0.0625)  # simulated delay for PSD + RF
        else:
            self.sr.acquire()
            w, ts = self.sr.get_window()  # w = times x channels
            w = w.T  # -> channels x times

            # apply filters. Important: maintain the original channel order at this point.
            pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch,
                          spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch,
                          notch_ch=self.notch_ch, multiplier=self.multiplier)

            # select the same channels used for training
            w = w[self.picks]

            # debug: show max - min
            # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) )

            # psd = channels x freqs
            psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1])))

            # update psd buffer ( < 1 msec overhead )
            self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0)
            self.ts_buffer.append(ts[0])
            if ts[0] - self.ts_buffer[0] > self.buffer_sec:
                # search speed comparison for ordered arrays:
                # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value
                t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0)
                self.ts_buffer = self.ts_buffer[t_index:]
                self.psd_buffer = self.psd_buffer[t_index:, :, :]  # numpy delete is slower
            # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec

            # make a feautre vector and classify
            feats = np.concatenate(psd[0]).reshape(1, -1)

            # compute likelihoods
            probs = self.cls.predict_proba(feats)[0]

        return probs

    def get_prob_unread(self):
        return self.get_prob()

    def get_psd(self):
        """
        Returns
        -------
        The latest computed PSD
        """
        return self.psd_buffer[-1].reshape((1, -1))

    def is_ready(self):
        """
        Ready to decode? Returns True if buffer is not empty.
        """
        return self.sr.is_ready()
Esempio n. 2
0
class BCIDecoder(object):
    """
    Decoder class

    The label order of self.labels and self.label_names match likelihood orders computed by get_prob()

    """

    def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None, amp_name=None):
        """
        Params
        ------
        classifier: classifier file
        spatial: spatial filter to use
        buffer_size: length of the signal buffer in seconds
        """

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name

        if self.fake == False:
            model = qc.load_obj(self.classifier)
            if model is None:
                logger.error('Classifier model is None.')
                raise ValueError
            self.cls = model['cls']
            self.psde = model['psde']
            self.labels = list(self.cls.classes_)
            self.label_names = [model['classes'][k] for k in self.labels]
            self.spatial = model['spatial']
            self.spectral = model['spectral']
            self.notch = model['notch']
            self.w_seconds = model['w_seconds']
            self.w_frames = model['w_frames']
            self.wstep = model['wstep']
            self.sfreq = model['sfreq']
            if 'decim' not in model:
                model['decim'] = 1
            self.decim = model['decim']
            if not int(round(self.sfreq * self.w_seconds)) == self.w_frames:
                logger.error('sfreq * w_sec %d != w_frames %d' % (int(round(self.sfreq * self.w_seconds)), self.w_frames))
                raise RuntimeError

            if 'multiplier' in model:
                self.multiplier = model['multiplier']
            else:
                self.multiplier = 1

            # Stream Receiver
            self.sr = StreamReceiver(window_size=self.w_seconds, amp_name=self.amp_name, amp_serial=self.amp_serial)
            if self.sfreq != self.sr.sample_rate:
                logger.error('Amplifier sampling rate (%.3f) != model sampling rate (%.3f). Stop.' % (self.sr.sample_rate, self.sfreq))
                raise RuntimeError

            # Map channel indices based on channel names of the streaming server
            self.spatial_ch = model['spatial_ch']
            self.spectral_ch = model['spectral_ch']
            self.notch_ch = model['notch_ch']
            #self.ref_ch = model['ref_ch'] # not supported yet
            self.ch_names = self.sr.get_channel_names()
            mc = model['ch_names']
            self.picks = [self.ch_names.index(mc[p]) for p in model['picks']]
            if self.spatial_ch is not None:
                self.spatial_ch = [self.ch_names.index(mc[p]) for p in model['spatial_ch']]
            if self.spectral_ch is not None:
                self.spectral_ch = [self.ch_names.index(mc[p]) for p in model['spectral_ch']]
            if self.notch_ch is not None:
                self.notch_ch = [self.ch_names.index(mc[p]) for p in model['notch_ch']]

            # PSD buffer
            #psd_temp = self.psde.transform(np.zeros((1, len(self.picks), self.w_frames // self.decim)))
            #self.psd_shape = psd_temp.shape
            #self.psd_size = psd_temp.size
            #self.psd_buffer = np.zeros((0, self.psd_shape[1], self.psd_shape[2]))
            #self.psd_buffer = None

            self.ts_buffer = []

            logger.info_green('Loaded classifier %s (sfreq=%.3f, decim=%d)' % (' vs '.join(self.label_names), self.sfreq, self.decim))
        else:
            # Fake left-right decoder
            model = None
            self.psd_shape = None
            self.psd_size = None
            # TODO: parameterize directions using fake_dirs
            self.labels = [11, 9]
            self.label_names = ['LEFT_GO', 'RIGHT_GO']

    def get_labels(self):
        """
        Returns
        -------
        Class labels numbers in the same order as the likelihoods returned by get_prob()
        """
        return self.labels

    def get_label_names(self):
        """
        Returns
        -------
        Class label names in the same order as get_labels()
        """
        return self.label_names

    def start(self):
        pass

    def stop(self):
        pass

    def get_prob(self, timestamp=False):
        """
        Read the latest window

        Input
        -----
        timestamp: If True, returns LSL timestamp of the leading edge of the window used for decoding.

        Returns
        -------
        The likelihood P(X|C), where X=window, C=model
        """
        if self.fake:
            # fake deocder: biased likelihood for the first class
            probs = [random.uniform(0.0, 1.0)]
            # others class likelihoods are just set to equal
            p_others = (1 - probs[0]) / (len(self.labels) - 1)
            for x in range(1, len(self.labels)):
                probs.append(p_others)
            time.sleep(0.0625)  # simulated delay
            t_prob = pylsl.local_clock()
        else:
            self.sr.acquire(blocking=True)
            w, ts = self.sr.get_window()  # w = times x channels
            t_prob = ts[-1]
            w = w.T  # -> channels x times

            # re-reference channels
            # TODO: add re-referencing function to preprocess()

            # apply filters. Important: maintain the original channel order at this point.
            w = pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch,
                          spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch,
                          notch_ch=self.notch_ch, multiplier=self.multiplier, decim=self.decim)

            # select the same channels used for training
            w = w[self.picks]

            # debug: show max - min
            # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) )

            # psd = channels x freqs
            psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1])))

            # make a feautre vector and classify
            feats = np.concatenate(psd[0]).reshape(1, -1)

            # compute likelihoods
            probs = self.cls.predict_proba(feats)[0]

            # update psd buffer ( < 1 msec overhead )
            '''
            if self.psd_buffer is None:
                self.psd_buffer = psd
            else:
                self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0)
                # TODO: CHECK THIS BLOCK
                self.ts_buffer.append(ts[0])
                if ts[0] - self.ts_buffer[0] > self.buffer_sec:
                    # search speed comparison for ordered arrays:
                    # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value
                    #t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0)
                    t_index = np.searchsorted(self.ts_buffer, ts[0] - self.buffer_sec)
                    self.ts_buffer = self.ts_buffer[t_index:]
                    self.psd_buffer = self.psd_buffer[t_index:, :, :]  # numpy delete is slower
                # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec
            '''

        if timestamp:
            return probs, t_prob
        else:
            return probs

    def get_prob_unread(self, timestamp=False):
        return self.get_prob(timestamp)

    def get_psd(self):
        """
        Returns
        -------
        The latest computed PSD
        """
        raise NotImplementedError('Sorry! PSD buffer is under testing.')
        return self.psd_buffer[-1].reshape((1, -1))

    def is_ready(self):
        """
        Ready to decode? Returns True if buffer is not empty.
        """
        return self.sr.is_ready()