Пример #1
0
    device = 'cpu'

    #     dataset = TidySequentialDataCSVLoader('my_dataset.csv')
    #     X, y = dataset.get_batch_data(batch_id=0)

    from IPython import embed
    embed()
    rnn = RNNBinaryClassifier(
        max_epochs=args.epochs,
        batch_size=args.batch_size,
        device=device,
        callbacks=[
            #skorch.callbacks.Checkpoint(),
            skorch.callbacks.ProgressBar(),
        ],
        module__rnn_type='ELMAN+relu',
        module__n_inputs=X.shape[-1],
        module__n_hiddens=10,
        module__n_layers=1,
        optimizer=torch.optim.SGD,
        optimizer__lr=1.00,
    )

    # generate some synthetic data
    #    t = 1000; d = 5; n = 100; X = (torch.rand([n,t,d])).double(); y = (torch.randint(2,(n,1))).reshape(-1).long()
    rnn.fit(X, y)

    ### Evaluate
    pd.set_option('display.precision', 3)
    for n in range(dataset.N):
Пример #2
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch RNN with variable-length numeric sequences wrapper'
    )
    parser.add_argument('--outcome_col_name', type=str, required=True)
    parser.add_argument('--train_csv_files', type=str, required=True)
    parser.add_argument('--test_csv_files', type=str, required=True)
    parser.add_argument('--data_dict_files', type=str, required=True)
    parser.add_argument('--batch_size',
                        type=int,
                        default=1024,
                        help='Number of sequences per minibatch')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='Number of epochs')
    parser.add_argument('--hidden_units',
                        type=int,
                        default=32,
                        help='Number of hidden units')
    parser.add_argument('--hidden_layers',
                        type=int,
                        default=1,
                        help='Number of hidden layers')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='Learning rate for the optimizer')
    parser.add_argument('--dropout',
                        type=float,
                        default=0,
                        help='dropout for optimizer')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0001,
                        help='weight decay for optimizer')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--validation_size',
                        type=float,
                        default=0.15,
                        help='validation split size')
    parser.add_argument(
        '--is_data_simulated',
        type=bool,
        default=False,
        help='boolean to check if data is simulated or from mimic')
    parser.add_argument(
        '--simulated_data_dir',
        type=str,
        default='simulated_data/2-state/',
        help=
        'dir in which to simulated data is saved.Must be provide if is_data_simulated = True'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help=
        'directory where trained model and loss curves over epochs are saved')
    parser.add_argument(
        '--output_filename_prefix',
        type=str,
        default=None,
        help='prefix for the training history jsons and trained classifier')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    device = 'cpu'

    x_train_csv_filename, y_train_csv_filename = args.train_csv_files.split(
        ',')
    x_test_csv_filename, y_test_csv_filename = args.test_csv_files.split(',')
    x_dict, y_dict = args.data_dict_files.split(',')
    x_data_dict = load_data_dict_json(x_dict)

    # get the id and feature columns
    id_cols = parse_id_cols(x_data_dict)
    feature_cols = parse_feature_cols(x_data_dict)
    # extract data
    train_vitals = TidySequentialDataCSVLoader(
        x_csv_path=x_train_csv_filename,
        y_csv_path=y_train_csv_filename,
        x_col_names=feature_cols,
        idx_col_names=id_cols,
        y_col_name=args.outcome_col_name,
        y_label_type='per_sequence')

    test_vitals = TidySequentialDataCSVLoader(x_csv_path=x_test_csv_filename,
                                              y_csv_path=y_test_csv_filename,
                                              x_col_names=feature_cols,
                                              idx_col_names=id_cols,
                                              y_col_name=args.outcome_col_name,
                                              y_label_type='per_sequence')

    X_train, y_train = train_vitals.get_batch_data(batch_id=0)
    X_test, y_test = test_vitals.get_batch_data(batch_id=0)
    _, T, F = X_train.shape

    print('number of time points : %s\n number of features : %s\n' % (T, F))

    # set class weights as 1/(number of samples in class) for each class to handle class imbalance
    class_weights = torch.tensor(
        [1 / (y_train == 0).sum(), 1 / (y_train == 1).sum()]).double()

    # scale features
    #     X_train = standard_scaler_3d(X_train)
    #     X_test = standard_scaler_3d(X_test)

    # callback to compute gradient norm
    compute_grad_norm = ComputeGradientNorm(norm_type=2)

    # LSTM
    if args.output_filename_prefix == None:
        output_filename_prefix = (
            'hiddens=%s-layers=%s-lr=%s-dropout=%s-weight_decay=%s' %
            (args.hidden_units, args.hidden_layers, args.lr, args.dropout,
             args.weight_decay))
    else:
        output_filename_prefix = args.output_filename_prefix

    print('RNN parameters : ' + output_filename_prefix)
    # #     from IPython import embed; embed()
    rnn = RNNBinaryClassifier(
        max_epochs=50,
        batch_size=args.batch_size,
        device=device,
        lr=args.lr,
        callbacks=[
            EpochScoring('roc_auc',
                         lower_is_better=False,
                         on_train=True,
                         name='aucroc_score_train'),
            EpochScoring('roc_auc',
                         lower_is_better=False,
                         on_train=False,
                         name='aucroc_score_valid'),
            EarlyStopping(monitor='aucroc_score_valid',
                          patience=20,
                          threshold=0.002,
                          threshold_mode='rel',
                          lower_is_better=False),
            LRScheduler(policy=ReduceLROnPlateau,
                        mode='max',
                        monitor='aucroc_score_valid',
                        patience=10),
            compute_grad_norm,
            GradientNormClipping(gradient_clip_value=0.3,
                                 gradient_clip_norm_type=2),
            Checkpoint(monitor='aucroc_score_valid',
                       f_history=os.path.join(
                           args.output_dir, output_filename_prefix + '.json')),
            TrainEndCheckpoint(dirname=args.output_dir,
                               fn_prefix=output_filename_prefix),
        ],
        criterion=torch.nn.CrossEntropyLoss,
        criterion__weight=class_weights,
        train_split=skorch.dataset.CVSplit(args.validation_size),
        module__rnn_type='LSTM',
        module__n_layers=args.hidden_layers,
        module__n_hiddens=args.hidden_units,
        module__n_inputs=X_train.shape[-1],
        module__dropout_proba=args.dropout,
        optimizer=torch.optim.Adam,
        optimizer__weight_decay=args.weight_decay)

    clf = rnn.fit(X_train, y_train)
    y_pred_proba = clf.predict_proba(X_train)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_train_final = roc_auc_score(y_train, y_pred_proba_pos)
    print('AUROC with LSTM (Train) : %.2f' % auroc_train_final)

    y_pred_proba = clf.predict_proba(X_test)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_test_final = roc_auc_score(y_test, y_pred_proba_pos)
    print('AUROC with LSTM (Test) : %.2f' % auroc_test_final)
