class liveEEG_CNN(): def __init__(self, data_dir, filename_notlive, n_classes=2): self.print_version_info() self.data_dir = data_dir self.cwd = os.getcwd() self.n_classes = n_classes kwargs = {'n_classes': self.n_classes} ### initialise dataset self.data_notlive = NST_EEG_LIVE(self.data_dir, filename_notlive, **kwargs) self.data_notlive.load() self.data_notlive.print_stats() self.MODELNAME = "CNN_STFT" self.x_stacked = np.zeros((1, self.data_notlive.sampling_freq * self.data_notlive.trial_total, 3)) self.y_stacked = np.zeros((1, self.n_classes)) self.fs = 256 self.lowcut = 2 self.highcut = 60 self.anti_drift = 0.5 self.f0 = 50.0 # freq to be removed from signal (Hz) for notch filter self.Q = 30.0 # quality factor for notch filter # w0 = f0 / (fs / 2) self.AXIS = 0 self.CUTOFF = 50.0 self.w0 = self.CUTOFF / (self.fs / 2) self.dropout = 0.5 ### reduce sampling frequency to 256 ### most previous data is at 256 Hz, but no it has to be recorded at 512 Hz due to the combination of EMG and EEG ### hence, EEG is downsampled by a factor of 2 here if self.data_notlive.sampling_freq > self.fs: self.data_notlive.raw_data = decimate( self.data_notlive.raw_data, int(self.data_notlive.sampling_freq / self.fs), axis=0, zero_phase=True) self.data_notlive.sampling_freq = self.fs self.data_notlive.trials = np.floor(self.data_notlive.trials / 2).astype(int) ### filter the data self.data_notlive_filt = gumpy.signal.notch(self.data_notlive.raw_data, self.CUTOFF, self.AXIS) self.data_notlive_filt = gumpy.signal.butter_highpass( self.data_notlive_filt, self.anti_drift, self.AXIS) self.data_notlive_filt = gumpy.signal.butter_bandpass( self.data_notlive_filt, self.lowcut, self.highcut, self.AXIS) #self.min_cols = np.min(self.data_notlive_filt, axis=0) #self.max_cols = np.max(self.data_notlive_filt, axis=0) ### clip and normalise the data ### keep normalisation constants for lateron (hence no use of gumpy possible) self.sigma = np.min(np.std(self.data_notlive_filt, axis=0)) self.data_notlive_clip = np.clip(self.data_notlive_filt, self.sigma * (-6), self.sigma * 6) self.notlive_mean = np.mean(self.data_notlive_clip, axis=0) self.notlive_std_dev = np.std(self.data_notlive_clip, axis=0) self.data_notlive_clip = (self.data_notlive_clip - self.notlive_mean) / self.notlive_std_dev #self.data_notlive_clip = gumpy.signal.normalize(self.data_notlive_clip, 'mean_std') ### extract the time within the trials of 10s for each class self.class1_mat, self.class2_mat = gumpy.utils.extract_trials_corrJB( self.data_notlive, filtered=self.data_notlive_clip) #, self.data_notlive.trials, #self.data_notlive.labels, self.data_notlive.trial_total, self.fs)#, nbClasses=self.n_classes) #TODO: correct function extract_trials() trial len & trial offset ### concatenate data for training and create labels self.x_train = np.concatenate((self.class1_mat, self.class2_mat)) self.labels_c1 = np.zeros((self.class1_mat.shape[0], )) self.labels_c2 = np.ones((self.class2_mat.shape[0], )) self.y_train = np.concatenate((self.labels_c1, self.labels_c2)) ### for categorical crossentropy as an output of the CNN, another format of y is required self.y_train = ku.to_categorical(self.y_train) if DEBUG: print("Shape of x_train: ", self.x_train.shape) print("Shape of y_train: ", self.y_train.shape) print("EEG Data loaded and processed successfully!") ### roll shape to match to the CNN self.x_rolled = np.rollaxis(self.x_train, 2, 1) if DEBUG: print('X shape: ', self.x_train.shape) print('X rolled shape: ', self.x_rolled.shape) ### augment data to have more samples for training self.x_augmented, self.y_augmented = gumpy.signal.sliding_window( data=self.x_train, labels=self.y_train, window_sz=4 * self.fs, n_hop=self.fs // 8, n_start=self.fs * 3) ### roll shape to match to the CNN self.x_augmented_rolled = np.rollaxis(self.x_augmented, 2, 1) print("Shape of x_augmented: ", self.x_augmented_rolled.shape) print("Shape of y_augmented: ", self.y_augmented.shape) ### try to load the .json model file, otherwise build a new model self.loaded = 0 if os.path.isfile(os.path.join(self.cwd, self.MODELNAME + ".json")): self.load_CNN_model() if self.model: self.loaded = 1 if self.loaded == 0: print("Could not load model, will build model.") self.build_CNN_model() if self.model: self.loaded = 1 ### Create callbacks for saving saved_model_name = self.MODELNAME TMP_NAME = self.MODELNAME + "_" + "_C" + str(self.n_classes) for i in range(99): if os.path.isfile(saved_model_name + ".csv"): saved_model_name = TMP_NAME + "_run{0}".format(i) ### Save model -> json file json_string = self.model.to_json() model_file = saved_model_name + ".json" open(model_file, 'w').write(json_string) ### define where to save the parameters to model_file = saved_model_name + 'monitoring' + '.h5' checkpoint = ModelCheckpoint(model_file, monitor='val_loss', verbose=1, save_best_only=True, mode='min') log_file = saved_model_name + '.csv' csv_logger = CSVLogger(log_file, append=True, separator=';') self.callbacks_list = [csv_logger, checkpoint] # callback list ############################################################################### ### train the model with the notlive data or sinmply load a pretrained model def fit(self, load=False): #TODO: use method train_on_batch() to update model self.batch_size = 32 self.model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) if not load: print('Train...') self.model.fit(self.x_augmented_rolled, self.y_augmented, batch_size=self.batch_size, epochs=100, shuffle=True, validation_split=0.2, callbacks=self.callbacks_list) else: print('Load...') self.model = keras.models.load_model( 'CNN_STFTmonitoring.h5', custom_objects={ 'Spectrogram': kapre.time_frequency.Spectrogram, 'Normalization2D': kapre.utils.Normalization2D }) #CNN_STFT__C2_run4monitoring.h5 ############################################################################### ### do the live classification def classify_live(self, data_live): ### perform the same preprocessing steps as in __init__() ### agina, donwsampling from 512 to 256 (see above) if data_live.sampling_freq > self.fs: data_live.raw_data = decimate(data_live.raw_data, int(self.data_notlive.sampling_freq / self.fs), axis=0, zero_phase=True) data_live.sampling_freq = self.fs self.y_live = data_live.labels self.data_live_filt = gumpy.signal.notch(data_live, self.CUTOFF, self.AXIS) self.data_live_filt = gumpy.signal.butter_highpass( self.data_live_filt, self.anti_drift, self.AXIS) self.data_live_filt = gumpy.signal.butter_bandpass( self.data_live_filt, self.lowcut, self.highcut, self.AXIS) self.data_live_clip = np.clip(self.data_live_filt, self.sigma * (-6), self.sigma * 6) self.data_live_clip = (self.data_live_clip - self.notlive_mean) / self.notlive_std_dev class1_mat, class2_mat = gumpy.utils.extract_trials_corrJB( data_live, filtered=self.data_live_clip) ### concatenate data and create labels self.x_live = np.concatenate((class1_mat, class2_mat)) self.x_live = self.x_live[:, data_live.mi_interval[0]*data_live.sampling_freq\ :data_live.mi_interval[1]*data_live.sampling_freq, :] self.x_live = np.rollaxis(self.x_live, 2, 1) ### do the prediction pred_valid = 0 y_pred = [] pred_true = [] if self.loaded and self.x_live.any(): y_pred = self.model.predict(self.x_live, batch_size=64) print(y_pred) #classes = self.model.predict(self.x_live_augmented,batch_size=64) #pref0 = sum(classes[:,0]) #pref1 = sum(classes[:,1]) #if pref1 > pref0: # y_pred = 1 #else: # y_pred = 0 ### argmax because output is crossentropy y_pred = y_pred.argmax() pred_true = self.y_live == y_pred print('Real=', self.y_live) pred_valid = 1 return y_pred, pred_true, pred_valid ############################################################################### def load_CNN_model(self): print('Load model', self.MODELNAME) model_path = self.MODELNAME + ".json" if not os.path.isfile(model_path): raise IOError('file "%s" does not exist' % (model_path)) self.model = model_from_json(open(model_path).read(), custom_objects={ 'Spectrogram': kapre.time_frequency.Spectrogram, 'Normalization2D': kapre.utils.Normalization2D }) #self.model = load_model(self.cwd,self.MODELNAME,self.MODELNAME+'monitoring') #TODO: get it to work, but not urgently required #self.model = [] ############################################################################### def build_CNN_model(self): ### define CNN architecture print('Build model...') self.model = Sequential() self.model.add( Spectrogram(n_dft=128, n_hop=16, input_shape=(self.x_augmented_rolled.shape[1:]), return_decibel_spectrogram=False, power_spectrogram=2.0, trainable_kernel=False, name='static_stft')) self.model.add(Normalization2D(str_axis='freq')) # Conv Block 1 self.model.add( Conv2D(filters=24, kernel_size=(12, 12), strides=(1, 1), name='conv1', border_mode='same')) self.model.add(BatchNormalization(axis=1)) self.model.add( MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid', data_format='channels_last')) self.model.add(Activation('relu')) self.model.add(Dropout(self.dropout)) # Conv Block 2 self.model.add( Conv2D(filters=48, kernel_size=(8, 8), name='conv2', border_mode='same')) self.model.add(BatchNormalization(axis=1)) self.model.add( MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid', data_format='channels_last')) self.model.add(Activation('relu')) self.model.add(Dropout(self.dropout)) # Conv Block 3 self.model.add( Conv2D(filters=96, kernel_size=(4, 4), name='conv3', border_mode='same')) self.model.add(BatchNormalization(axis=1)) self.model.add( MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid', data_format='channels_last')) self.model.add(Activation('relu')) self.model.add(Dropout(self.dropout)) # classificator self.model.add(Flatten()) self.model.add(Dense(self.n_classes)) # two classes only self.model.add(Activation('softmax')) print(self.model.summary()) self.saved_model_name = self.MODELNAME ############################################################################### def print_version_info(self): now = datetime.now() print('%s/%s/%s' % (now.year, now.month, now.day)) print('Keras version: {}'.format(keras.__version__)) if keras.backend._BACKEND == 'tensorflow': import tensorflow print('Keras backend: {}: {}'.format(keras.backend._backend, tensorflow.__version__)) else: import theano print('Keras backend: {}: {}'.format(keras.backend._backend, theano.__version__)) print('Keras image dim ordering: {}'.format( keras.backend.image_dim_ordering())) print('Kapre version: {}'.format(kapre.__version__))
class liveEEG(): def __init__(self, cwd, filename_notlive, flag=0): # only used for testing self.flag = flag if self.flag == 1: self.data_notlive = NST_EEG_TEST(cwd, filename_notlive) self.data_notlive.load() self.data_notlive.print_stats() # cwd and filename_notlive specify the location of one .mat file where the recorded notlive data has been stored # print(filename_notlive, '\n') # load notlive data from path that has been specified if self.flag == 0: self.data_notlive = NST_EEG_LIVE(cwd, filename_notlive) self.data_notlive.load() self.data_notlive.print_stats() if 1: self.data_notlive.raw_data[:, 0] -= self.data_notlive.raw_data[:, 2] self.data_notlive.raw_data[:, 1] -= self.data_notlive.raw_data[:, 2] self.data_notlive.raw_data[:, 0].shape self.labels_notlive = self.data_notlive.labels # butter-bandpass filtered version of notlive data self.data_notlive_filtbp = gumpy.signal.butter_bandpass( self.data_notlive, lo=2, hi=60) # frequency to be removed from the signal self.notch_f0 = 50 # quality factor self.notch_Q = 50.0 # get the cutoff frequency # self.notch_w0 = self.notch_f0/(self.data_notlive.sampling_freq/2) # apply the notch filter self.data_notlive_filtno = gumpy.signal.notch( self.data_notlive_filtbp, cutoff=self.notch_f0, Q=self.notch_Q, fs=self.data_notlive.sampling_freq) self.alpha_bands = np.array( self.alpha_subBP_features(self.data_notlive_filtno)) self.beta_bands = np.array( self.beta_subBP_features(self.data_notlive_filtno)) # Feature extraction using sub-bands # Method 1: logarithmic sub-band power self.w1 = [0, 125] self.w2 = [125, 250] self.features1 = self.log_subBP_feature_extraction( self.alpha_bands, self.beta_bands, self.data_notlive.trials, self.data_notlive.sampling_freq, self.w1) self.features2 = self.log_subBP_feature_extraction( self.alpha_bands, self.beta_bands, self.data_notlive.trials, self.data_notlive.sampling_freq, self.w2) # concatenate the features and normalize the data self.features_notlive = np.concatenate( (self.features1.T, self.features2.T)).T self.features_notlive -= np.mean(self.features_notlive) self.features_notlive = gumpy.signal.normalize(self.features_notlive, 'min_max') # Method 2: DWT # We'll work with the data that was postprocessed using a butter bandpass if False: self.w = [0, 256] # extract the features self.trials = self.data_notlive.trials self.fs = self.data_notlive.sampling_freq self.features1 = np.array( self.dwt_features(self.data_notlive_filtno, self.trials, 5, self.fs, self.w, 3, "db4")) self.features2 = np.array( self.dwt_features(self.data_notlive_filtno, self.trials, 5, self.fs, self.w, 4, "db4")) # concat the features and normalize self.features_notlive = np.concatenate( (self.features1.T, self.features2.T)).T self.features_notlive -= np.mean(self.features_notlive) self.features_notlive = gumpy.signal.normalize( self.features_notlive, 'min_max') self.pos_fit = False #print(self.features_notlive) #print(self.labels_notlive) def fit(self): # Sequential Feature Selection Algorithm out_realtime = gumpy.features.sequential_feature_selector_realtime( self.features_notlive, self.labels_notlive, 'SVM', 1, 2, 'SFFS') print('\n\nAverage score:', out_realtime[1] * 100) self.sfs_object = out_realtime[3] self.estimator_object = self.sfs_object.est_ ### fit the estimator object with the selected features and test self.estimator_object.fit( self.sfs_object.transform(self.features_notlive), self.labels_notlive) self.labels_pred_notlive = self.estimator_object.predict( self.sfs_object.transform(self.features_notlive)) self.acc_notlive = 100 * np.sum(abs(self.labels_pred_notlive-self.labels_notlive)<1) \ / np.shape(self.labels_pred_notlive) print('\nAccuracy of notlive fit:', self.acc_notlive[0], '\n') self.pos_fit = True def classify_live(self, data_live): # data_live should be an object of class NST_EEG_LIVE self.labels_live = data_live.labels # butter-bandpass filtered version of notlive data self.data_live_filtbp = gumpy.signal.butter_bandpass(data_live, lo=2, hi=60) # apply the notch filter self.data_live_filtno = gumpy.signal.notch( self.data_live_filtbp, cutoff=self.notch_f0, Q=self.notch_Q, fs=self.data_notlive.sampling_freq) self.alpha_bands = np.array( self.alpha_subBP_features(self.data_live_filtno)) self.beta_bands = np.array( self.beta_subBP_features(self.data_live_filtno)) # Feature extraction using sub-bands # Method 1: logarithmic sub-band power self.w1 = [0, 125] self.w2 = [125, 250] self.features1_live = self.log_subBP_feature_extraction( self.alpha_bands, self.beta_bands, data_live.trials, data_live.sampling_freq, self.w1) self.features2_live = self.log_subBP_feature_extraction( self.alpha_bands, self.beta_bands, data_live.trials, data_live.sampling_freq, self.w2) # concatenate the features and normalize the data self.features_live = np.concatenate( (self.features1_live.T, self.features2_live.T)).T self.features_live -= np.mean(self.features_live) self.features_live = gumpy.signal.normalize(self.features_live, 'min_max') # Method 2: DWT # We'll work with the data that was postprocessed using a butter bandpass if False: self.w = [0, 256] # extract the features self.trials_live = data_live.trials self.fs_live = data_live.sampling_freq self.features1_live = np.array( self.dwt_features(self.data_live_filtno, self.trials_live, 5, self.fs_live, self.w, 3, "db4")) self.features2_live = np.array( self.dwt_features(self.data_live_filtno, self.trials_live, 5, self.fs_live, self.w, 4, "db4")) # concat the features and normalize self.features_live = np.concatenate( (self.features1_live.T, self.features2_live.T)).T self.features_live -= np.mean(self.features_live) self.features_live = gumpy.signal.normalize( self.features_live, 'min_max') ### predict label of live trial and check whether it is correct if self.pos_fit: labels_pred_live = self.estimator_object.predict( self.sfs_object.transform(self.features_live)) pred_true = (self.labels_live - labels_pred_live) == 1 else: print("No fit was performed yet. Please train the model first.\n") sys.exit() return labels_pred_live, pred_true # Alpha and Beta sub-bands def alpha_subBP_features(self, data): # filter data in sub-bands by specification of low- and high-cut frequencies alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6) alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6) alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6) alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6) # return a list of sub-bands return [alpha1, alpha2, alpha3, alpha4] def beta_subBP_features(self, data): beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6) beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6) beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6) beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6) return [beta1, beta2, beta3, beta4] # Feature extraction using sub-bands (The following examples show how the sub-bands can be used to extract features.) # Method 1: logarithmic sub-band power def powermean(self, data, trial, fs, w): return np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],0],2).mean(), \ np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],1],2).mean(), \ np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],2],2).mean() def log_subBP_feature_extraction(self, alpha, beta, trials, fs, w): # number of features combined for all trials n_features = 15 # initialize the feature matrix X = np.zeros((len(trials), n_features)) # Extract features for t, trial in enumerate(trials): power_c31, power_c41, power_cz1 = self.powermean( alpha[0], trial, fs, w) power_c32, power_c42, power_cz2 = self.powermean( alpha[1], trial, fs, w) power_c33, power_c43, power_cz3 = self.powermean( alpha[2], trial, fs, w) power_c34, power_c44, power_cz4 = self.powermean( alpha[3], trial, fs, w) power_c31_b, power_c41_b, power_cz1_b = self.powermean( beta[0], trial, fs, w) X[t, :] = np.array([ np.log(power_c31), np.log(power_c41), np.log(power_cz1), np.log(power_c32), np.log(power_c42), np.log(power_cz2), np.log(power_c33), np.log(power_c43), np.log(power_cz3), np.log(power_c34), np.log(power_c44), np.log(power_cz4), np.log(power_c31_b), np.log(power_c41_b), np.log(power_cz1_b) ]) return X # Method 2: DWT def dwt_features(self, data, trials, level, sampling_freq, w, n, wavelet): import pywt # number of features per trial n_features = 9 # allocate memory to store the features X = np.zeros((len(trials), n_features)) # Extract Features for t, trial in enumerate(trials): signals = data[trial + sampling_freq * 5 + (w[0]):trial + sampling_freq * 5 + (w[1])] coeffs_c3 = pywt.wavedec(data=signals[:, 0], wavelet=wavelet, level=level) coeffs_c4 = pywt.wavedec(data=signals[:, 1], wavelet=wavelet, level=level) coeffs_cz = pywt.wavedec(data=signals[:, 2], wavelet=wavelet, level=level) X[t, :] = np.array([ np.std(coeffs_c3[n]), np.mean(coeffs_c3[n]**2), np.std(coeffs_c4[n]), np.mean(coeffs_c4[n]**2), np.std(coeffs_cz[n]), np.mean(coeffs_cz[n]**2), np.mean(coeffs_c3[n]), np.mean(coeffs_c4[n]), np.mean(coeffs_cz[n]) ]) return X