Beispiel #1
0
class liveEEG_CNN():
    def __init__(self, data_dir, filename_notlive, n_classes=2):
        self.print_version_info()

        self.data_dir = data_dir
        self.cwd = os.getcwd()
        self.n_classes = n_classes
        kwargs = {'n_classes': self.n_classes}

        ### initialise dataset
        self.data_notlive = NST_EEG_LIVE(self.data_dir, filename_notlive,
                                         **kwargs)
        self.data_notlive.load()
        self.data_notlive.print_stats()

        self.MODELNAME = "CNN_STFT"

        self.x_stacked = np.zeros((1, self.data_notlive.sampling_freq *
                                   self.data_notlive.trial_total, 3))
        self.y_stacked = np.zeros((1, self.n_classes))

        self.fs = 256
        self.lowcut = 2
        self.highcut = 60
        self.anti_drift = 0.5
        self.f0 = 50.0  # freq to be removed from signal (Hz) for notch filter
        self.Q = 30.0  # quality factor for notch filter
        # w0 = f0 / (fs / 2)
        self.AXIS = 0
        self.CUTOFF = 50.0
        self.w0 = self.CUTOFF / (self.fs / 2)
        self.dropout = 0.5

        ### reduce sampling frequency to 256
        ### most previous data is at 256 Hz, but no it has to be recorded at 512 Hz due to the combination of EMG and EEG
        ### hence, EEG is downsampled by a factor of 2 here
        if self.data_notlive.sampling_freq > self.fs:
            self.data_notlive.raw_data = decimate(
                self.data_notlive.raw_data,
                int(self.data_notlive.sampling_freq / self.fs),
                axis=0,
                zero_phase=True)
            self.data_notlive.sampling_freq = self.fs
            self.data_notlive.trials = np.floor(self.data_notlive.trials /
                                                2).astype(int)

        ### filter the data
        self.data_notlive_filt = gumpy.signal.notch(self.data_notlive.raw_data,
                                                    self.CUTOFF, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_highpass(
            self.data_notlive_filt, self.anti_drift, self.AXIS)
        self.data_notlive_filt = gumpy.signal.butter_bandpass(
            self.data_notlive_filt, self.lowcut, self.highcut, self.AXIS)

        #self.min_cols = np.min(self.data_notlive_filt, axis=0)
        #self.max_cols = np.max(self.data_notlive_filt, axis=0)

        ### clip and normalise the data
        ### keep normalisation constants for lateron (hence no use of gumpy possible)
        self.sigma = np.min(np.std(self.data_notlive_filt, axis=0))
        self.data_notlive_clip = np.clip(self.data_notlive_filt,
                                         self.sigma * (-6), self.sigma * 6)

        self.notlive_mean = np.mean(self.data_notlive_clip, axis=0)
        self.notlive_std_dev = np.std(self.data_notlive_clip, axis=0)
        self.data_notlive_clip = (self.data_notlive_clip -
                                  self.notlive_mean) / self.notlive_std_dev
        #self.data_notlive_clip = gumpy.signal.normalize(self.data_notlive_clip, 'mean_std')

        ### extract the time within the trials of 10s for each class
        self.class1_mat, self.class2_mat = gumpy.utils.extract_trials_corrJB(
            self.data_notlive,
            filtered=self.data_notlive_clip)  #, self.data_notlive.trials,
        #self.data_notlive.labels, self.data_notlive.trial_total, self.fs)#, nbClasses=self.n_classes)
        #TODO: correct function extract_trials() trial len & trial offset

        ### concatenate data for training and create labels
        self.x_train = np.concatenate((self.class1_mat, self.class2_mat))
        self.labels_c1 = np.zeros((self.class1_mat.shape[0], ))
        self.labels_c2 = np.ones((self.class2_mat.shape[0], ))
        self.y_train = np.concatenate((self.labels_c1, self.labels_c2))

        ### for categorical crossentropy as an output of the CNN, another format of y is required
        self.y_train = ku.to_categorical(self.y_train)

        if DEBUG:
            print("Shape of x_train: ", self.x_train.shape)
            print("Shape of y_train: ", self.y_train.shape)

        print("EEG Data loaded and processed successfully!")

        ### roll shape to match to the CNN
        self.x_rolled = np.rollaxis(self.x_train, 2, 1)

        if DEBUG:
            print('X shape: ', self.x_train.shape)
            print('X rolled shape: ', self.x_rolled.shape)

        ### augment data to have more samples for training
        self.x_augmented, self.y_augmented = gumpy.signal.sliding_window(
            data=self.x_train,
            labels=self.y_train,
            window_sz=4 * self.fs,
            n_hop=self.fs // 8,
            n_start=self.fs * 3)

        ### roll shape to match to the CNN
        self.x_augmented_rolled = np.rollaxis(self.x_augmented, 2, 1)
        print("Shape of x_augmented: ", self.x_augmented_rolled.shape)
        print("Shape of y_augmented: ", self.y_augmented.shape)

        ### try to load the .json model file, otherwise build a new model
        self.loaded = 0
        if os.path.isfile(os.path.join(self.cwd, self.MODELNAME + ".json")):
            self.load_CNN_model()
            if self.model:
                self.loaded = 1

        if self.loaded == 0:
            print("Could not load model, will build model.")
            self.build_CNN_model()
            if self.model:
                self.loaded = 1

        ### Create callbacks for saving
        saved_model_name = self.MODELNAME
        TMP_NAME = self.MODELNAME + "_" + "_C" + str(self.n_classes)
        for i in range(99):
            if os.path.isfile(saved_model_name + ".csv"):
                saved_model_name = TMP_NAME + "_run{0}".format(i)

        ### Save model -> json file
        json_string = self.model.to_json()
        model_file = saved_model_name + ".json"
        open(model_file, 'w').write(json_string)

        ### define where to save the parameters to
        model_file = saved_model_name + 'monitoring' + '.h5'
        checkpoint = ModelCheckpoint(model_file,
                                     monitor='val_loss',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='min')
        log_file = saved_model_name + '.csv'
        csv_logger = CSVLogger(log_file, append=True, separator=';')
        self.callbacks_list = [csv_logger, checkpoint]  # callback list

###############################################################################
### train the model with the notlive data or sinmply load a pretrained model

    def fit(self, load=False):
        #TODO: use method train_on_batch() to update model
        self.batch_size = 32
        self.model.compile(loss='categorical_crossentropy',
                           optimizer='adam',
                           metrics=['accuracy'])

        if not load:
            print('Train...')
            self.model.fit(self.x_augmented_rolled,
                           self.y_augmented,
                           batch_size=self.batch_size,
                           epochs=100,
                           shuffle=True,
                           validation_split=0.2,
                           callbacks=self.callbacks_list)
        else:
            print('Load...')
            self.model = keras.models.load_model(
                'CNN_STFTmonitoring.h5',
                custom_objects={
                    'Spectrogram': kapre.time_frequency.Spectrogram,
                    'Normalization2D': kapre.utils.Normalization2D
                })

        #CNN_STFT__C2_run4monitoring.h5

###############################################################################
### do the live classification

    def classify_live(self, data_live):
        ### perform the same preprocessing steps as in __init__()

        ### agina, donwsampling from 512 to 256 (see above)
        if data_live.sampling_freq > self.fs:
            data_live.raw_data = decimate(data_live.raw_data,
                                          int(self.data_notlive.sampling_freq /
                                              self.fs),
                                          axis=0,
                                          zero_phase=True)
            data_live.sampling_freq = self.fs

        self.y_live = data_live.labels

        self.data_live_filt = gumpy.signal.notch(data_live, self.CUTOFF,
                                                 self.AXIS)
        self.data_live_filt = gumpy.signal.butter_highpass(
            self.data_live_filt, self.anti_drift, self.AXIS)
        self.data_live_filt = gumpy.signal.butter_bandpass(
            self.data_live_filt, self.lowcut, self.highcut, self.AXIS)

        self.data_live_clip = np.clip(self.data_live_filt, self.sigma * (-6),
                                      self.sigma * 6)
        self.data_live_clip = (self.data_live_clip -
                               self.notlive_mean) / self.notlive_std_dev

        class1_mat, class2_mat = gumpy.utils.extract_trials_corrJB(
            data_live, filtered=self.data_live_clip)

        ### concatenate data  and create labels
        self.x_live = np.concatenate((class1_mat, class2_mat))

        self.x_live = self.x_live[:,
                    data_live.mi_interval[0]*data_live.sampling_freq\
                    :data_live.mi_interval[1]*data_live.sampling_freq, :]

        self.x_live = np.rollaxis(self.x_live, 2, 1)

        ### do the prediction
        pred_valid = 0
        y_pred = []
        pred_true = []
        if self.loaded and self.x_live.any():
            y_pred = self.model.predict(self.x_live, batch_size=64)
            print(y_pred)
            #classes = self.model.predict(self.x_live_augmented,batch_size=64)
            #pref0 = sum(classes[:,0])
            #pref1 = sum(classes[:,1])
            #if pref1 > pref0:
            #    y_pred = 1
            #else:
            #    y_pred = 0

            ### argmax because output is crossentropy
            y_pred = y_pred.argmax()
            pred_true = self.y_live == y_pred
            print('Real=', self.y_live)
            pred_valid = 1

        return y_pred, pred_true, pred_valid

###############################################################################

    def load_CNN_model(self):
        print('Load model', self.MODELNAME)
        model_path = self.MODELNAME + ".json"
        if not os.path.isfile(model_path):
            raise IOError('file "%s" does not exist' % (model_path))
        self.model = model_from_json(open(model_path).read(),
                                     custom_objects={
                                         'Spectrogram':
                                         kapre.time_frequency.Spectrogram,
                                         'Normalization2D':
                                         kapre.utils.Normalization2D
                                     })
        #self.model = load_model(self.cwd,self.MODELNAME,self.MODELNAME+'monitoring')
        #TODO: get it to work, but not urgently required
        #self.model = []

###############################################################################

    def build_CNN_model(self):
        ### define CNN architecture
        print('Build model...')
        self.model = Sequential()
        self.model.add(
            Spectrogram(n_dft=128,
                        n_hop=16,
                        input_shape=(self.x_augmented_rolled.shape[1:]),
                        return_decibel_spectrogram=False,
                        power_spectrogram=2.0,
                        trainable_kernel=False,
                        name='static_stft'))
        self.model.add(Normalization2D(str_axis='freq'))

        # Conv Block 1
        self.model.add(
            Conv2D(filters=24,
                   kernel_size=(12, 12),
                   strides=(1, 1),
                   name='conv1',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # Conv Block 2
        self.model.add(
            Conv2D(filters=48,
                   kernel_size=(8, 8),
                   name='conv2',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # Conv Block 3
        self.model.add(
            Conv2D(filters=96,
                   kernel_size=(4, 4),
                   name='conv3',
                   border_mode='same'))
        self.model.add(BatchNormalization(axis=1))
        self.model.add(
            MaxPooling2D(pool_size=(2, 2),
                         strides=(2, 2),
                         padding='valid',
                         data_format='channels_last'))
        self.model.add(Activation('relu'))
        self.model.add(Dropout(self.dropout))

        # classificator
        self.model.add(Flatten())
        self.model.add(Dense(self.n_classes))  # two classes only
        self.model.add(Activation('softmax'))

        print(self.model.summary())
        self.saved_model_name = self.MODELNAME

###############################################################################

    def print_version_info(self):
        now = datetime.now()

        print('%s/%s/%s' % (now.year, now.month, now.day))
        print('Keras version: {}'.format(keras.__version__))
        if keras.backend._BACKEND == 'tensorflow':
            import tensorflow
            print('Keras backend: {}: {}'.format(keras.backend._backend,
                                                 tensorflow.__version__))
        else:
            import theano
            print('Keras backend: {}: {}'.format(keras.backend._backend,
                                                 theano.__version__))
        print('Keras image dim ordering: {}'.format(
            keras.backend.image_dim_ordering()))
        print('Kapre version: {}'.format(kapre.__version__))
class liveEEG():
    def __init__(self, cwd, filename_notlive, flag=0):
        # only used for testing
        self.flag = flag
        if self.flag == 1:
            self.data_notlive = NST_EEG_TEST(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        # cwd and filename_notlive specify the location of one .mat file where the recorded notlive data has been stored
        # print(filename_notlive, '\n')
        # load notlive data from path that has been specified
        if self.flag == 0:
            self.data_notlive = NST_EEG_LIVE(cwd, filename_notlive)
            self.data_notlive.load()
            self.data_notlive.print_stats()

        if 1:
            self.data_notlive.raw_data[:, 0] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 1] -= self.data_notlive.raw_data[:,
                                                                           2]
            self.data_notlive.raw_data[:, 0].shape

        self.labels_notlive = self.data_notlive.labels

        # butter-bandpass filtered version of notlive data
        self.data_notlive_filtbp = gumpy.signal.butter_bandpass(
            self.data_notlive, lo=2, hi=60)

        # frequency to be removed from the signal
        self.notch_f0 = 50
        # quality factor
        self.notch_Q = 50.0
        # get the cutoff frequency
        # self.notch_w0 = self.notch_f0/(self.data_notlive.sampling_freq/2)
        # apply the notch filter
        self.data_notlive_filtno = gumpy.signal.notch(
            self.data_notlive_filtbp,
            cutoff=self.notch_f0,
            Q=self.notch_Q,
            fs=self.data_notlive.sampling_freq)

        self.alpha_bands = np.array(
            self.alpha_subBP_features(self.data_notlive_filtno))
        self.beta_bands = np.array(
            self.beta_subBP_features(self.data_notlive_filtno))

        # Feature extraction using sub-bands
        # Method 1: logarithmic sub-band power
        self.w1 = [0, 125]
        self.w2 = [125, 250]

        self.features1 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w1)

        self.features2 = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, self.data_notlive.trials,
            self.data_notlive.sampling_freq, self.w2)

        # concatenate the features and normalize the data
        self.features_notlive = np.concatenate(
            (self.features1.T, self.features2.T)).T
        self.features_notlive -= np.mean(self.features_notlive)
        self.features_notlive = gumpy.signal.normalize(self.features_notlive,
                                                       'min_max')

        # Method 2: DWT
        # We'll work with the data that was postprocessed using a butter bandpass
        if False:
            self.w = [0, 256]
            # extract the features
            self.trials = self.data_notlive.trials
            self.fs = self.data_notlive.sampling_freq
            self.features1 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 3, "db4"))
            self.features2 = np.array(
                self.dwt_features(self.data_notlive_filtno, self.trials, 5,
                                  self.fs, self.w, 4, "db4"))
            # concat the features and normalize
            self.features_notlive = np.concatenate(
                (self.features1.T, self.features2.T)).T
            self.features_notlive -= np.mean(self.features_notlive)
            self.features_notlive = gumpy.signal.normalize(
                self.features_notlive, 'min_max')

        self.pos_fit = False
        #print(self.features_notlive)
        #print(self.labels_notlive)

    def fit(self):
        # Sequential Feature Selection Algorithm
        out_realtime = gumpy.features.sequential_feature_selector_realtime(
            self.features_notlive, self.labels_notlive, 'SVM', 1, 2, 'SFFS')
        print('\n\nAverage score:', out_realtime[1] * 100)
        self.sfs_object = out_realtime[3]
        self.estimator_object = self.sfs_object.est_
        ### fit the estimator object with the selected features and test
        self.estimator_object.fit(
            self.sfs_object.transform(self.features_notlive),
            self.labels_notlive)
        self.labels_pred_notlive = self.estimator_object.predict(
            self.sfs_object.transform(self.features_notlive))
        self.acc_notlive = 100 * np.sum(abs(self.labels_pred_notlive-self.labels_notlive)<1) \
                / np.shape(self.labels_pred_notlive)
        print('\nAccuracy of notlive fit:', self.acc_notlive[0], '\n')
        self.pos_fit = True

    def classify_live(self, data_live):
        # data_live should be an object of class NST_EEG_LIVE
        self.labels_live = data_live.labels

        # butter-bandpass filtered version of notlive data
        self.data_live_filtbp = gumpy.signal.butter_bandpass(data_live,
                                                             lo=2,
                                                             hi=60)

        # apply the notch filter
        self.data_live_filtno = gumpy.signal.notch(
            self.data_live_filtbp,
            cutoff=self.notch_f0,
            Q=self.notch_Q,
            fs=self.data_notlive.sampling_freq)

        self.alpha_bands = np.array(
            self.alpha_subBP_features(self.data_live_filtno))
        self.beta_bands = np.array(
            self.beta_subBP_features(self.data_live_filtno))

        # Feature extraction using sub-bands
        # Method 1: logarithmic sub-band power
        self.w1 = [0, 125]
        self.w2 = [125, 250]

        self.features1_live = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, data_live.trials,
            data_live.sampling_freq, self.w1)

        self.features2_live = self.log_subBP_feature_extraction(
            self.alpha_bands, self.beta_bands, data_live.trials,
            data_live.sampling_freq, self.w2)

        # concatenate the features and normalize the data
        self.features_live = np.concatenate(
            (self.features1_live.T, self.features2_live.T)).T
        self.features_live -= np.mean(self.features_live)
        self.features_live = gumpy.signal.normalize(self.features_live,
                                                    'min_max')

        # Method 2: DWT
        # We'll work with the data that was postprocessed using a butter bandpass
        if False:
            self.w = [0, 256]
            # extract the features
            self.trials_live = data_live.trials
            self.fs_live = data_live.sampling_freq
            self.features1_live = np.array(
                self.dwt_features(self.data_live_filtno, self.trials_live, 5,
                                  self.fs_live, self.w, 3, "db4"))
            self.features2_live = np.array(
                self.dwt_features(self.data_live_filtno, self.trials_live, 5,
                                  self.fs_live, self.w, 4, "db4"))
            # concat the features and normalize
            self.features_live = np.concatenate(
                (self.features1_live.T, self.features2_live.T)).T
            self.features_live -= np.mean(self.features_live)
            self.features_live = gumpy.signal.normalize(
                self.features_live, 'min_max')

        ### predict label of live trial and check whether it is correct
        if self.pos_fit:
            labels_pred_live = self.estimator_object.predict(
                self.sfs_object.transform(self.features_live))
            pred_true = (self.labels_live - labels_pred_live) == 1

        else:
            print("No fit was performed yet. Please train the model first.\n")
            sys.exit()

        return labels_pred_live, pred_true

    # Alpha and Beta sub-bands
    def alpha_subBP_features(self, data):
        # filter data in sub-bands by specification of low- and high-cut frequencies
        alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6)
        alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6)
        alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6)
        alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6)
        # return a list of sub-bands
        return [alpha1, alpha2, alpha3, alpha4]

    def beta_subBP_features(self, data):
        beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6)
        beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6)
        beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6)
        beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6)
        return [beta1, beta2, beta3, beta4]

    # Feature extraction using sub-bands (The following examples show how the sub-bands can be used to extract features.)
    # Method 1: logarithmic sub-band power
    def powermean(self, data, trial, fs, w):
        return  np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],0],2).mean(), \
                np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],1],2).mean(), \
                np.power(data[trial+fs*5+w[0]: trial+fs*5+w[1],2],2).mean()

    def log_subBP_feature_extraction(self, alpha, beta, trials, fs, w):
        # number of features combined for all trials
        n_features = 15
        # initialize the feature matrix
        X = np.zeros((len(trials), n_features))

        # Extract features
        for t, trial in enumerate(trials):
            power_c31, power_c41, power_cz1 = self.powermean(
                alpha[0], trial, fs, w)
            power_c32, power_c42, power_cz2 = self.powermean(
                alpha[1], trial, fs, w)
            power_c33, power_c43, power_cz3 = self.powermean(
                alpha[2], trial, fs, w)
            power_c34, power_c44, power_cz4 = self.powermean(
                alpha[3], trial, fs, w)
            power_c31_b, power_c41_b, power_cz1_b = self.powermean(
                beta[0], trial, fs, w)

            X[t, :] = np.array([
                np.log(power_c31),
                np.log(power_c41),
                np.log(power_cz1),
                np.log(power_c32),
                np.log(power_c42),
                np.log(power_cz2),
                np.log(power_c33),
                np.log(power_c43),
                np.log(power_cz3),
                np.log(power_c34),
                np.log(power_c44),
                np.log(power_cz4),
                np.log(power_c31_b),
                np.log(power_c41_b),
                np.log(power_cz1_b)
            ])
        return X

    # Method 2: DWT
    def dwt_features(self, data, trials, level, sampling_freq, w, n, wavelet):
        import pywt

        # number of features per trial
        n_features = 9
        # allocate memory to store the features
        X = np.zeros((len(trials), n_features))

        # Extract Features
        for t, trial in enumerate(trials):
            signals = data[trial + sampling_freq * 5 + (w[0]):trial +
                           sampling_freq * 5 + (w[1])]
            coeffs_c3 = pywt.wavedec(data=signals[:, 0],
                                     wavelet=wavelet,
                                     level=level)
            coeffs_c4 = pywt.wavedec(data=signals[:, 1],
                                     wavelet=wavelet,
                                     level=level)
            coeffs_cz = pywt.wavedec(data=signals[:, 2],
                                     wavelet=wavelet,
                                     level=level)

            X[t, :] = np.array([
                np.std(coeffs_c3[n]),
                np.mean(coeffs_c3[n]**2),
                np.std(coeffs_c4[n]),
                np.mean(coeffs_c4[n]**2),
                np.std(coeffs_cz[n]),
                np.mean(coeffs_cz[n]**2),
                np.mean(coeffs_c3[n]),
                np.mean(coeffs_c4[n]),
                np.mean(coeffs_cz[n])
            ])

        return X