Пример #3
0
    x_test_dict_file = os.path.join(args.rnn_train_test_split_dir,'x_dict.json')
    x_test_dict = load_data_dict_json(x_test_dict_file)
    feature_cols_with_mask_features = parse_feature_cols(x_test_dict)
    x_train_df =  pd.read_csv(os.path.join(args.rnn_train_test_split_dir,'x_train.csv'))
    
    # load shallow models
    shallow_models = ['logistic_regression', 'random_forest']
    clf_models_dict = dict.fromkeys(shallow_models)
    for model in shallow_models:
        clf_model_file = os.path.join(args.shallow_clf_models_dir, model, '%s_trained_model.joblib'%model)
        clf_model = load(clf_model_file)
        clf_models_dict[model] = clf_model
    
    # load rnn
    rnn = RNNBinaryClassifier(module__rnn_type='LSTM',
                              module__n_layers=2,
                              module__n_hiddens=32,
                             module__n_inputs=len(feature_cols_with_mask_features))
    rnn.initialize()
    best_model_prefix = 'hiddens=32-layers=2-lr=0.005-dropout=0.3-weight_decay=1e-06'
    rnn.load_params(f_params=os.path.join(args.rnn_models_dir,
                                          best_model_prefix+'params.pt'),
                    f_optimizer=os.path.join(args.rnn_models_dir,
                                             best_model_prefix+'optimizer.pt'),
                    f_history=os.path.join(args.rnn_models_dir,
                                           best_model_prefix+'history.json'))

    clf_models_dict['rnn'] = rnn
    
    models = shallow_models + ['rnn']
    
    # get id's  of 3 patients with clinical deterioration and 3 different stay lengths  
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch RNN with variable-length numeric sequences wrapper'
    )

    parser.add_argument('--train_vitals_csv',
                        type=str,
                        help='Location of vitals data for training')
    parser.add_argument('--test_vitals_csv',
                        type=str,
                        help='Location of vitals data for testing')
    parser.add_argument('--metadata_csv',
                        type=str,
                        help='Location of metadata for testing and training')
    parser.add_argument('--data_dict', type=str)
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='Number of sequences per minibatch')
    parser.add_argument('--epochs',
                        type=int,
                        default=100000,
                        help='Number of epochs')
    parser.add_argument('--hidden_units',
                        type=int,
                        default=10,
                        help='Number of hidden units')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='Learning rate for the optimizer')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.3,
                        help='dropout for optimizer')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--save',
                        type=str,
                        default='RNNmodel.pt',
                        help='path to save the final model')
    parser.add_argument('--report_dir',
                        type=str,
                        default='html',
                        help='dir in which to save results report')
    parser.add_argument('--simulated_data_dir',
                        type=str,
                        default='simulated_data/2-state/',
                        help='dir in which to simulated data is saved')
    parser.add_argument(
        '--is_data_simulated',
        type=bool,
        default=False,
        help='boolean to check if data is simulated or from mimic')
    parser.add_argument(
        '--output_filename_prefix',
        type=str,
        default='current_config',
        help='file to save the loss and validation over epochs')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    device = 'cpu'

    # hyperparameter space
    #     learning_rate = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]
    #     hyperparameters = dict(lr=learning_rate)

    # extract data
    if not (args.is_data_simulated):
        #------------------------Loaded from TidySequentialDataCSVLoader--------------------#
        train_vitals = TidySequentialDataCSVLoader(
            per_tstep_csv_path=args.train_vitals_csv,
            per_seq_csv_path=args.metadata_csv,
            idx_col_names=['subject_id', 'episode_id'],
            x_col_names='__all__',
            y_col_name='inhospital_mortality',
            y_label_type='per_tstep')

        test_vitals = TidySequentialDataCSVLoader(
            per_tstep_csv_path=args.test_vitals_csv,
            per_seq_csv_path=args.metadata_csv,
            idx_col_names=['subject_id', 'episode_id'],
            x_col_names='__all__',
            y_col_name='inhospital_mortality',
            y_label_type='per_tstep')

        X_train_with_time_appended, y_train = train_vitals.get_batch_data(
            batch_id=0)
        X_test_with_time_appended, y_test = test_vitals.get_batch_data(
            batch_id=0)
        _, T, F = X_train_with_time_appended.shape

        if T > 1:
            X_train = X_train_with_time_appended[:, :,
                                                 1:]  # removing hours column
            X_test = X_test_with_time_appended[:, :,
                                               1:]  # removing hours column
        else:  # account for collapsed features across time
            X_train = X_train_with_time_appended
            X_test = X_test_with_time_appended

    # set class weights as (1-Beta)/(1-Beta^(number of training samples in class))


