def fit(self, X_train, y_train, X_val, y_val): if X_train.ndim != 2: raise Exception('ValueError: `X_train` is incompatible: expected ndim=4, found ndim='+str(X_train.ndim)) elif X_val.ndim != 2: raise Exception('ValueError: `X_val` is incompatible: expected ndim=4, found ndim='+str(X_val.ndim)) print('Dimension of training set is: {} and label is: {}'.format(X_train.shape, y_train.shape)) print('Dimension of validation set is: {} and label is: {}'.format(X_val.shape, y_val.shape)) X_all = np.concatenate((X_train, X_val),axis=0) y_all = np.concatenate((y_train, y_val),axis=0) # Create a list where train data indices are -1 and validation data indices are 0 tr_index = np.full((X_train.shape[0]), -1) val_index = np.full((X_val.shape[0]), 0) split_index = np.concatenate((tr_index, val_index), axis=0).tolist() # Use the list to create PredefinedSplit pds = PredefinedSplit(test_fold = split_index) clf = GridSearchCV(estimator=SVC(), param_grid=self.tuned_parameters, cv=pds, scoring = 'accuracy') start = time.time() clf.fit(X_all , y_all) end = time.time() #Clasifying with an optimal parameter set Optimal_params = clf.best_params_ print(Optimal_params) classifier = SVC(**Optimal_params) classifier.fit(X_train, y_train) dump(classifier, self.model_path) write_log(filepath=self.time_log, data=['time_log'], mode='w') write_log(filepath=self.time_log, data=[end-start], mode='a')
def k_fold_cross_validation(subject): # create object of DataLoader loader = DataLoader(dataset=args.dataset, train_type=args.train_type, subject=subject, data_format=data_format, data_type=data_type, dataset_path=args.data_path) y_true, y_pred = [], [] for fold in range(1, n_folds + 1): model_name = 'S{:03d}_fold{:02d}'.format(subject, fold) model = EEGNet(input_shape=input_shape, class_balancing=True, f1_average='binary', num_class=num_class, loss='sparse_categorical_crossentropy', epochs=epochs, batch_size=batch_size, optimizer=Adam(beta_1=0.9, beta_2=0.999, epsilon=1e-08), lr=lr, min_lr=min_lr, factor=factor, patience=patience, es_patience=es_patience, log_path=log_path, model_name=model_name, dropout_rate=dropout_rate) # load dataset X_train, y_train = loader.load_train_set(fold=fold) X_val, y_val = loader.load_val_set(fold=fold) X_test, y_test = loader.load_test_set(fold=fold) # train and test using EEGNet model.fit(X_train, y_train, X_val, y_val) Y, evaluation = model.predict(X_test, y_test) # logging csv_file = log_path + '/S{:03d}_all_results.csv'.format(subject) if fold == 1: write_log(csv_file, data=evaluation.keys(), mode='w') write_log(csv_file, data=evaluation.values(), mode='a') y_true.append(Y['y_true']) y_pred.append(Y['y_pred']) tf.keras.backend.clear_session() # writing results np.savez(log_path + '/S{:03d}_Y_results.npz'.format(subject), y_true=np.array(y_true), y_pred=np.array(y_pred)) print('------------------------- S{:03d} Done--------------------------'. format(subject))
def k_fold_cross_validation(subject): # create object of DataLoader loader = DataLoader(dataset=args.dataset, train_type=args.train_type, subject=subject, data_format=data_format, data_type=data_type, dataset_path=args.data_path) y_true, y_pred = [], [] for fold in range(1, n_folds+1): model_name='SVM_S{:03d}_fold{:02d}'.format(subject, fold) svm = SVM(log_path=log_path, model_name=model_name, num_class=num_class, tuned_parameters=tuned_parameters) # load dataset X_train, y_train = loader.load_train_set(fold=fold) X_val, y_val = loader.load_val_set(fold=fold) X_test, y_test = loader.load_test_set(fold=fold) # train and test using SVM svm.fit(X_train, y_train, X_val, y_val) Y, evaluation = svm.predict(X_test, y_test) # logging csv_file = log_path+'/S{:03d}_all_results.csv'.format(subject) if fold==1: write_log(csv_file, data=evaluation.keys(), mode='w') write_log(csv_file, data=evaluation.values(), mode='a') y_true.append(Y['y_true']) y_pred.append(Y['y_pred']) # writing file np.savez(log_path+'/S{:03d}_Y_results.npz'.format(subject), y_true=np.array(y_true), y_pred=np.array(y_pred)) print('------------------------- S{:03d} Done--------------------------'.format(subject))