class ClassifierTrainer(QMainWindow): DEFAULT_K_FOLD_REPETITIONS = 1000 data_set: DataSet classifier: cls.SimpleClassifier def __init__(self): super().__init__() self.setWindowTitle("Classifier Trainer") self.classifier = None self.data_set = None self.filter_settings = None self.feature_extraction_info = None self.feature_types = [] self.trial_classes = [] self.root_directories = [] self.root_directories.append(global_config.IMAGES_SSD_DRIVER_LETTER + ":\\EEG_GUI_OpenBCI\\eeg_recordings\\vibro_tactile_27_12_2020\\trial_02") self.root_widget = QWidget() self.root_layout = QGridLayout() self.root_layout.setAlignment(PyQt5.QtCore.Qt.AlignTop | PyQt5.QtCore.Qt.AlignVCenter) self.root_widget.setLayout(self.root_layout) self.setCentralWidget(self.root_widget) # Data which should not get loaded every train. Saved globally to avoid redundancy. self.loaded_eeg_data = [] # Title title = QLabel("<h1> Train A Classifier </h1>") title.setMargin(20) title.setAlignment(PyQt5.QtCore.Qt.AlignCenter) self.root_layout.addWidget(title, 0, 0, 1, 3) # Load Training Data load_training_data_label = QLabel("<h2> Load Training Data </h2>") load_training_data_label.setAlignment(PyQt5.QtCore.Qt.AlignCenter) self.root_layout.addWidget(load_training_data_label, 1, 0, 1, 3) self.root_directory_label = QLabel("path to directories") self.add_root_directory = QPushButton("Add path") self.pop_root_directory = QPushButton("Pop path") self.root_directory_changed = True self.add_root_directory.clicked.connect(self.add_root_directory_clicked) self.pop_root_directory.clicked.connect(self.pop_root_directory_clicked) self.root_layout.addWidget(utils.construct_horizontal_box([ QLabel("Root Directories: "), self.root_directory_label, self.add_root_directory, self.pop_root_directory ]), 2, 0, 2, 3) pre_processing_label = QLabel("<h2> Pre-Process Data </h2>") pre_processing_label.setAlignment(PyQt5.QtCore.Qt.AlignCenter) self.root_layout.addWidget(pre_processing_label, 6, 0, 1, 3) self.bandpass_min_edit = QLineEdit("15") self.bandpass_max_edit = QLineEdit("30") self.notch_filter_checkbox = QCheckBox("Notch Filter") self.notch_filter_checkbox.setChecked(True) self.root_layout.addWidget(utils.construct_horizontal_box( [QLabel("Bandpass Filter: "), QLabel("from "), self.bandpass_min_edit, QLabel(" to "), self.bandpass_max_edit, self.notch_filter_checkbox]), 7, 0, 1, 3) self.adaptive_filtering_checkbox = QCheckBox("Adaptive Filtering") self.adaptive_reference_electrode = QLineEdit() self.adaptive_frequencies_edit = QLineEdit() self.adaptive_bandwidths_edit = QLineEdit() self.root_layout.addWidget(utils.construct_horizontal_box( [ self.adaptive_filtering_checkbox, QLabel("Reference Electrode: "), self.adaptive_reference_electrode, QLabel("Frequencies (comma separated):"), self.adaptive_frequencies_edit, QLabel("Bandwidths (comma separated):"), self.adaptive_bandwidths_edit ], ), 8, 0, 1, 3) self.re_reference_checkbox = QCheckBox("Re-Reference data") self.reference_electrode_edit = QLineEdit("3") self.root_layout.addWidget(utils.construct_horizontal_box([ self.re_reference_checkbox, QLabel("New reference electrode: "), self.reference_electrode_edit ]), 9, 0, 1, 3) self.feature_scaling_radio_group = QGroupBox() self.selected_feature_scaling_type = data.FeatureScalingType.NO_SCALING self.min_max_scaling_radio_btn = QRadioButton("MinMax Scaling") self.min_max_scaling_radio_btn.setChecked(True) self.min_max_scaling_radio_btn.clicked.connect\ (lambda: self.set_feature_scaling_type(data.FeatureScalingType.MIN_MAX_SCALING)) self.standardization_scaling_radio_btn = QRadioButton("Standardization (Z-Score)") self.standardization_scaling_radio_btn.clicked.connect\ (lambda: self.set_feature_scaling_type(data.FeatureScalingType.STANDARDIZATION)) self.mean_normalization_radio_btn = QRadioButton("Mean Normalization") self.mean_normalization_radio_btn.clicked.connect\ (lambda: self.set_feature_scaling_type(data.FeatureScalingType.MEAN_NORMALIZATION)) self.unit_length_scaling_radio_btn = QRadioButton("Unit Length") self.unit_length_scaling_radio_btn.clicked.connect\ (lambda: self.set_feature_scaling_type(data.FeatureScalingType.UNIT_LENGTH)) self.feature_scaling_radio_layout = QHBoxLayout() self.feature_scaling_radio_layout.addWidget(self.min_max_scaling_radio_btn) self.feature_scaling_radio_layout.addWidget(self.standardization_scaling_radio_btn) self.feature_scaling_radio_layout.addWidget(self.mean_normalization_radio_btn) self.feature_scaling_radio_layout.addWidget(self.unit_length_scaling_radio_btn) self.feature_scaling_radio_group.setLayout(self.feature_scaling_radio_layout) self.feature_scaling_radio_group.setCheckable(True) self.feature_scaling_radio_group.setChecked(False) self.feature_scaling_radio_group.setTitle("Feature Scaling") self.root_layout.addWidget(self.feature_scaling_radio_group, 10, 0, 1, 3) feature_extraction_label = QLabel("<h2> Extract Features </h2>") feature_extraction_label.setAlignment(PyQt5.QtCore.Qt.AlignCenter) self.root_layout.addWidget(feature_extraction_label, 11, 0, 1, 3) self.electrodes_edit = QLineEdit("5,4,2,1") self.root_layout.addWidget(utils.construct_horizontal_box([ QLabel("Include data from electrodes (comma separated):"), self.electrodes_edit]), 12, 0, 1, 2) self.band_amplitude_checkbox = QCheckBox("Average Band Amplitude") self.band_amplitude_min_edit = QLineEdit() self.band_amplitude_max_edit = QLineEdit() self.fft_window_combo = QComboBox() for window_size in FFT_WINDOW_SIZES: self.fft_window_combo.addItem(str(window_size)) self.k_value_edit = QLineEdit() self.accuracy_threshold_edit = QLineEdit() self.regularization_edit = QLineEdit() self.root_layout.addWidget(utils.construct_horizontal_box([ QLabel("FFT Window Size:"), self.fft_window_combo, QLabel("K value:"), self.k_value_edit, QLabel("Accuracy Threshold (0 - 1):"), self.accuracy_threshold_edit, QLabel("Regularization Parameter:"), self.regularization_edit ]), 13, 0, 1, 3) self.root_layout.addWidget(utils.construct_horizontal_box([ self.band_amplitude_checkbox, QLabel("Frequency band from "), self.band_amplitude_min_edit, QLabel(" up to "), self.band_amplitude_max_edit ]), 14, 0, 1, 3) # Extract features as frequency band width and multiple frequency band centers. self.frequency_bands_checkbox = QCheckBox("Multiple Frequency Bands") self.band_width_edit = QLineEdit("1") self.center_frequencies_edit = QLineEdit("20,24") self.peak_frequency_checkbox = QCheckBox("Peak Frequency") self.root_layout.addWidget(utils.construct_horizontal_box([ self.frequency_bands_checkbox, QLabel("Bandwidth: "), self.band_width_edit, QLabel("Center Frequencies (comma separated): "), self.center_frequencies_edit, self.peak_frequency_checkbox ]), 15, 0, 1, 3) classifier_type_label = QLabel("<p>Classifier Type:</p>") self.classifier_type_combo = QComboBox() self.classifier_type_combo.addItems(AVAILABLE_CLASSIFIERS) self.root_layout.addWidget(utils.construct_horizontal_box([ classifier_type_label, self.classifier_type_combo ]), 16, 0, 1, 3) self.extract_features_btn = QPushButton("Extract Features") self.extract_features_btn.clicked.connect(self.extract_features_clicked) self.shuffle_data_set = QPushButton("Shuffle DataSet") self.shuffle_data_set.clicked.connect(self.shuffle_data_set_clicked) self.train_classifier_btn = QPushButton("Train Classifier") self.train_classifier_btn.clicked.connect(self.train_classifier_clicked) self.test_classifier_btn = QPushButton("Test Classifier") self.test_classifier_btn.clicked.connect(self.test_classifier_clicked) self.root_layout.addWidget(utils.construct_horizontal_box([ self.extract_features_btn, self.shuffle_data_set, self.train_classifier_btn, self.test_classifier_btn ]), 17, 0, 1, 3) self.performance_report_btn = QPushButton("Performance Report") self.performance_report_btn.clicked.connect(self.generate_performance_report) self.error_description_btn = QPushButton("Error Description") self.error_description_btn.clicked.connect(self.generate_error_descriptions) self.visualize_data_btn = QPushButton("Visualize Data") self.visualize_data_btn.clicked.connect(self.visualize_data) self.k_fold_edit = QLineEdit() self.k_fold_btn = QPushButton("k fold cross validation") self.k_fold_btn.clicked.connect(self.k_fold_cross_validation_clicked) self.repeated_k_fold_btn = QPushButton("repeated k fold") self.repeated_k_fold_btn.clicked.connect(self.repeated_k_fold_clicked) self.root_layout.addWidget(utils.construct_horizontal_box([ self.performance_report_btn, self.error_description_btn, self.visualize_data_btn, self.k_fold_edit, self.k_fold_btn, self.repeated_k_fold_btn ]), 18, 0, 1, 3) self.update_root_directories_label() def set_feature_scaling_type(self, feature_scaling_type: data.FeatureScalingType): self.selected_feature_scaling_type = feature_scaling_type print("Setting selected feature scaling type to {}".format(feature_scaling_type)) def extract_features_clicked(self): # Construct filter settings to loaded data. bandpass_min = -1 bandpass_max = -1 notch_filter = self.notch_filter_checkbox.isChecked() if utils.is_float(self.bandpass_min_edit.text()): bandpass_min = float(self.bandpass_min_edit.text()) if utils.is_float(self.bandpass_max_edit.text()): bandpass_max = float(self.bandpass_max_edit.text()) adaptive_settings = None if self.adaptive_filtering_checkbox.isChecked(): reference_electrode = int(self.adaptive_reference_electrode.text()) frequencies = [] widths = [] for freq_str in self.adaptive_frequencies_edit.text().split(","): frequencies.append(float(freq_str)) for width_str in self.adaptive_bandwidths_edit.text().split(","): widths.append(float(width_str)) adaptive_settings = utils.AdaptiveFilterSettings(reference_electrode, frequencies, widths) reference_electrode = 0 if self.re_reference_checkbox.isChecked() and utils.is_integer(self.reference_electrode_edit.text()): reference_electrode = int(self.reference_electrode_edit.text()) filter_settings = utils.FilterSettings(global_config.SAMPLING_RATE, bandpass_min, bandpass_max, notch_filter=notch_filter, adaptive_filter_settings=adaptive_settings, reference_electrode=reference_electrode) if self.root_directory_changed: self.loaded_eeg_data = utils.load_data(self.root_directories) eeg_data, classes, sampling_rate, self.trial_classes = \ utils.slice_and_filter_data(self.root_directories, filter_settings, self.loaded_eeg_data) labels = np.array(classes).reshape((-1, 1)) if len(eeg_data) != 0 and len(classes) != 0: print("Data loaded successfully") else: print("Could not load data") return # Construct feature descriptors. # Obtain the range of channels to be included electrode_list = [] for electrode_str in self.electrodes_edit.text().split(","): electrode_list.append(int(electrode_str)) fft_window_size = float(self.fft_window_combo.currentText()) feature_types = [] if self.band_amplitude_checkbox.isChecked(): band_amplitude_min_freq = -1 band_amplitude_max_freq = -1 if utils.is_float(self.band_amplitude_min_edit.text()): band_amplitude_min_freq = float(self.band_amplitude_min_edit.text()) if utils.is_float(self.band_amplitude_max_edit.text()): band_amplitude_max_freq = float(self.band_amplitude_max_edit.text()) if band_amplitude_min_freq != -1 and band_amplitude_max_freq != -1: feature_types.append( utils.AverageBandAmplitudeFeature( utils.FrequencyBand(band_amplitude_min_freq, band_amplitude_max_freq), fft_window_size)) if self.frequency_bands_checkbox.isChecked(): band_width = -1 peak_frequency = self.peak_frequency_checkbox.isChecked() if utils.is_float(self.band_width_edit.text()): band_width = float(self.band_width_edit.text()) center_frequencies_str_list = self.center_frequencies_edit.text().split(",") center_frequencies = [] for center_freq_str in center_frequencies_str_list: if utils.is_float(center_freq_str): center_frequencies.append(float(center_freq_str)) if len(center_frequencies) != 0 and band_width != -1: feature_types.append( utils.FrequencyBandsAmplitudeFeature(center_frequencies, band_width, fft_window_size, peak_frequency)) feature_extraction_info = utils.FeatureExtractionInfo(sampling_rate, electrode_list) self.filter_settings = filter_settings self.feature_extraction_info = feature_extraction_info self.feature_types = feature_types # Extract features extracted_data = utils.extract_features( eeg_data, feature_extraction_info, feature_types) feature_matrix = data.construct_feature_matrix(extracted_data) self.data_set = DataSet(feature_matrix, labels, add_x0=False, shuffle=False) print("Features extracted successfully...") def shuffle_data_set_clicked(self): if self.data_set is not None: self.data_set = \ DataSet(self.data_set.raw_feature_matrix(), self.data_set.feature_matrix_labels(), False, True) print("Data set shuffled!!!") def train_classifier_clicked(self): print("train classifier clicked") if self.data_set is None: self.extract_features_clicked() else: print("*"*10 + "Using existing feature matrix, re-extract if changes were made" + "*"*10) accuracy_threshold = self.get_accuracy_threshold() regularization_param = self.get_regularization_param() selected_classifier = self.classifier_type_combo.currentText() feature_matrix = self.data_set.raw_feature_matrix() labels = self.data_set.feature_matrix_labels() # Train Classifier if selected_classifier == cls.LogisticRegressionClassifier.NAME: classifier = cls.LogisticRegressionClassifier(feature_matrix, labels, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier cost = classifier.train(accuracy_threshold=accuracy_threshold) plt.plot(cost) plt.xlabel("Iteration Number") plt.ylabel("Training Set Cost") plt.title("Gradient Descent Cost Curve") plt.show() print("Training set accuracy = {}".format(classifier.training_set_accuracy())) print("Logistic Regression trained successfully, test set accuracy = {}".format(classifier.test_set_accuracy())) print("Cross validation accuracy = {}".format(classifier.test_set_accuracy())) elif selected_classifier == cls.KNearestNeighborsClassifier.NAME: k_value = self.get_k_value() print("Using K value of {}".format(k_value)) classifier = cls.KNearestNeighborsClassifier(feature_matrix, labels, k_value, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier print("KNN training set accuracy = {}".format(classifier.training_set_accuracy())) print("KNN test set accuracy = {}".format(classifier.test_set_accuracy())) print("KNN cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) k_values, accuracy = classifier.cross_validation_learning_curve() plt.plot(k_values, accuracy) plt.xlabel("K Neighbors") plt.ylabel("Accuracy") plt.title("Accuracy graph for K values") plt.show() elif selected_classifier == cls.PerceptronClassifier.NAME: classifier = cls.PerceptronClassifier(feature_matrix, labels, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier classifier.train(accuracy_threshold=accuracy_threshold) print("Perceptron training set accuracy = {}".format(classifier.training_set_accuracy())) print("Perceptron test set accuracy = {}".format(classifier.test_set_accuracy())) print("Perceptron Cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) elif selected_classifier == cls.SvmClassifier.NAME: classifier = cls.SvmClassifier(feature_matrix, labels, regularization_param, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier classifier.train() print("SVM training set accuracy = {}".format(classifier.training_set_accuracy())) print("SVM test set accuracy = {}".format(classifier.test_set_accuracy())) print("SVM cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) elif selected_classifier == cls.LdaClassifier.NAME: classifier = cls.LdaClassifier(feature_matrix, labels, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier classifier.train() print("LDA training set accuracy = {}".format(classifier.training_set_accuracy())) print("LDA test set accuracy = {}".format(classifier.test_set_accuracy())) print("LDA cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) # TODO: Learning curves elif selected_classifier == cls.ANNClassifier.NAME: classifier = cls.ANNClassifier(feature_matrix, labels, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier classifier.train() print("MLP training set accuracy = {}".format(classifier.training_set_accuracy())) print("MLP test set accuracy = {}".format(classifier.test_set_accuracy())) print("MLP cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) elif selected_classifier == cls.VotingClassifier.NAME: classifier = cls.VotingClassifier(feature_matrix, labels, DEFAULT_VOTING_CLASSIFIERS, shuffle=False) classifier.apply_feature_scaling(self.selected_feature_scaling_type) self.classifier = classifier self.classifier.train() print("VOTING training set accuracy = {}".format(classifier.training_set_accuracy())) print("VOTING test set accuracy = {}".format(classifier.test_set_accuracy())) print("VOTING cross validation accuracy = {}".format(classifier.cross_validation_accuracy())) self.root_directory_changed = False def k_fold_cross_validation_clicked(self): if self.classifier is not None and self.classifier.get_data_set() is not None: try: k = int(self.k_fold_edit.text()) print(self.classifier.k_fold_cross_validation(k)) except ValueError: pass def repeated_k_fold_clicked(self): if self.classifier is not None: k = int(self.k_fold_edit.text()) average, std = self.classifier.repeated_k_fold_cross_validation(k, self.DEFAULT_K_FOLD_REPETITIONS) print("Average = {}, std = {}".format(average, std)) def generate_performance_report(self): feature_matrix = self.data_set.raw_feature_matrix() labels = self.data_set.feature_matrix_labels() # Logistic Regression cls1 = cls.LogisticRegressionClassifier(feature_matrix, labels, shuffle=False) cls1.apply_feature_scaling(self.selected_feature_scaling_type) cls1.train(accuracy_threshold=self.get_accuracy_threshold()) # KNN k_value = self.get_k_value() print("Using K value of {}".format(k_value)) cls2 = cls.KNearestNeighborsClassifier(feature_matrix, labels, k_value, shuffle=False) cls2.apply_feature_scaling(self.selected_feature_scaling_type) # Perceptron cls3 = cls.PerceptronClassifier(feature_matrix, labels, shuffle=False) cls3.apply_feature_scaling(self.selected_feature_scaling_type) cls3.train(accuracy_threshold=self.get_accuracy_threshold()) # SVM cls4 = cls.SvmClassifier(feature_matrix, labels, self.get_regularization_param(), shuffle=False) cls4.apply_feature_scaling(self.selected_feature_scaling_type) cls4.train() # LDA cls5 = cls.LdaClassifier(feature_matrix, labels, shuffle=False) cls5.apply_feature_scaling(self.selected_feature_scaling_type) cls5.train() # MLP cls6 = cls.ANNClassifier(feature_matrix, labels, shuffle=False) cls6.apply_feature_scaling(self.selected_feature_scaling_type) cls6.train() # Get performance measure from each prm1 = cls1.performance_measure() prm2 = cls2.performance_measure() prm3 = cls3.performance_measure() prm4 = cls4.performance_measure() prm5 = cls5.performance_measure() prm6 = cls6.performance_measure() plot_data = np.vstack(( prm1.as_row_array(), prm2.as_row_array(), prm3.as_row_array(), prm4.as_row_array(), prm5.as_row_array(), prm6.as_row_array() )) plot_data = np.transpose(plot_data) x = np.arange(6) plt.title("Classifiers' Accuracy") plt.bar(x + 0.0, plot_data[0], width=0.25, label="Training Accuracy") plt.bar(x + 0.25, plot_data[1], width=0.25, label="Cross Validation Accuracy") plt.bar(x + 0.5, plot_data[2], width=0.25, label="Testing Accuracy") plt.xticks(x + 0.125, ("Logistic Regression", "kNN", "Perceptron", "SVM", "LDA", "MLP")) plt.ylabel("Accuracy %") plt.legend(loc="best") plt.show() def generate_error_descriptions(self): # TODO: Fix the error description method in all the classifiers. Problem with sample size and shape. feature_matrix = self.data_set.raw_feature_matrix() labels = self.data_set.feature_matrix_labels() # Logistic Regression cls1 = cls.LogisticRegressionClassifier(feature_matrix, labels, shuffle=False) cls1.apply_feature_scaling(self.selected_feature_scaling_type) cls1.train(accuracy_threshold=self.get_accuracy_threshold()) # KNN k_value = self.get_k_value() print("Using K value of {}".format(k_value)) cls2 = cls.KNearestNeighborsClassifier(feature_matrix, labels, k_value, shuffle=False) cls2.apply_feature_scaling(self.selected_feature_scaling_type) # Perceptron cls3 = cls.PerceptronClassifier(feature_matrix, labels, shuffle=False) cls3.apply_feature_scaling(self.selected_feature_scaling_type) cls3.train(accuracy_threshold=self.get_accuracy_threshold()) # SVM cls4 = cls.SvmClassifier(feature_matrix, labels, self.get_regularization_param(), shuffle=False) cls4.apply_feature_scaling(self.selected_feature_scaling_type) cls4.train() # LDA cls5 = cls.LdaClassifier(feature_matrix, labels, shuffle=False) cls5.apply_feature_scaling(self.selected_feature_scaling_type) cls5.train() # MLP cls6 = cls.ANNClassifier(feature_matrix, labels, shuffle=False) cls6.apply_feature_scaling(self.selected_feature_scaling_type) cls6.train() errd1 = cls1.error_description() errd2 = cls2.error_description() errd3 = cls3.error_description() errd4 = cls4.error_description() errd5 = cls5.error_description() errd6 = cls6.error_description() unique_labels = self.data_set.unique_labels().flatten() text_labels = [] for label in unique_labels: for trial_class in self.trial_classes: if trial_class.label == label: text_labels.append(trial_class.name) label_count = unique_labels.size plot_data = np.vstack(( errd1.as_row_array(), errd2.as_row_array(), errd3.as_row_array(), errd4.as_row_array(), errd5.as_row_array(), errd6.as_row_array() )) plot_data = plot_data.transpose() x = np.arange(6) plt.title("Error Description") for i in range(label_count): plt.bar(x + 0.25 * i, plot_data[i], width=0.25, label=text_labels[i]) plt.xticks(x + 0.125, ("Logistic Regression", "kNN", "Perceptron", "SVM", "LDA", "MLP")) plt.ylabel("Error Percent") plt.legend(loc="best") plt.show() def visualize_data(self): if self.data_set is None: print("Please extract features before trying to visualize...") return self.data_set.apply_feature_scaling(self.selected_feature_scaling_type) feature_matrix = self.data_set.scaled_feature_matrix() labels = self.data_set.feature_matrix_labels() pca = PCA(n_components=2) x = pca.fit_transform(feature_matrix) plt.figure() plt.title("Data reduced to 2D using PCA") unique_labels = np.unique(labels) label_to_name = {} for label in unique_labels: for trial_class in self.trial_classes: if trial_class.label == label: label_to_name[label] = trial_class.name for label in unique_labels: x_values = x[labels.flatten() == label, :] plt.scatter(x_values[:, 0], x_values[:, 1], label=f"{label_to_name[label]}") ratio1 = int(pca.explained_variance_ratio_[0] * 100 * 100) / 100 ratio2 = int(pca.explained_variance_ratio_[1] * 100 * 100) / 100 plt.xlabel(f"PC1 %{ratio1}") plt.ylabel(f"PC2 %{ratio2}") plt.legend(loc="best") plt.show() def get_k_value(self) -> int: k_value = 5 if utils.is_integer(self.k_value_edit.text()): k_value = int(self.k_value_edit.text()) if k_value <= 0: k_value = 1 return k_value def get_accuracy_threshold(self) -> float: accuracy_threshold = 1 if utils.is_float(self.accuracy_threshold_edit.text()): accuracy_threshold = float(self.accuracy_threshold_edit.text()) return accuracy_threshold def get_regularization_param(self) -> float: c = 1.0 if utils.is_float(self.regularization_edit.text()): c = float(self.regularization_edit.text()) return c def test_classifier_clicked(self): if self.classifier is not None and self.filter_settings is not None and self.feature_extraction_info is not None\ and len(self.feature_types) != 0 and len(self.trial_classes) != 0: trial_length = utils.obtain_trial_length_from_slice_index(self.root_directories[0]) online_config = OnlineClassifierConfigurations() online_config.feature_window_size = trial_length OnlineClassifierGui(self.classifier, self.filter_settings, self.feature_extraction_info, self.feature_types, self.trial_classes, self, online_config) else: print("Train Classifier Before Testing!") @staticmethod def validate_keywords(keywords: [str]) -> bool: for keyword in keywords: if len(keyword) == 0: return False return True def update_root_directories_label(self): label_str = "" for directory_str in self.root_directories: label_str += directory_str + "\n" self.root_directory_label.setText(label_str) def add_root_directory_clicked(self): path = QFileDialog.getExistingDirectory(self, "Add Root Directory...") self.root_directories.append(path) self.update_root_directories_label() self.root_directory_changed = True def pop_root_directory_clicked(self): if len(self.root_directories) != 0: self.root_directories.pop() self.update_root_directories_label() self.root_directory_changed = True