def get_last_trial(self, filename_live):
     # generate a NST_EEG_LIVE object and save data of last trial into it
     last_label = self.Y[-1:]
     ### subtract one trial offset, because add trial is allways called when the moto imagery starts and not in the beginning of each trial
     last_trial = self.trial[-1:][0] - self.Fs * self.trial_offset
     X = np.array(self.X[slice(last_trial[0], None, None)])
     #print(X.shape)
     last_trial = 0  #hand over 0 as index to dataset object, because the new index in the slice of X that will be handed over is 0
     cwd = os.getcwd()
     self.nst_eeg_live = NST_EEG_LIVE(cwd, filename_live)
     self.nst_eeg_live.load_from_mat(last_label, last_trial, X, self.Fs)
     return self.nst_eeg_live
    def get_last_trial(self, filename_live=""):
        ### generate a NST_EEG_LIVE object and save data of last trial into it
        ### the dataset can then be used for live classification
        last_label = self.Y[-1:]
        ### subtract one trial offset, because add trial is allways called when the moto imagery starts and not in the beginning of each trial
        last_trial = self.trial[-1:][0] - self.Fs * self.trial_offset
        X = np.array(self.X[slice(last_trial, None, None)])
        if False:  ### usefull info for debugging
            print('last index is: ', last_trial[0])
            print('last time is:', self.trial_time_stamps[-1:][0])
            print('current time is: ', pylsl.local_clock())
            print('current index is:', len(self.X))

        ### hand over 0 as index to dataset object, because the new index in the slice of X that will be handed over is 0
        last_trial = 0
        ### generate an instance of the NST_EEG_LIVE class (inherits from Dataset class)
        self.nst_eeg_live = NST_EEG_LIVE(self.datapath, filename_live)
        ### hand over data to the NST_EEG_LIVE instance
        self.nst_eeg_live.load_from_mat(last_label, last_trial, X, self.Fs)
        return self.nst_eeg_live
class RecordData_liveEEG_JB():
    def __init__(self,
                 Fs,
                 age,
                 gender="male",
                 with_feedback=False,
                 record_func=record_and_process,
                 control_func=control):

        ### decide whether control thread should also be started
        ### currently not used
        self.docontrol = False

        ### indizes in X indicating the beginning of a new trial
        self.trial = []
        ### list including all the recorded EEG data
        self.X = []

        ### time stamp indicating the beginning of each trial
        self.trial_time_stamps = []
        ### all time stamps, one is added for each data point
        self.time_stamps = []
        ### label of each trial: 0: left, 1: right, 2: both
        self.Y = []
        ### currently not used # 0 negative_feedback # 1 positive feedback
        self.feedbacks = []
        ### sampling frequency
        self.Fs = Fs
        ### trial offset in motor imagery paradigm. Used in get_last_trial()
        self.trial_offset = 4

        self.gender = gender
        self.age = age
        self.add_info = "with feedback" if with_feedback else "with no feedback"

        ### initialise a subfolder where the data is to be saved
        ### if it does not yet exist, create it
        self.datapath = os.path.join(os.getcwd(), '00_DATA')
        if not os.path.exists(self.datapath):
            os.makedirs(self.datapath)

        ### stop event used to pause and resume to the recording & processing thread
        self.stop_event_rec = threading.Event()
        self.stop_event_con = threading.Event()

        ### initialise recording thread. It does not run yet. Therefore use start_recording()
        recording_thread = threading.Thread(
            target=record_func,
            args=(self.stop_event_rec, self.X, self.time_stamps),
        )
        recording_thread.daemon = True
        self.recording_thread = recording_thread

        ### initialise control thread.
        if self.docontrol:
            control_thread = threading.Thread(
                target=control_func,
                args=(self.stop_event_con, self.X),
            )
            control_thread.daemon = True
            self.control_thread = control_thread

###############################################################################

    def __iter__(self):
        yield 'trial', self.trial
        yield 'age', self.age
        yield 'X', self.X
        yield 'time_stamps', self.time_stamps
        yield 'trial_time_stamps', self.trial_time_stamps
        yield 'Y', self.Y
        yield 'Fs', self.Fs
        yield 'gender', self.gender
        yield 'add_info', self.add_info
        yield 'feedbacks', self.feedbacks

###############################################################################

    def add_trial(self, label):
        ### called whenever a new trial is started
        self.trial_time_stamps.append(pylsl.local_clock())
        self.Y.append(label)
        ### trial includes the index in X and force, where each trial has begun
        self.trial.append(len(self.X) - 1)

###############################################################################

    def add_feedback(self, feedback):
        self.feedbacks.append(feedback)

###############################################################################

    def start_recording(self, len_X=0, len_f=0):
        ### start the recording thread
        self.recording_thread.start()

        if self.docontrol:
            self.control_thread.start()

        time.sleep(2)
        ### check whether data arrived, if not raise error
        if len(self.X) - len_X == 0:
            raise NoRecordingDataError()

###############################################################################

    def pause_recording(self):
        ### raise stop_event to break the loop in record() while the classification of notlive data is done
        self.stop_event_rec.set()
        print('Recording has been paused.')

###############################################################################

    def restart_recording(self):
        ### newly initialise the recording thread and start it
        self.stop_event_rec.clear()
        recording_thread = threading.Thread(
            target=record,
            args=(self.stop_event_rec, self.X, self.time_stamps),
        )
        recording_thread.daemon = True
        self.recording_thread = recording_thread
        self.start_recording(len_X=len(self.X))

        print('Recording has been restarted.')