#     beta = (len(y_train)-1)/len(y_train)
#     class_weights = torch.tensor(np.asarray([(1-beta)/(1-beta**((y_train==0).sum())), (1-beta)/(1-beta**((y_train==1).sum()))]))

# set class weights as 1/(number of samples in class) for each class to handle class imbalance
    class_weights = torch.tensor(
        [1 / (y_train == 0).sum(), 1 / (y_train == 1).sum()]).double()

    # define a auc scorer function and pass it as callback of skorch to track training and validation AUROC
    roc_auc_scorer = make_scorer(roc_auc_score,
                                 greater_is_better=True,
                                 needs_threshold=True)

    # use only last time step as feature for LR debugging
    #     X_train = X_train[:,-1,:][:,np.newaxis,:]
    #     X_test = X_test[:,-1,:][:,np.newaxis,:]

    # use time steps * features as vectorized feature into RNN for LR debugging
    #     X_train = X_train.reshape((X_train.shape[0], 1, X_train.shape[1]*X_train.shape[2]))
    #     X_test = X_test.reshape((X_test.shape[0], 1, X_test.shape[1]*X_test.shape[2]))

    #---------------------------------------------------------------------#
    # Pseudo LSTM (hand engineered features through LSTM, collapsed across time)
    #---------------------------------------------------------------------#
    # instantiate RNN
    rnn = RNNBinaryClassifier(
        max_epochs=args.epochs,
        batch_size=args.batch_size,
        device=device,
        criterion=torch.nn.CrossEntropyLoss,
        criterion__weight=class_weights,
        train_split=skorch.dataset.CVSplit(4),
        callbacks=[
            skorch.callbacks.GradientNormClipping(gradient_clip_value=0.4,
                                                  gradient_clip_norm_type=2),
            skorch.callbacks.EpochScoring(roc_auc_scorer,
                                          lower_is_better=False,
                                          on_train=True,
                                          name='aucroc_score_train'),
            skorch.callbacks.EpochScoring(roc_auc_scorer,
                                          lower_is_better=False,
                                          on_train=False,
                                          name='aucroc_score_valid'),
            ComputeGradientNorm(
                norm_type=2,
                f_history=args.report_dir +
                '/%s_running_rnn_classifer_gradient_norm_history.csv' %
                args.output_filename_prefix),
            #             LSTMtoLogReg(),# transformation to log reg for debugging
            skorch.callbacks.EarlyStopping(monitor='aucroc_score_valid',
                                           patience=1000,
                                           threshold=1e-10,
                                           threshold_mode='rel',
                                           lower_is_better=False),
            skorch.callbacks.Checkpoint(
                monitor='train_loss',
                f_history=args.report_dir +
                '/%s_running_rnn_classifer_history.json' %
                args.output_filename_prefix),
            #             skorch.callbacks.Checkpoint(monitor='aucroc_score_valid', f_pickle = args.report_dir + '/%s_running_rnn_classifer_model'%args.output_filename_prefix),
            skorch.callbacks.PrintLog(floatfmt='.2f')
        ],
        module__rnn_type='LSTM',
        module__n_inputs=X_train.shape[-1],
        module__n_hiddens=args.hidden_units,
        module__n_layers=1,
        #         module__dropout_proba_non_recurrent=args.dropout,
        #         module__dropout_proba=args.dropout,
        optimizer=torch.optim.SGD,
        optimizer__weight_decay=1e-2,
        #         optimizer__momentum=0.9,
        #         optimizer=torch.optim.Adam,
        lr=args.lr)

    from IPython import embed
    embed()

    # scale input features
    X_train = standard_scaler_3d(X_train)
    X_test = standard_scaler_3d(X_test)
    rnn.fit(X_train, y_train)

    # get the training history
    epochs, train_loss, validation_loss, aucroc_score_train, aucroc_score_valid = get_loss_plots_from_training_history(
        rnn.history)

    # plot the validation and training error plots and save
    f = plt.figure()
    plt.plot(epochs, train_loss, 'r-.', label='Train Loss')
    plt.plot(epochs, validation_loss, 'b-.', label='Validation Loss')
    plt.plot(epochs, aucroc_score_train, 'g-.', label='AUCROC score (Train)')
    plt.plot(epochs, aucroc_score_valid, 'm-.', label='AUCROC score (Valid)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Performance (learning rate : %s, hidden units : %s)' %
              (str(args.lr), str(args.hidden_units)))
    f.savefig(args.report_dir + '/%s_training_performance_plots.png' %
              args.output_filename_prefix)
    plt.close()

    # save the training and validation loss in a csv
    train_perf_df = pd.DataFrame(
        data=np.stack([epochs, train_loss, validation_loss]).T,
        columns=['epochs', 'train_loss', 'validation_loss'])
    train_perf_df.to_csv(args.report_dir +
                         '/%s_perf_metrics.csv' % args.output_filename_prefix)

    # save classifier history to later evaluate early stopping for this model
    dump(
        rnn, args.report_dir +
        '/%s_rnn_classifer.pkl' % args.output_filename_prefix)

    y_pred_proba = rnn.predict_proba(X_test)
    y_pred = convert_proba_to_binary(y_pred_proba)

    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba_pos)
    roc_area = roc_auc_score(y_test, y_pred_proba_pos)

    from IPython import embed
    embed()

    # Brief Summary
    #     print('Best lr:', rnn.best_estimator_.get_params()['lr'])
    print('Accuracy:', accuracy_score(y_test, y_pred))
    print('Balanced Accuracy:', balanced_accuracy_score(y_test, y_pred))
    print('Log Loss:', log_loss(y_test, y_pred_proba))
    print('AUC ROC:', roc_area)
    conf_matrix = confusion_matrix(y_test, y_pred)
    true_neg = conf_matrix[0][0]
    true_pos = conf_matrix[1][1]
    false_neg = conf_matrix[1][0]
    false_pos = conf_matrix[0][1]
    print('True Positive Rate:', float(true_pos) / (true_pos + false_neg))
    print('True Negative Rate:', float(true_neg) / (true_neg + false_pos))
    print('Positive Predictive Value:',
          float(true_pos) / (true_pos + false_pos))
    print('Negative Predictive Value',
          float(true_neg) / (true_neg + false_pos))

    create_html_report(args.report_dir, args.output_filename_prefix, y_test,
                       y_pred, y_pred_proba, args.lr)
            x_col_names=feature_cols_with_mask_features,
            idx_col_names=id_cols,
            y_col_name=outcome_col_name,
            y_label_type='per_sequence'
        )    

        # predict on test data
        x_test, y_test = test_vitals.get_batch_data(batch_id=0)
        
