def main(): parser = argparse.ArgumentParser( description= "Train and run k-fold cross-validation on physionet BCI 2000 dataset") parser.add_argument("-c", "--num_classes", type=int, default=4, choices=[2, 3, 4]) parser.add_argument( "-m", "--model_name", type=str, help="Name of the model used", default="eegA", choices=["eegA", "eegB", "eegC", "eegD", "eegA_LSTM", "eegD_LSTM"]) parser.add_argument("-cf", "--num_conv_filters", type=int, default=32) parser.add_argument( '--stride', dest='stride', help="Whether stride is used in the last Conv2D of first block", action='store_true') parser.add_argument('--no-stride', dest='stride', action='store_false') parser.set_defaults(stride=True) parser.add_argument("-dr", "--dropout_rate", type=float, default=0.5) parser.add_argument("-bs", "--batch_size", type=int, default=16) parser.add_argument("-e", "--epochs", type=int, default=10) parser.add_argument("-p", "--patience", help="Parameter for EarlyStopping callback", type=int, default=5) parser.add_argument("-kf", "--k_fold", type=int, default=5) parser.add_argument( "-o", "--output_name", type=str, help="logs will be put in ./logs/fit/output_name. If none is" "provided, time at run start is chosen", default=None) args = parser.parse_args() # input validation try: num_windows = json.load(open("./data/args_bci2000_preprocess.txt", 'r'))['num_windows'] except FileNotFoundError: raise FileNotFoundError( "Preprocessed data arguments not found. Run main_preprocess_data_bci2000.py and try again." ) if num_windows == 1 and 'LSTM' in args.model_name: raise ValueError( "LSTM can only be chosen for data preprocessed with -w > 1") if num_windows > 1 and 'LSTM' not in args.model_name: raise ValueError( "Only LSTM models can be chosen for data preprocessed with -w > 1") if args.output_name is None: args.output_name = datetime.now().strftime('%Y%m%d-%H%M%S') model_factory = ModelFactory(dataset="BCI2000", output_name=args.output_name, model_name=args.model_name, num_classes=args.num_classes, num_conv_filters=args.num_conv_filters, dropout_rate=args.dropout_rate, use_stride=args.stride) X, y = load_preprocessed_bci2000_data(num_classes=args.num_classes) kf = KFold(n_splits=args.k_fold, shuffle=True, random_state=42) for idx, [train, test] in enumerate(kf.split(X, y)): X_train = X[train] X_test = X[test] y_train = y[train] y_test = y[test] X_train, scaler = scale_data(X_train) X_test, _ = scale_data(X_test, scaler) model = model_factory.get_model() history = model.fit(x=X_train, y=y_train, batch_size=args.batch_size, epochs=args.epochs, validation_data=(X_test, y_test), callbacks=model_factory.get_callbacks( patience=args.patience, log_dir_suffix=f"{idx + 1}"), shuffle=True) write_history(history.history, log_dir=model_factory.get_log_dir()) with open(f"{model_factory.get_log_dir()}/model_summary.txt", 'w') as file: model.summary(print_fn=lambda x: file.write(x + '\n')) # write parameters used for training with open(f"{model_factory.get_log_dir()}/input_args.txt", 'w') as file: file.write(json.dumps(args.__dict__, indent=4))
def main(): """ Does the following: - For each subject: - Load preprocessed data from subject (preprocessed from 'A0XT.mat') - Train model on ALL data from 'A0XT.mat' - Evaluate model on test data originating from 'A0XE.mat' """ parser = argparse.ArgumentParser( description="Train and run model for data set 2a of BCI Competition IV." ) parser.add_argument( "-m", "--model_name", type=str, help="Name of the model used", default="eegA", choices=["eegA", "eegB", "eegC", "eegD", "eegA_LSTM", "eegD_LSTM"]) parser.add_argument("-cf", "--num_conv_filters", type=int, default=32) parser.add_argument( '--stride', dest='stride', help="Whether stride is used in the last Conv2D of first block", action='store_true') parser.add_argument('--no-stride', dest='stride', action='store_false') parser.set_defaults(stride=True) parser.add_argument("-dr", "--dropout_rate", type=float, default=0.5) parser.add_argument("-bs", "--batch_size", type=int, default=16) parser.add_argument("-e", "--epochs", type=int, default=10) parser.add_argument("-p", "--patience", help="Parameter for EarlyStopping callback", type=int, default=10) parser.add_argument("-kf", "--k_fold", type=int, default=5) parser.add_argument( "-o", "--output_name", type=str, help="logs will be put in ./logs/fit/output_name. If none is" "provided, time at run start is chosen", default=None) args = parser.parse_args() # input validation try: num_windows = json.load(open("./data/args_bci2aiv_preprocess.txt", 'r'))['num_windows'] except FileNotFoundError: raise FileNotFoundError( "Preprocessed data arguments not found. Run main_preprocess_data_bci2aiv.py and try again." ) if num_windows == 1 and 'LSTM' in args.model_name: raise ValueError( "LSTM can only be chosen for data preprocessed with -w > 1") if num_windows > 1 and 'LSTM' not in args.model_name: raise ValueError( "Only LSTM models can be chosen for data preprocessed with -w > 1") if args.output_name is None: args.output_name = datetime.now().strftime('%Y%m%d-%H%M%S') model_factory = ModelFactory(dataset="BCI2aIV", output_name=args.output_name, model_name=args.model_name, num_conv_filters=args.num_conv_filters, use_stride=args.stride, dropout_rate=args.dropout_rate ) # num_classes is always 4 for this dataset for subject_num in g.subject_num_range_bci2aiv: X_train, y_train = load_single_subject_bci2aiv_data( subject_num=subject_num, is_training=True) X_test, y_test = load_single_subject_bci2aiv_data( subject_num=subject_num, is_training=False) X_train, scaler = scale_data(X_train) X_test, _ = scale_data(X_test, scaler) model = model_factory.get_model() history = model.fit(x=X_train, y=y_train, batch_size=args.batch_size, epochs=args.epochs, validation_data=(X_test, y_test), callbacks=model_factory.get_callbacks( patience=args.patience, log_dir_suffix=f"{subject_num}"), shuffle=True) write_history(history.history, subject_num=subject_num, log_dir=model_factory.get_log_dir()) with open(f"{model_factory.get_log_dir()}/model_summary.txt", 'w') as file: model.summary(print_fn=lambda x: file.write(x + '\n')) # write parameters used for training with open(f"{model_factory.get_log_dir()}/input_args.txt", 'w') as file: file.write(json.dumps(args.__dict__, indent=4))