###############################################################################
### this function is not required anymore, because self.trial is updated in add_trial()
### kept for historical reasons

    def set_trial_start_indexes(self):
        ### since it can be called twice during one recording (because of live processing)
        ### everything done by the first step is deleted before the second step
        if len(self.trial) > 0:
            self.trial = []
        ### the loop was once used to calculate the index in X that the time stamp of each trial begin relates to
        ### this is solved by updating self.trial already in add_trial()
        i = 0
        for trial_time_stamp in self.trial_time_stamps:
            for j in range(i, len(self.time_stamps)):
                time_stamp = self.time_stamps[j]
                if trial_time_stamp <= time_stamp:
                    self.trial.append(j - 1)
                    i = j
                    break

###############################################################################

    def stop_recording_and_dump(self,
                                file_name="EEG_session_" + time_str() +
                                ".mat"):
        ### finish the recording, save all data to a .mat file
        self.pause_recording()
        self.stop_event_con.set()
        sio.savemat(os.path.join(self.datapath, file_name), dict(self))
        print('Recording will shut down.')
        return file_name, self.datapath

###############################################################################

    def stop_recording_and_dump_live(self,
                                     file_name="EEG_session_live_" +
                                     time_str() + ".mat"):
        ### still there for historic reasons, to support run_session by Mirjam Hemberger
        return self.continue_recording_and_dump()

###############################################################################

    def continue_recording_and_dump(self,
                                    file_name="EEG_session_live_" +
                                    time_str() + ".mat"):
        ### only save data while still keeping the recording thread alive
        ### the data can then be used to classify the notlive data
        sio.savemat(os.path.join(self.datapath, file_name), dict(self))
        return file_name, self.datapath

###############################################################################

    def pause_recording_and_dump(self,
                                 file_name="EEG_session_live_" + time_str() +
                                 ".mat"):
        ### save data to .mat and pause the recording such that it can be resumed lateron
        sio.savemat(os.path.join(self.datapath, file_name), dict(self))
        self.pause_recording()
        return file_name, self.datapath

###############################################################################

    def get_last_trial(self, filename_live=""):
        ### generate a NST_EEG_LIVE object and save data of last trial into it
        ### the dataset can then be used for live classification
        last_label = self.Y[-1:]
        ### subtract one trial offset, because add trial is allways called when the moto imagery starts and not in the beginning of each trial
        last_trial = self.trial[-1:][0] - self.Fs * self.trial_offset
        X = np.array(self.X[slice(last_trial, None, None)])
        if False:  ### usefull info for debugging
            print('last index is: ', last_trial[0])
            print('last time is:', self.trial_time_stamps[-1:][0])
            print('current time is: ', pylsl.local_clock())
            print('current index is:', len(self.X))

        ### hand over 0 as index to dataset object, because the new index in the slice of X that will be handed over is 0
        last_trial = 0
        ### generate an instance of the NST_EEG_LIVE class (inherits from Dataset class)
        self.nst_eeg_live = NST_EEG_LIVE(self.datapath, filename_live)
        ### hand over data to the NST_EEG_LIVE instance
        self.nst_eeg_live.load_from_mat(last_label, last_trial, X, self.Fs)
        return self.nst_eeg_live

###############################################################################

    def startAccumulate(self):
        self.accStart = len(self.X)
        print("starting accumulation")

###############################################################################

    def stopAccumulate(self):
        pass