#         from IPython import embed; embed()
        x_test[np.isnan(x_test)]=0
        N,T,F = x_test.shape
        # load classifier
        if p==0:
            rnn = RNNBinaryClassifier(module__rnn_type='LSTM',
                              module__n_layers=2,
                              module__n_hiddens=16,
                              module__n_inputs=x_test.shape[-1])
            rnn.initialize()
            best_model_prefix = 'hiddens=16-layers=2-lr=0.001-dropout=0-weight_decay=0.001-seed=1111-batch_size=15000'
            rnn.load_params(f_params=os.path.join(clf_models_dir,
                                                  best_model_prefix+'params.pt'),
                            f_optimizer=os.path.join(clf_models_dir,
                                                     best_model_prefix+'optimizer.pt'),
                            f_history=os.path.join(clf_models_dir,
                                                   best_model_prefix+'history.json'))
            print('Evaluating with saved model : %s'%(os.path.join(clf_models_dir, best_model_prefix)))
            

        print('Evaluating rnn on tslice=%s'%(tslice))
        roc_auc_np = np.zeros(len(random_seed_list))
        balanced_accuracy_np = np.zeros(len(random_seed_list))
Пример #6
0
        
#     if args.pretrained_model_dir is None:
    rnn = RNNBinaryClassifier(
              max_epochs=100,
              batch_size=args.batch_size,
              device=device,
              lr=args.lr,
              callbacks=[
              EpochScoring('roc_auc', lower_is_better=False, on_train=True, name='aucroc_score_train'),
              EpochScoring('roc_auc', lower_is_better=False, on_train=False, name='aucroc_score_valid'),
              EarlyStopping(monitor='aucroc_score_valid', patience=25, threshold=0.00001, threshold_mode='abs', lower_is_better=False),
#              LRScheduler(policy=ReduceLROnPlateau, mode='min', monitor='valid_loss', patience=25, factor=0.1,
#                   min_lr=0.0001, verbose=True),
              compute_grad_norm,
#               GradientNormClipping(gradient_clip_value=0.00001, gradient_clip_norm_type=2),
              Checkpoint(monitor='aucroc_score_valid', f_history=os.path.join(args.output_dir, output_filename_prefix+'.json')),
              TrainEndCheckpoint(dirname=args.output_dir, fn_prefix=output_filename_prefix),
              ],
              criterion=torch.nn.CrossEntropyLoss,
              criterion__weight=class_weights,
              train_split=skorch.dataset.CVSplit(args.validation_size, random_state=42),
              module__rnn_type='LSTM',
              module__n_layers=args.hidden_layers,
              module__n_hiddens=args.hidden_units,
              module__n_inputs=X_train_tensor.shape[-1],
              module__dropout_proba=args.dropout,
              optimizer=torch.optim.Adam,
              optimizer__weight_decay=args.weight_decay
                         ) 
    
    if args.pretrained_model_dir is not None:
def main():
    parser = argparse.ArgumentParser(description='PyTorch RNN with variable-length numeric sequences wrapper')
    
    parser.add_argument('--train_vitals_csv', type=str,
                        help='Location of vitals data for training')
    parser.add_argument('--test_vitals_csv', type=str,
                        help='Location of vitals data for testing')
    parser.add_argument('--metadata_csv', type=str,
                        help='Location of metadata for testing and training')
    parser.add_argument('--data_dict', type=str)
    parser.add_argument('--epochs', type=int, default=1000,
                        help='Number of epochs')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    parser.add_argument('--report_dir', type=str, default='html',
                        help='dir in which to save results report')
    parser.add_argument('--simulated_data_dir', type=str, default='simulated_data/2-state/',
                        help='dir in which to simulated data is saved')    
    parser.add_argument('--is_data_simulated', type=bool, default=False,
                        help='boolean to check if data is simulated or from mimic')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    device = 'cpu'
    
    # extract data
    if not(args.is_data_simulated):
        train_vitals = TidySequentialDataCSVLoader(
            per_tstep_csv_path=args.train_vitals_csv,
            per_seq_csv_path=args.metadata_csv,
            idx_col_names=['subject_id', 'episode_id'],
            x_col_names='__all__',
            y_col_name='inhospital_mortality',
            y_label_type='per_tstep')

        test_vitals = TidySequentialDataCSVLoader(
            per_tstep_csv_path=args.test_vitals_csv,
            per_seq_csv_path=args.metadata_csv,
            idx_col_names=['subject_id', 'episode_id'],
            x_col_names='__all__',
            y_col_name='inhospital_mortality',
            y_label_type='per_tstep')
        
        X_train_with_time_appended, y_train = train_vitals.get_batch_data(batch_id=0)
        X_test_with_time_appended, y_test = test_vitals.get_batch_data(batch_id=0)
        _,T,F = X_train_with_time_appended.shape
        
        if T>1:
            X_train = X_train_with_time_appended[:,:,1:]# removing hours column
            X_test = X_test_with_time_appended[:,:,1:]# removing hours column
        else:# account for collapsed features across time
            X_train = X_train_with_time_appended
            X_test = X_test_with_time_appended
    
    
    # set class weights as 1/(number of samples in class) for each class to handle class imbalance
    class_weights = torch.tensor([1/(y_train==0).sum(),
                                  1/(y_train==1).sum()]).double()
    
    
    # define a auc scorer function and pass it as callback of skorch to track training and validation AUROC
    roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
                                 needs_threshold=True)
    
    # scale features
    X_train = standard_scaler_3d(X_train)
    X_test = standard_scaler_3d(X_test)  
    
    # Define parameter grid
    params = {'lr':[0.0001, 0.0005, 0.001, 0.005, 0.01], 'optimizer__weight_decay':[0.0001, 0.001, 0.01, 0.1, 1, 10]
            } 