Beispiel #4
0
    def __init__(self, data_dir, filename_notlive, n_classes=2):
        self.print_version_info()

        self.data_dir = data_dir
        self.cwd = os.getcwd()
        self.n_classes = n_classes
        kwargs = {'n_classes': self.n_classes}

        ### initialise dataset
        self.data_notlive = NST_EEG_LIVE(self.data_dir, filename_notlive,
                                         **kwargs)
        self.data_notlive.load()
        self.data_notlive.print_stats()

        self.MODELNAME = "CNN_STFT"

        self.x_stacked = np.zeros((1, self.data_notlive.sampling_freq *
                                   self.data_notlive.trial_total, 3))
        self.y_stacked = np.zeros((1, self.n_classes))

        self.fs = 256
        self.lowcut = 2
        self.highcut = 60
        self.anti_drift = 0.5
        self.f0 = 50.0  # freq to be removed from signal (Hz) for notch filter
        self.Q = 30.0  # quality factor for notch filter
        # w0 = f0 / (fs / 2)
        self.AXIS = 0
        self.CUTOFF = 50.0
        self.w0 = self.CUTOFF / (self.fs / 2)
        self.dropout = 0.5

        ### reduce sampling frequency to 256
        ### most previous data is at 256 Hz, but no it has to be recorded at 512 Hz due to the combination of EMG and EEG
        ### hence, EEG is downsampled by a factor of 2 here
        if self.data_notlive.sampling_freq > self.fs:
            self.data_notlive.raw_data = decimate(
                self.data_notlive.raw_data,
                int(self.data_notlive.sampling_freq / self.fs),
                axis=0,
                zero_phase=True)
            self.data_notlive.sampling_freq = self.fs
            self.data_notlive.trials = np.floor(self.data_notlive.trials /
                                                2).astype(int)

        ### filter the data
        self.data_notlive_filt = gumpy.signal.notch(self.data_notlive.raw_data,
                                                    self.CUTOFF, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_highpass(
            self.data_notlive_filt, self.anti_drift, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_bandpass(
            self.data_notlive_filt, self.lowcut, self.highcut, self.AXIS)

        #self.min_cols = np.min(self.data_notlive_filt, axis=0)
        #self.max_cols = np.max(self.data_notlive_filt, axis=0)

        ### clip and normalise the data
        ### keep normalisation constants for lateron (hence no use of gumpy possible)
        self.sigma = np.min(np.std(self.data_notlive_filt, axis=0))
        self.data_notlive_clip = np.clip(self.data_notlive_filt,
                                         self.sigma * (-6), self.sigma * 6)

        self.notlive_mean = np.mean(self.data_notlive_clip, axis=0)
        self.notlive_std_dev = np.std(self.data_notlive_clip, axis=0)
        self.data_notlive_clip = (self.data_notlive_clip -
                                  self.notlive_mean) / self.notlive_std_dev
        #self.data_notlive_clip = gumpy.signal.normalize(self.data_notlive_clip, 'mean_std')

        ### extract the time within the trials of 10s for each class
        self.class1_mat, self.class2_mat = gumpy.utils.extract_trials_corrJB(
            self.data_notlive,
            filtered=self.data_notlive_clip)  #, self.data_notlive.trials,
        #self.data_notlive.labels, self.data_notlive.trial_total, self.fs)#, nbClasses=self.n_classes)
        #TODO: correct function extract_trials() trial len & trial offset

        ### concatenate data for training and create labels
        self.x_train = np.concatenate((self.class1_mat, self.class2_mat))
        self.labels_c1 = np.zeros((self.class1_mat.shape[0], ))
        self.labels_c2 = np.ones((self.class2_mat.shape[0], ))
        self.y_train = np.concatenate((self.labels_c1, self.labels_c2))

        ### for categorical crossentropy as an output of the CNN, another format of y is required
        self.y_train = ku.to_categorical(self.y_train)

        if DEBUG:
            print("Shape of x_train: ", self.x_train.shape)
            print("Shape of y_train: ", self.y_train.shape)

        print("EEG Data loaded and processed successfully!")

        ### roll shape to match to the CNN
        self.x_rolled = np.rollaxis(self.x_train, 2, 1)

        if DEBUG:
            print('X shape: ', self.x_train.shape)
            print('X rolled shape: ', self.x_rolled.shape)

        ### augment data to have more samples for training
        self.x_augmented, self.y_augmented = gumpy.signal.sliding_window(
            data=self.x_train,
            labels=self.y_train,
            window_sz=4 * self.fs,
            n_hop=self.fs // 8,
            n_start=self.fs * 3)

        ### roll shape to match to the CNN
        self.x_augmented_rolled = np.rollaxis(self.x_augmented, 2, 1)
        print("Shape of x_augmented: ", self.x_augmented_rolled.shape)
        print("Shape of y_augmented: ", self.y_augmented.shape)

        ### try to load the .json model file, otherwise build a new model
        self.loaded = 0
        if os.path.isfile(os.path.join(self.cwd, self.MODELNAME + ".json")):
            self.load_CNN_model()
            if self.model:
                self.loaded = 1

        if self.loaded == 0:
            print("Could not load model, will build model.")
            self.build_CNN_model()
            if self.model:
                self.loaded = 1

        ### Create callbacks for saving
        saved_model_name = self.MODELNAME
        TMP_NAME = self.MODELNAME + "_" + "_C" + str(self.n_classes)
        for i in range(99):
            if os.path.isfile(saved_model_name + ".csv"):
                saved_model_name = TMP_NAME + "_run{0}".format(i)

        ### Save model -> json file
        json_string = self.model.to_json()
        model_file = saved_model_name + ".json"
        open(model_file, 'w').write(json_string)

        ### define where to save the parameters to
        model_file = saved_model_name + 'monitoring' + '.h5'
        checkpoint = ModelCheckpoint(model_file,
                                     monitor='val_loss',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='min')
        log_file = saved_model_name + '.csv'
        csv_logger = CSVLogger(log_file, append=True, separator=';')
        self.callbacks_list = [csv_logger, checkpoint]  # callback list
Beispiel #5
0
class liveEEG_CNN():
    def __init__(self, data_dir, filename_notlive, n_classes=2):
        self.print_version_info()

        self.data_dir = data_dir
        self.cwd = os.getcwd()
        self.n_classes = n_classes
        kwargs = {'n_classes': self.n_classes}

        ### initialise dataset
        self.data_notlive = NST_EEG_LIVE(self.data_dir, filename_notlive,
                                         **kwargs)
        self.data_notlive.load()
        self.data_notlive.print_stats()

        self.MODELNAME = "CNN_STFT"

        self.x_stacked = np.zeros((1, self.data_notlive.sampling_freq *
                                   self.data_notlive.trial_total, 3))
        self.y_stacked = np.zeros((1, self.n_classes))

        self.fs = 256
        self.lowcut = 2
        self.highcut = 60
        self.anti_drift = 0.5
        self.f0 = 50.0  # freq to be removed from signal (Hz) for notch filter
        self.Q = 30.0  # quality factor for notch filter
        # w0 = f0 / (fs / 2)
        self.AXIS = 0
        self.CUTOFF = 50.0
        self.w0 = self.CUTOFF / (self.fs / 2)
        self.dropout = 0.5

        ### reduce sampling frequency to 256
        ### most previous data is at 256 Hz, but no it has to be recorded at 512 Hz due to the combination of EMG and EEG
        ### hence, EEG is downsampled by a factor of 2 here
        if self.data_notlive.sampling_freq > self.fs:
            self.data_notlive.raw_data = decimate(
                self.data_notlive.raw_data,
                int(self.data_notlive.sampling_freq / self.fs),
                axis=0,
                zero_phase=True)
            self.data_notlive.sampling_freq = self.fs
            self.data_notlive.trials = np.floor(self.data_notlive.trials /
                                                2).astype(int)

        ### filter the data
        self.data_notlive_filt = gumpy.signal.notch(self.data_notlive.raw_data,
                                                    self.CUTOFF, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_highpass(
            self.data_notlive_filt, self.anti_drift, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_bandpass(
            self.data_notlive_filt, self.lowcut, self.highcut, self.AXIS)

        #self.min_cols = np.min(self.data_notlive_filt, axis=0)
        #self.max_cols = np.max(self.data_notlive_filt, axis=0)

        ### clip and normalise the data
        ### keep normalisation constants for lateron (hence no use of gumpy possible)
        self.sigma = np.min(np.std(self.data_notlive_filt, axis=0))
        self.data_notlive_clip = np.clip(self.data_notlive_filt,
                                         self.sigma * (-6), self.sigma * 6)

        self.notlive_mean = np.mean(self.data_notlive_clip, axis=0)
        self.notlive_std_dev = np.std(self.data_notlive_clip, axis=0)
        self.data_notlive_clip = (self.data_notlive_clip -
                                  self.notlive_mean) / self.notlive_std_dev
        #self.data_notlive_clip = gumpy.signal.normalize(self.data_notlive_clip, 'mean_std')

        ### extract the time within the trials of 10s for each class
        self.class1_mat, self.class2_mat = gumpy.utils.extract_trials_corrJB(
            self.data_notlive,
            filtered=self.data_notlive_clip)  #, self.data_notlive.trials,
        #self.data_notlive.labels, self.data_notlive.trial_total, self.fs)#, nbClasses=self.n_classes)
        #TODO: correct function extract_trials() trial len & trial offset

        ### concatenate data for training and create labels
        self.x_train = np.concatenate((self.class1_mat, self.class2_mat))
        self.labels_c1 = np.zeros((self.class1_mat.shape[0], ))
        self.labels_c2 = np.ones((self.class2_mat.shape[0], ))
        self.y_train = np.concatenate((self.labels_c1, self.labels_c2))

        ### for categorical crossentropy as an output of the CNN, another format of y is required
        self.y_train = ku.to_categorical(self.y_train)

        if DEBUG:
            print("Shape of x_train: ", self.x_train.shape)
            print("Shape of y_train: ", self.y_train.shape)

        print("EEG Data loaded and processed successfully!")

        ### roll shape to match to the CNN
        self.x_rolled = np.rollaxis(self.x_train, 2, 1)

        if DEBUG:
            print('X shape: ', self.x_train.shape)
            print('X rolled shape: ', self.x_rolled.shape)

        ### augment data to have more samples for training
        self.x_augmented, self.y_augmented = gumpy.signal.sliding_window(
            data=self.x_train,
            labels=self.y_train,
            window_sz=4 * self.fs,
            n_hop=self.fs // 8,
            n_start=self.fs * 3)

        ### roll shape to match to the CNN
        self.x_augmented_rolled = np.rollaxis(self.x_augmented, 2, 1)
        print("Shape of x_augmented: ", self.x_augmented_rolled.shape)
        print("Shape of y_augmented: ", self.y_augmented.shape)

        ### try to load the .json model file, otherwise build a new model
        self.loaded = 0
        if os.path.isfile(os.path.join(self.cwd, self.MODELNAME + ".json")):
            self.load_CNN_model()
            if self.model:
                self.loaded = 1

        if self.loaded == 0:
            print("Could not load model, will build model.")
            self.build_CNN_model()
            if self.model:
                self.loaded = 1

        ### Create callbacks for saving
        saved_model_name = self.MODELNAME
        TMP_NAME = self.MODELNAME + "_" + "_C" + str(self.n_classes)
        for i in range(99):
            if os.path.isfile(saved_model_name + ".csv"):
                saved_model_name = TMP_NAME + "_run{0}".format(i)

        ### Save model -> json file
        json_string = self.model.to_json()
        model_file = saved_model_name + ".json"
        open(model_file, 'w').write(json_string)

        ### define where to save the parameters to
        model_file = saved_model_name + 'monitoring' + '.h5'
        checkpoint = ModelCheckpoint(model_file,
                                     monitor='val_loss',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='min')
        log_file = saved_model_name + '.csv'
        csv_logger = CSVLogger(log_file, append=True, separator=';')
        self.callbacks_list = [csv_logger, checkpoint]  # callback list

###############################################################################
### train the model with the notlive data or sinmply load a pretrained model

    def fit(self, load=False):
        #TODO: use method train_on_batch() to update model
        self.batch_size = 32
        self.model.compile(loss='categorical_crossentropy',
                           optimizer='adam',
                           metrics=['accuracy'])

        if not load:
            print('Train...')
            self.model.fit(self.x_augmented_rolled,
                           self.y_augmented,
                           batch_size=self.batch_size,
                           epochs=100,
                           shuffle=True,
                           validation_split=0.2,
                           callbacks=self.callbacks_list)
        else:
            print('Load...')
            self.model = keras.models.load_model(
                'CNN_STFTmonitoring.h5',
                custom_objects={
                    'Spectrogram': kapre.time_frequency.Spectrogram,
                    'Normalization2D': kapre.utils.Normalization2D
                })

        #CNN_STFT__C2_run4monitoring.h5

###############################################################################
### do the live classification

    def classify_live(self, data_live):
        ### perform the same preprocessing steps as in __init__()

        ### agina, donwsampling from 512 to 256 (see above)
        if data_live.sampling_freq > self.fs:
            data_live.raw_data = decimate(data_live.raw_data,
                                          int(self.data_notlive.sampling_freq /
                                              self.fs),
                                          axis=0,
                                          zero_phase=True)
            data_live.sampling_freq = self.fs

        self.y_live = data_live.labels

        self.data_live_filt = gumpy.signal.notch(data_live, self.CUTOFF,
                                                 self.AXIS)
        self.data_live_filt = gumpy.signal.butter_highpass(
            self.data_live_filt, self.anti_drift, self.AXIS)
        self.data_live_filt = gumpy.signal.butter_bandpass(
            self.data_live_filt, self.lowcut, self.highcut, self.AXIS)

        self.data_live_clip = np.clip(self.data_live_filt, self.sigma * (-6),
                                      self.sigma * 6)
        self.data_live_clip = (self.data_live_clip -
                               self.notlive_mean) / self.notlive_std_dev

        class1_mat, class2_mat = gumpy.utils.extract_trials_corrJB(
            data_live, filtered=self.data_live_clip)

        ### concatenate data  and create labels
        self.x_live = np.concatenate((class1_mat, class2_mat))

        self.x_live = self.x_live[:,
                    data_live.mi_interval[0]*data_live.sampling_freq\
                    :data_live.mi_interval[1]*data_live.sampling_freq, :]

        self.x_live = np.rollaxis(self.x_live, 2, 1)

        ### do the prediction
        pred_valid = 0
        y_pred = []
        pred_true = []
        if self.loaded and self.x_live.any():
            y_pred = self.model.predict(self.x_live, batch_size=64)
            print(y_pred)
            #classes = self.model.predict(self.x_live_augmented,batch_size=64)
            #pref0 = sum(classes[:,0])
            #pref1 = sum(classes[:,1])
            #if pref1 > pref0:
            #    y_pred = 1
            #else:
            #    y_pred = 0

            ### argmax because output is crossentropy
            y_pred = y_pred.argmax()
            pred_true = self.y_live == y_pred
            print('Real=', self.y_live)
            pred_valid = 1

        return y_pred, pred_true, pred_valid

###############################################################################

    def load_CNN_model(self):
        print('Load model', self.MODELNAME)
        model_path = self.MODELNAME + ".json"
        if not os.path.isfile(model_path):
            raise IOError('file "%s" does not exist' % (model_path))
        self.model = model_from_json(open(model_path).read(),
                                     custom_objects={
                                         'Spectrogram':
                                         kapre.time_frequency.Spectrogram,
                                         'Normalization2D':
                                         kapre.utils.Normalization2D
                                     })
        #self.model = load_model(self.cwd,self.MODELNAME,self.MODELNAME+'monitoring')
        #TODO: get it to work, but not urgently required
        #self.model = []

###############################################################################

    def build_CNN_model(self):
        ### define CNN architecture
        print('Build model...')
        self.model = Sequential()
        self.model.add(
            Spectrogram(n_dft=128,
                        n_hop=16,
                        input_shape=(self.x_augmented_rolled.shape[1:]),
                        return_decibel_spectrogram=False,
                        power_spectrogram=2.0,
                        trainable_kernel=False,
                        name='static_stft'))
        self.model.add(Normalization2D(str_axis='freq'))

        # Conv Block 1
        self.model.add(
            Conv2D(filters=24,
                   kernel_size=(12, 12),
                   strides=(1, 1),
                   name='conv1',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # Conv Block 2
        self.model.add(
            Conv2D(filters=48,
                   kernel_size=(8, 8),
                   name='conv2',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # Conv Block 3
        self.model.add(
            Conv2D(filters=96,
                   kernel_size=(4, 4),
                   name='conv3',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # classificator
        self.model.add(Flatten())
        self.model.add(Dense(self.n_classes))  # two classes only
        self.model.add(Activation('softmax'))

        print(self.model.summary())
        self.saved_model_name = self.MODELNAME

###############################################################################

    def print_version_info(self):
        now = datetime.now()

        print('%s/%s/%s' % (now.year, now.month, now.day))
        print('Keras version: {}'.format(keras.__version__))
        if keras.backend._BACKEND == 'tensorflow':
            import tensorflow
            print('Keras backend: {}: {}'.format(keras.backend._backend,
                                                 tensorflow.__version__))
        else:
            import theano
            print('Keras backend: {}: {}'.format(keras.backend._backend,
                                                 theano.__version__))
        print('Keras image dim ordering: {}'.format(
            keras.backend.image_dim_ordering()))
        print('Kapre version: {}'.format(kapre.__version__))
        count_pred_true = 0
        count_pred_false = 0

        for i in range(0,myclass2.data_notlive.trials.shape[0]):
            fs = myclass2.data_notlive.sampling_freq
            label = myclass2.data_notlive.labels[i]
            trial = myclass2.data_notlive.trials[i]
            if i < (myclass2.data_notlive.trials.shape[0] - 1):
                next_trial = myclass2.data_notlive.trials[i+1]
                X = myclass2.data_notlive.raw_data[trial:next_trial]
            else:
                X = myclass2.data_notlive.raw_data[trial:]

            trial = 0
            nst_eeg_live = NST_EEG_LIVE(base_dir, file_name2)
            nst_eeg_live.load_from_mat(label,trial,X,fs)

            current_classifier, pred_true, pred_valid = myclass.classify_live(nst_eeg_live)

            if not pred_valid:
                continue

            #print('Classification result: ',current_classifier[0],'\n')
            if pred_true:
                #print('This is true!\n')
                count_pred_true = count_pred_true + 1
            else:
                count_pred_false += 1
                #print('This is false!\n')
class RecordData():
    def __init__(self,
                 Fs,
                 age,
                 gender="male",
                 with_feedback=False,
                 record_func=record):
        # timepoints when the subject starts imagination
        self.trial = []

        self.X = []

        self.trial_time_stamps = []
        self.time_stamps = []

        # 0 negative_feedback
        # 1 positive feedback
        self.feedbacks = []

        # containts the lables of the trials:
        # 1: left
        # 2: right
        # 3: both hands
        self.Y = []

        # sampling frequncy
        self.Fs = Fs

        self.trial_offset = 4

        self.gender = gender
        self.age = age
        self.add_info = "with feedback" if with_feedback else "with no feedback"

        recording_thread = threading.Thread(
            target=record_func,
            args=(self.X, self.time_stamps),
        )
        recording_thread.daemon = True
        self.recording_thread = recording_thread

    def __iter__(self):
        yield 'trial', self.trial
        yield 'age', self.age
        yield 'X', self.X
        yield 'time_stamps', self.time_stamps
        yield 'trial_time_stamps', self.trial_time_stamps
        yield 'Y', self.Y
        yield 'Fs', self.Fs
        yield 'gender', self.gender
        yield 'add_info', self.add_info
        yield 'feedbacks', self.feedbacks

    def add_trial(self, label):
        self.trial_time_stamps.append(pylsl.local_clock())
        self.Y.append(label)
        self.trial.append(len(self.X) - 1)

    def add_feedback(self, feedback):
        self.feedbacks.append(feedback)

    def start_recording(self):
        self.recording_thread.start()
        time.sleep(2)
        if len(self.X) == 0:
            raise NoRecordingDataError()

    ### this function is not required anymore, because self.trial is updated in add_trial()
    ### kept for historical reasons
    def set_trial_start_indexes(self):
        if len(self.trial) > 0:
            self.trial = []

        i = 0
        for trial_time_stamp in self.trial_time_stamps:
            for j in range(i, len(self.time_stamps)):
                time_stamp = self.time_stamps[j]
                if trial_time_stamp <= time_stamp:
                    self.trial.append(j - 1)
                    i = j
                    break

    def stop_recording_and_dump(self,
                                file_name="session_" + time_str() + ".mat"):
        #self.set_trial_start_indexes() #solved by collecting index in add_trial()
        sio.savemat(file_name, dict(self))

        return file_name

    ### used for live processing step 2: dump all data generated so far, another filename than stop_recording_and_dump
    def stop_recording_and_dump_live(self,
                                     file_name="session_live_" + time_str() +
                                     ".mat"):
        #self.set_trial_start_indexes() #solved by collecting index in add_trial()
        sio.savemat(file_name, dict(self))

        return file_name

    def get_last_trial(self, filename_live):
        # generate a NST_EEG_LIVE object and save data of last trial into it
        last_label = self.Y[-1:]
        ### subtract one trial offset, because add trial is allways called when the moto imagery starts and not in the beginning of each trial
        last_trial = self.trial[-1:][0] - self.Fs * self.trial_offset
        X = np.array(self.X[slice(last_trial[0], None, None)])
        #print(X.shape)
        last_trial = 0  #hand over 0 as index to dataset object, because the new index in the slice of X that will be handed over is 0
        cwd = os.getcwd()
        self.nst_eeg_live = NST_EEG_LIVE(cwd, filename_live)
        self.nst_eeg_live.load_from_mat(last_label, last_trial, X, self.Fs)
        return self.nst_eeg_live
    def __init__(self, cwd, filename_notlive, flag=0):
        # only used for testing
        self.flag = flag
        if self.flag == 1:
            self.data_notlive = NST_EEG_TEST(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        # cwd and filename_notlive specify the location of one .mat file where the recorded notlive data has been stored
        # print(filename_notlive, '\n')
        # load notlive data from path that has been specified
        if self.flag == 0:
            self.data_notlive = NST_EEG_LIVE(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        if 1:
            self.data_notlive.raw_data[:, 0] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 1] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 0].shape

        self.labels_notlive = self.data_notlive.labels

        # butter-bandpass filtered version of notlive data
        self.data_notlive_filtbp = gumpy.signal.butter_bandpass(
            self.data_notlive, lo=2, hi=60)

        # frequency to be removed from the signal
        self.notch_f0 = 50
        # quality factor
        self.notch_Q = 50.0
        # get the cutoff frequency
        # self.notch_w0 = self.notch_f0/(self.data_notlive.sampling_freq/2)
        # apply the notch filter
        self.data_notlive_filtno = gumpy.signal.notch(
            self.data_notlive_filtbp,
            cutoff=self.notch_f0,
            Q=self.notch_Q,
            fs=self.data_notlive.sampling_freq)

        self.alpha_bands = np.array(
            self.alpha_subBP_features(self.data_notlive_filtno))
        self.beta_bands = np.array(
            self.beta_subBP_features(self.data_notlive_filtno))

        # Feature extraction using sub-bands
        # Method 1: logarithmic sub-band power
        self.w1 = [0, 125]
        self.w2 = [125, 250]

        self.features1 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w1)

        self.features2 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w2)

        # concatenate the features and normalize the data
        self.features_notlive = np.concatenate(
            (self.features1.T, self.features2.T)).T
        self.features_notlive -= np.mean(self.features_notlive)
        self.features_notlive = gumpy.signal.normalize(self.features_notlive,
                                                       'min_max')

        # Method 2: DWT
        # We'll work with the data that was postprocessed using a butter bandpass
        if False:
            self.w = [0, 256]
            # extract the features
            self.trials = self.data_notlive.trials
            self.fs = self.data_notlive.sampling_freq
            self.features1 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 3, "db4"))
            self.features2 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 4, "db4"))
            # concat the features and normalize
            self.features_notlive = np.concatenate(
                (self.features1.T, self.features2.T)).T
            self.features_notlive -= np.mean(self.features_notlive)
            self.features_notlive = gumpy.signal.normalize(
                self.features_notlive, 'min_max')

        self.pos_fit = False
class liveEEG():
    def __init__(self, cwd, filename_notlive, flag=0):
        # only used for testing
        self.flag = flag
        if self.flag == 1:
            self.data_notlive = NST_EEG_TEST(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        # cwd and filename_notlive specify the location of one .mat file where the recorded notlive data has been stored
        # print(filename_notlive, '\n')
        # load notlive data from path that has been specified
        if self.flag == 0:
            self.data_notlive = NST_EEG_LIVE(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        if 1:
            self.data_notlive.raw_data[:, 0] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 1] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 0].shape

        self.labels_notlive = self.data_notlive.labels

        # butter-bandpass filtered version of notlive data
        self.data_notlive_filtbp = gumpy.signal.butter_bandpass(
            self.data_notlive, lo=2, hi=60)

        # frequency to be removed from the signal
        self.notch_f0 = 50
        # quality factor
        self.notch_Q = 50.0
        # get the cutoff frequency
        # self.notch_w0 = self.notch_f0/(self.data_notlive.sampling_freq/2)
        # apply the notch filter
        self.data_notlive_filtno = gumpy.signal.notch(
            self.data_notlive_filtbp,
            cutoff=self.notch_f0,
            Q=self.notch_Q,
            fs=self.data_notlive.sampling_freq)

        self.alpha_bands = np.array(
            self.alpha_subBP_features(self.data_notlive_filtno))
        self.beta_bands = np.array(
            self.beta_subBP_features(self.data_notlive_filtno))

        # Feature extraction using sub-bands
        # Method 1: logarithmic sub-band power
        self.w1 = [0, 125]
        self.w2 = [125, 250]

        self.features1 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w1)

        self.features2 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w2)

        # concatenate the features and normalize the data
        self.features_notlive = np.concatenate(
            (self.features1.T, self.features2.T)).T
        self.features_notlive -= np.mean(self.features_notlive)
        self.features_notlive = gumpy.signal.normalize(self.features_notlive,
                                                       'min_max')

        # Method 2: DWT
        # We'll work with the data that was postprocessed using a butter bandpass
        if False:
            self.w = [0, 256]
            # extract the features
            self.trials = self.data_notlive.trials
            self.fs = self.data_notlive.sampling_freq
            self.features1 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 3, "db4"))
            self.features2 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 4, "db4"))
            # concat the features and normalize
            self.features_notlive = np.concatenate(
                (self.features1.T, self.features2.T)).T
            self.features_notlive -= np.mean(self.features_notlive)
            self.features_notlive = gumpy.signal.normalize(
                self.features_notlive, 'min_max')

        self.pos_fit = False
        #print(self.features_notlive)
        #print(self.labels_notlive)

    def fit(self):
        # Sequential Feature Selection Algorithm
        out_realtime = gumpy.features.sequential_feature_selector_realtime(
            self.features_notlive, self.labels_notlive, 'SVM', 1, 2, 'SFFS')
        print('\n\nAverage score:', out_realtime[1] * 100)
        self.sfs_object = out_realtime[3]
        self.estimator_object = self.sfs_object.est_
        ### fit the estimator object with the selected features and test
        self.estimator_object.fit(
            self.sfs_object.transform(self.features_notlive),
            self.labels_notlive)
        self.labels_pred_notlive = self.estimator_object.predict(
            self.sfs_object.transform(self.features_notlive))
        self.acc_notlive = 100 * np.sum(abs(self.labels_pred_notlive-self.labels_notlive)<1) \
                / np.shape(self.labels_pred_notlive)
        print('\nAccuracy of notlive fit:', self.acc_notlive[0], '\n')
        self.pos_fit = True

    def classify_live(self, data_live):
        # data_live should be an object of class NST_EEG_LIVE
        self.labels_live = data_live.labels

        # butter-bandpass filtered version of notlive data
        self.data_live_filtbp = gumpy.signal.butter_bandpass(data_live,
                                                             lo=2,
                                                             hi=60)

        # apply the notch filter
        self.data_live_filtno = gumpy.signal.notch(
            self.data_live_filtbp,
            cutoff=self.notch_f0,
            Q=self.notch_Q,
            fs=self.data_notlive.sampling_freq)

        self.alpha_bands = np.array(
            self.alpha_subBP_features(self.data_live_filtno))
        self.beta_bands = np.array(
            self.beta_subBP_features(self.data_live_filtno))

        # Feature extraction using sub-bands
        # Method 1: logarithmic sub-band power
        self.w1 = [0, 125]
        self.w2 = [125, 250]

        self.features1_live = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, data_live.trials,
            data_live.sampling_freq, self.w1)

        self.features2_live = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, data_live.trials,
            data_live.sampling_freq, self.w2)

        # concatenate the features and normalize the data
        self.features_live = np.concatenate(
            (self.features1_live.T, self.features2_live.T)).T
        self.features_live -= np.mean(self.features_live)
        self.features_live = gumpy.signal.normalize(self.features_live,
                                                    'min_max')

        # Method 2: DWT
        # We'll work with the data that was postprocessed using a butter bandpass
        if False:
            self.w = [0, 256]
            # extract the features
            self.trials_live = data_live.trials
            self.fs_live = data_live.sampling_freq
            self.features1_live = np.array(
                self.dwt_features(self.data_live_filtno, self.trials_live, 5,
                                  self.fs_live, self.w, 3, "db4"))
            self.features2_live = np.array(
                self.dwt_features(self.data_live_filtno, self.trials_live, 5,
                                  self.fs_live, self.w, 4, "db4"))
            # concat the features and normalize
            self.features_live = np.concatenate(
                (self.features1_live.T, self.features2_live.T)).T
            self.features_live -= np.mean(self.features_live)
            self.features_live = gumpy.signal.normalize(
                self.features_live, 'min_max')

        ### predict label of live trial and check whether it is correct
        if self.pos_fit:
            labels_pred_live = self.estimator_object.predict(
                self.sfs_object.transform(self.features_live))
            pred_true = (self.labels_live - labels_pred_live) == 1

        else:
            print("No fit was performed yet. Please train the model first.\n")
            sys.exit()

        return labels_pred_live, pred_true

    # Alpha and Beta sub-bands
    def alpha_subBP_features(self, data):
        # filter data in sub-bands by specification of low- and high-cut frequencies
        alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6)
        alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6)
        alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6)
        alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6)
        # return a list of sub-bands
        return [alpha1, alpha2, alpha3, alpha4]

    def beta_subBP_features(self, data):
        beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6)
        beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6)
        beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6)
        beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6)
        return [beta1, beta2, beta3, beta4]

    # Feature extraction using sub-bands (The following examples show how the sub-bands can be used to extract features.)
    # Method 1: logarithmic sub-band power
    def powermean(self, data, trial, fs, w):
        return  np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],0],2).mean(), \
                np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],1],2).mean(), \
                np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],2],2).mean()

    def log_subBP_feature_extraction(self, alpha, beta, trials, fs, w):
        # number of features combined for all trials
        n_features = 15
        # initialize the feature matrix
        X = np.zeros((len(trials), n_features))

        # Extract features
        for t, trial in enumerate(trials):
            power_c31, power_c41, power_cz1 = self.powermean(
                alpha[0], trial, fs, w)
            power_c32, power_c42, power_cz2 = self.powermean(
                alpha[1], trial, fs, w)
            power_c33, power_c43, power_cz3 = self.powermean(
                alpha[2], trial, fs, w)
            power_c34, power_c44, power_cz4 = self.powermean(
                alpha[3], trial, fs, w)
            power_c31_b, power_c41_b, power_cz1_b = self.powermean(
                beta[0], trial, fs, w)

            X[t, :] = np.array([
                np.log(power_c31),
                np.log(power_c41),
                np.log(power_cz1),
                np.log(power_c32),
                np.log(power_c42),
                np.log(power_cz2),
                np.log(power_c33),
                np.log(power_c43),
                np.log(power_cz3),
                np.log(power_c34),
                np.log(power_c44),
                np.log(power_cz4),
                np.log(power_c31_b),
                np.log(power_c41_b),
                np.log(power_cz1_b)
            ])
        return X

    # Method 2: DWT
    def dwt_features(self, data, trials, level, sampling_freq, w, n, wavelet):
        import pywt

        # number of features per trial
        n_features = 9
        # allocate memory to store the features
        X = np.zeros((len(trials), n_features))

        # Extract Features
        for t, trial in enumerate(trials):
            signals = data[trial + sampling_freq * 5 + (w[0]):trial +
                           sampling_freq * 5 + (w[1])]
            coeffs_c3 = pywt.wavedec(data=signals[:, 0],
                                     wavelet=wavelet,
                                     level=level)
            coeffs_c4 = pywt.wavedec(data=signals[:, 1],
                                     wavelet=wavelet,
                                     level=level)
            coeffs_cz = pywt.wavedec(data=signals[:, 2],
                                     wavelet=wavelet,
                                     level=level)

            X[t, :] = np.array([
                np.std(coeffs_c3[n]),
                np.mean(coeffs_c3[n]**2),
                np.std(coeffs_c4[n]),
                np.mean(coeffs_c4[n]**2),
                np.std(coeffs_cz[n]),
                np.mean(coeffs_cz[n]**2),
                np.mean(coeffs_c3[n]),
                np.mean(coeffs_c4[n]),
                np.mean(coeffs_cz[n])
            ])

        return X