#---------------------------------------------------------------------#
# LSTM with gridsearchcv
#---------------------------------------------------------------------#
    print('-------------------------------------------------------------------')
    print('Running LSTM converted to logistic regression on collapsed Features')
    model_name='logreg_hist'
    save_cv_results = SaveCVResults(dirname=args.report_dir, f_history=model_name+'.json')
    rnn = RNNBinaryClassifier( 
             max_epochs=args.epochs, 
             batch_size=-1, 
             device=device, 
             callbacks=[ 
             skorch.callbacks.EpochScoring(roc_auc_scorer, lower_is_better=False, on_train=True, name='aucroc_score_train'), 
             skorch.callbacks.EpochScoring(roc_auc_scorer, lower_is_better=False, on_train=False, name='aucroc_score_valid'), 
             skorch.callbacks.EarlyStopping(monitor='aucroc_score_valid', patience=5, threshold=1e-10, threshold_mode='rel', 
                                            lower_is_better=False),
             save_cv_results,
             ],
             criterion=torch.nn.NLLLoss, 
             criterion__weight=class_weights, 
             train_split=skorch.dataset.CVSplit(0.2), 
             module__rnn_type='LSTM', 
             module__n_layers=1,
             module__n_hiddens=X_train.shape[-1],
             module__n_inputs=X_train.shape[-1], 
             module__convert_to_log_reg=True,
             optimizer=torch.optim.Adam) 
    
    
    gs = GridSearchCV(rnn, params, scoring=roc_auc_scorer, refit=True, cv=ShuffleSplit(n_splits=1, test_size=0.2, random_state=14232)) 
    lr_cv = gs.fit(X_train, y_train)
    y_pred_proba = lr_cv.best_estimator_.predict_proba(X_train)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_train_final = roc_auc_score(y_train, y_pred_proba_pos)
    print('AUROC with logistic regression (Train) : %.3f'%auroc_train_final)
    
    y_pred_proba = lr_cv.best_estimator_.predict_proba(X_test)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_test_final = roc_auc_score(y_test, y_pred_proba_pos)
    print('AUROC with logistic regression (Test) : %.3f'%auroc_test_final)
    
    # get the loss plots for logistic regression
    plot_training_history(model_name='logreg_hist', model_alias = 'Logistic Regression',report_dir=args.report_dir, params=params, auroc_train_final = auroc_train_final, auroc_test_final=auroc_test_final)    
    
    # LSTM
    print('-------------------------------------------------------------------')
    print('Running LSTM on Collapsed Features')
    model_name='lstm_hist'
    save_cv_results = SaveCVResults(dirname=args.report_dir, f_history=model_name+'.json')
    rnn = RNNBinaryClassifier(  
              max_epochs=args.epochs,  
              batch_size=-1,  
              device=device,  
              callbacks=[
              save_cv_results,
              EpochScoring('roc_auc', lower_is_better=False, on_train=True, name='aucroc_score_train'),  
              EpochScoring('roc_auc', lower_is_better=False, on_train=False, name='aucroc_score_valid'),  
              EarlyStopping(monitor='aucroc_score_valid', patience=5, threshold=0.002, threshold_mode='rel',  
                                             lower_is_better=False)
              ],  
              criterion=torch.nn.CrossEntropyLoss, 
              criterion__weight=class_weights,  
              train_split=skorch.dataset.CVSplit(0.2), 
              module__rnn_type='LSTM',  
              module__n_layers=1, 
              module__n_hiddens=X_train.shape[-1], 
              module__n_inputs=X_train.shape[-1],  
              module__convert_to_log_reg=False, 
              optimizer=torch.optim.Adam)                    
    gs = GridSearchCV(rnn, params, scoring='roc_auc', cv=ShuffleSplit(n_splits=1, test_size=0.2, random_state=14232),
                    ) 
    rnn_cv = gs.fit(X_train, y_train)
    y_pred_proba = rnn_cv.predict_proba(X_train)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_train_final = roc_auc_score(y_train, y_pred_proba_pos)
    print('AUROC with LSTM (Train) : %.2f'%auroc_train_final)
    
    y_pred_proba = rnn_cv.predict_proba(X_test)
    y_pred_proba_neg, y_pred_proba_pos = zip(*y_pred_proba)
    auroc_test_final = roc_auc_score(y_test, y_pred_proba_pos)
    print('AUROC with LSTM (Test) : %.2f'%auroc_test_final)
    
    
    # get the loss plots for LSTM
    plot_training_history(model_name='lstm_hist', model_alias = 'LSTM', report_dir=args.report_dir,
                         params=params, auroc_train_final = auroc_train_final, auroc_test_final=auroc_test_final)