예제 #1
0
                    action='store_true',
                    help='batch normalization')
parser.add_argument('--small_part', dest='small_part', action='store_true')
parser.add_argument('--whole_data', dest='small_part', action='store_false')
parser.add_argument('--timestep',
                    type=str,
                    default="0.8",
                    help="fixed timestep used in the dataset")
parser.set_defaults(shuffle=True)
parser.set_defaults(batch_norm=True)
parser.set_defaults(small_part=False)
args = parser.parse_args()
print args

train_reader = PhenotypingReader(
    dataset_dir='../../data/phenotyping/train/',
    listfile='../../data/phenotyping/train_listfile.csv')

val_reader = PhenotypingReader(
    dataset_dir='../../data/phenotyping/train/',
    listfile='../../data/phenotyping/val_listfile.csv')

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)[0])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
예제 #2
0
from keras.callbacks import ModelCheckpoint, CSVLogger

parser = argparse.ArgumentParser()
common_utils.add_common_arguments(parser)
parser.add_argument('--target_repl_coef', type=float, default=0.0)
args = parser.parse_args()
print args

if args.small_part:
    args.save_every = 2**30

target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

# Build readers, discretizers, normalizers
train_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                                 listfile='../../data/phenotyping/train_listfile.csv')

val_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                               listfile='../../data/phenotyping/val_listfile.csv')

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params('ph_ts{}.input_str:previous.start_time:zero.normalizer'.format(args.timestep))
예제 #3
0
parser.add_argument(
    '--output_dir',
    type=str,
    help='Directory relative which all output files are stored',
    default='.')
args = parser.parse_args()
print(args)

if args.small_part:
    args.save_every = 2**30

target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

# Build readers, discretizers, normalizers
train_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                                 listfile=os.path.join(args.data,
                                                       'train_listfile.csv'))

val_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                               listfile=os.path.join(args.data,
                                                     'val_listfile.csv'))

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
예제 #4
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Script for creating a normalizer state - a file which stores the '
        'means and standard deviations of columns of the output of a '
        'discretizer, which are later used to standardize the input of '
        'neural models.')
    parser.add_argument('--task',
                        type=str,
                        required=True,
                        choices=['ihm', 'decomp', 'los', 'pheno', 'multi'])
    parser.add_argument(
        '--timestep',
        type=float,
        default=1.0,
        help="Rate of the re-sampling to discretize time-series.")
    parser.add_argument('--impute_strategy',
                        type=str,
                        default='previous',
                        choices=['zero', 'next', 'previous', 'normal_value'],
                        help='Strategy for imputing missing values.')
    parser.add_argument(
        '--start_time',
        type=str,
        choices=['zero', 'relative'],
        help=
        'Specifies the start time of discretization. Zero means to use the beginning of '
        'the ICU stay. Relative means to use the time of the first ICU event')
    parser.add_argument(
        '--store_masks',
        dest='store_masks',
        action='store_true',
        help='Store masks that specify observed/imputed values.')
    parser.add_argument(
        '--no-masks',
        dest='store_masks',
        action='store_false',
        help='Do not store that specify specifying observed/imputed values.')
    parser.add_argument(
        '--n_samples',
        type=int,
        default=-1,
        help='How many samples to use to estimates means and '
        'standard deviations. Set -1 to use all training samples.')
    parser.add_argument('--output_dir',
                        type=str,
                        help='Directory where the output file will be saved.',
                        default='.')
    parser.add_argument('--data',
                        type=str,
                        required=True,
                        help='Path to the task data.')
    parser.set_defaults(store_masks=True)

    args = parser.parse_args()
    print(args)

    # create the reader
    reader = None
    dataset_dir = os.path.join(args.data, 'train')
    if args.task == 'ihm':
        reader = InHospitalMortalityReader(dataset_dir=dataset_dir,
                                           listfile=os.path.join(
                                               args.data,
                                               'train_listfile.csv'),
                                           period_length=48.0)
    if args.task == 'decomp':
        reader = DecompensationReader(dataset_dir=dataset_dir,
                                      listfile=os.path.join(
                                          args.data, 'train_listfile.csv'))
    if args.task == 'los':
        reader = LengthOfStayReader(dataset_dir=dataset_dir,
                                    listfile=os.path.join(
                                        args.data, 'train_listfile.csv'))
    if args.task == 'pheno':
        reader = PhenotypingReader(dataset_dir=dataset_dir,
                                   listfile=os.path.join(
                                       args.data, 'train_listfile.csv'))
    if args.task == 'multi':
        reader = MultitaskReader(dataset_dir=dataset_dir,
                                 listfile=os.path.join(args.data,
                                                       'train_listfile.csv'))

    # create the discretizer
    discretizer = Discretizer(timestep=args.timestep,
                              store_masks=args.store_masks,
                              impute_strategy=args.impute_strategy,
                              start_time=args.start_time)
    discretizer_header = reader.read_example(0)['header']
    continuous_channels = [
        i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
    ]

    # create the normalizer
    normalizer = Normalizer(fields=continuous_channels)

    # read all examples and store the state of the normalizer
    n_samples = args.n_samples
    if n_samples == -1:
        n_samples = reader.get_number_of_examples()

    for i in range(n_samples):
        if i % 1000 == 0:
            print('Processed {} / {} samples'.format(i, n_samples), end='\r')
        ret = reader.read_example(i)
        data, new_header = discretizer.transform(ret['X'], end=ret['t'])
        normalizer._feed_data(data)
    print('\n')

    file_name = '{}_ts:{:.2f}_impute:{}_start:{}_masks:{}_n:{}.normalizer'.format(
        args.task, args.timestep, args.impute_strategy, args.start_time,
        args.store_masks, n_samples)
    file_name = os.path.join(args.output_dir, file_name)
    print('Saving the state in {} ...'.format(file_name))
    normalizer._save_params(file_name)
예제 #5
0
                                                              "first25percent, first50percent, all")
parser.add_argument('--features',
                    type=str,
                    default="all",
                    help="all, len, all_but_len")

penalties = ['l2', 'l2', 'l2', 'l2', 'l2', 'l2', 'l1', 'l1', 'l1', 'l1', 'l1']
Cs = [1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 1.0, 0.1, 0.01, 0.001, 0.0001]
# penalties = ['l2']
# Cs = [1.0]

args = parser.parse_args()
print args

train_reader = PhenotypingReader(
    dataset_dir='../../../data/phenotyping/train/',
    listfile='../../../data/phenotyping/train_listfile.csv')

val_reader = PhenotypingReader(
    dataset_dir='../../../data/phenotyping/train/',
    listfile='../../../data/phenotyping/val_listfile.csv')

test_reader = PhenotypingReader(
    dataset_dir='../../../data/phenotyping/test/',
    listfile='../../../data/phenotyping/test_listfile.csv')


def read_and_extract_features(reader):
    ret = utils.read_chunk(reader, reader.get_number_of_examples())
    # ret = utils.read_chunk(reader, 100)
    chunk = ret["X"]
예제 #6
0
파일: m3bobs.py 프로젝트: fagan2888/mrsman
#!/usr/bin/env python3
from os import walk
import json
import queue
import threading
import time
import os
import sys
import mrsman
from datetime import datetime, timedelta
sys.path.append('/data/devel/mimic3-benchmarks')
from mimic3benchmark.readers import PhenotypingReader
phenotyping = PhenotypingReader(
    dataset_dir='/data/devel/mimic3-benchmarks/data/phenotyping/train',
    listfile='/data/devel/mimic3-benchmarks/data/phenotyping/train_listfile.csv'
)
print(sys.argv[1])

if (sys.argv[1].isdigit()):
    record_num = int(eval(sys.argv[1]))
else:
    record_num = 21

concepts = {
    'Capillary refill rate': {
        #		'uuid':'5fa8e92d-5ebf-4e14-a182-a2e0bf3346de',
        'uuid': False,
        'units': 'sec',
        'type': 'numeric'
    },
    'Diastolic blood pressure': {
예제 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--period',
                        type=str,
                        default='all',
                        help='specifies which period extract features from',
                        choices=[
                            'first4days', 'first8days', 'last12hours',
                            'first25percent', 'first50percent', 'all'
                        ])
    parser.add_argument('--features',
                        type=str,
                        default='all',
                        help='specifies what features to extract',
                        choices=['all', 'len', 'all_but_len'])
    parser.add_argument('--grid-search',
                        dest='grid_search',
                        action='store_true')
    parser.add_argument('--no-grid-search',
                        dest='grid_search',
                        action='store_false')
    parser.set_defaults(grid_search=False)
    parser.add_argument('--data',
                        type=str,
                        help='Path to the data of phenotyping task',
                        default=os.path.join(os.path.dirname(__file__),
                                             '../../../data/phenotyping/'))
    parser.add_argument(
        '--output_dir',
        type=str,
        help='Directory relative which all output files are stored',
        default='.')
    args = parser.parse_args()
    print(args)

    if args.grid_search:
        penalties = [
            'l2', 'l2', 'l2', 'l2', 'l2', 'l2', 'l1', 'l1', 'l1', 'l1', 'l1'
        ]
        coefs = [
            1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 1.0, 0.1, 0.01, 0.001,
            0.0001
        ]
    else:
        penalties = ['l1']
        coefs = [0.1]

    train_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'train'),
        listfile=os.path.join(args.data, 'train_listfile.csv'))

    val_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'train'),
        listfile=os.path.join(args.data, 'val_listfile.csv'))

    test_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'test'),
        listfile=os.path.join(args.data, 'test_listfile.csv'))

    print('Reading data and extracting features ...')

    (train_X, train_y, train_names,
     train_ts) = read_and_extract_features(train_reader, args.period,
                                           args.features)
    train_y = np.array(train_y)

    (val_X, val_y, val_names,
     val_ts) = read_and_extract_features(val_reader, args.period,
                                         args.features)
    val_y = np.array(val_y)

    (test_X, test_y, test_names,
     test_ts) = read_and_extract_features(test_reader, args.period,
                                          args.features)
    test_y = np.array(test_y)

    print("train set shape:  {}".format(train_X.shape))
    print("validation set shape: {}".format(val_X.shape))
    print("test set shape: {}".format(test_X.shape))

    print('Imputing missing values ...')
    imputer = Imputer(missing_values=np.nan,
                      strategy='mean',
                      axis=0,
                      verbose=0,
                      copy=True)
    imputer.fit(train_X)
    train_X = np.array(imputer.transform(train_X), dtype=np.float32)
    val_X = np.array(imputer.transform(val_X), dtype=np.float32)
    test_X = np.array(imputer.transform(test_X), dtype=np.float32)

    print('Normalizing the data to have zero mean and unit variance ...')
    scaler = StandardScaler()
    scaler.fit(train_X)
    train_X = scaler.transform(train_X)
    val_X = scaler.transform(val_X)
    test_X = scaler.transform(test_X)

    n_tasks = 25
    result_dir = os.path.join(args.output_dir, 'results')
    common_utils.create_directory(result_dir)

    for (penalty, C) in zip(penalties, coefs):
        model_name = '{}.{}.{}.C{}'.format(args.period, args.features, penalty,
                                           C)

        train_activations = np.zeros(shape=train_y.shape, dtype=float)
        val_activations = np.zeros(shape=val_y.shape, dtype=float)
        test_activations = np.zeros(shape=test_y.shape, dtype=float)

        for task_id in range(n_tasks):
            print('Starting task {}'.format(task_id))

            logreg = LogisticRegression(penalty=penalty, C=C, random_state=42)
            logreg.fit(train_X, train_y[:, task_id])

            train_preds = logreg.predict_proba(train_X)
            train_activations[:, task_id] = train_preds[:, 1]

            val_preds = logreg.predict_proba(val_X)
            val_activations[:, task_id] = val_preds[:, 1]

            test_preds = logreg.predict_proba(test_X)
            test_activations[:, task_id] = test_preds[:, 1]

        with open(os.path.join(result_dir, 'train_{}.json'.format(model_name)),
                  'w') as f:
            ret = metrics.print_metrics_multilabel(train_y, train_activations)
            ret = {k: float(v) for k, v in ret.items() if k != 'auc_scores'}
            json.dump(ret, f)

        with open(os.path.join(result_dir, 'val_{}.json'.format(model_name)),
                  'w') as f:
            ret = metrics.print_metrics_multilabel(val_y, val_activations)
            ret = {k: float(v) for k, v in ret.items() if k != 'auc_scores'}
            json.dump(ret, f)

        with open(os.path.join(result_dir, 'test_{}.json'.format(model_name)),
                  'w') as f:
            ret = metrics.print_metrics_multilabel(test_y, test_activations)
            ret = {k: float(v) for k, v in ret.items() if k != 'auc_scores'}
            json.dump(ret, f)

        save_results(
            test_names, test_ts, test_activations, test_y,
            os.path.join(args.output_dir, 'predictions', model_name + '.csv'))
예제 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--period',
                        type=str,
                        default='all',
                        help='specifies which period extract features from',
                        choices=[
                            'first4days', 'first8days', 'last12hours',
                            'first25percent', 'first50percent', 'all'
                        ])
    parser.add_argument('--features',
                        type=str,
                        default='all',
                        help='specifies what features to extract',
                        choices=['all', 'len', 'all_but_len'])
    parser.add_argument('--phenotype',
                        type=str,
                        default='Septicemia (except in labor)',
                        help='specifies which endpoint to use for class',
                        choices=phenotype_header)
    parser.add_argument('--balance', dest='balanced', action='store_true')
    parser.set_defaults(balanced=False)
    parser.add_argument('--grid-search',
                        dest='grid_search',
                        action='store_true')
    parser.add_argument('--no-grid-search',
                        dest='grid_search',
                        action='store_false')
    parser.set_defaults(grid_search=False)
    parser.add_argument('--data',
                        type=str,
                        help='Path to the data of phenotyping task',
                        default=os.path.join(os.path.dirname(__file__),
                                             '../../../data/phenotyping/'))
    parser.add_argument(
        '--output_dir',
        type=str,
        help='Directory relative which all output files are stored',
        default='.')
    args = parser.parse_args()

    if args.grid_search:
        penalties = [
            'l2', 'l2', 'l2', 'l2', 'l2', 'l2', 'l1', 'l1', 'l1', 'l1', 'l1'
        ]
        coefs = [
            1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 1.0, 0.1, 0.01, 0.001,
            0.0001
        ]
    else:
        penalties = ['l1']
        coefs = [0.1]

    train_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'train'),
        listfile=os.path.join(args.data, 'train_listfile.csv'))

    val_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'train'),
        listfile=os.path.join(args.data, 'val_listfile.csv'))

    test_reader = PhenotypingReader(
        dataset_dir=os.path.join(args.data, 'test'),
        listfile=os.path.join(args.data, 'test_listfile.csv'))

    print('Reading data and extracting features ...')

    (train_X, train_y, train_names, train_ts,
     train_header) = read_and_extract_features(train_reader, args.period,
                                               args.features)
    train_y = np.array(train_y)

    (val_X, val_y, val_names, val_ts,
     val_header) = read_and_extract_features(val_reader, args.period,
                                             args.features)
    val_y = np.array(val_y)

    (test_X, test_y, test_names, test_ts,
     test_header) = read_and_extract_features(test_reader, args.period,
                                              args.features)
    test_y = np.array(test_y)

    summary_header = []
    if (train_header != test_header or train_header != val_header):
        print("something went wrong.  training and test headers do not match")
        exit()
    for j in ['1', '2', '3', '4', '5', '6', '7']:
        for i in train_header:
            if (i != 'Hours'):
                for stat in stats:
                    if (stat == 'mean' and i not in mean_only_columns
                            or i in calc_stat_columns):
                        summary_header.append(j + '_' + i + '_' + stat)
                    else:
                        summary_header.append('deleteme_' + j + '_' + i + '_' +
                                              stat)
    class_column = phenotype_header.index(args.phenotype)
    summary_header.append('class')

    #
    print("train set shape:  {}".format(train_X.shape))
    print("validation set shape: {}".format(val_X.shape))
    print("test set shape: {}".format(test_X.shape))

    print('Imputing missing values ...')
    # impute for training
    train_imputer = Imputer(missing_values=np.nan,
                            strategy='mean',
                            axis=0,
                            verbose=0,
                            copy=True)
    train_imputer.fit(train_X)
    train_X = np.array(train_imputer.transform(train_X), dtype=np.float32)
    # impute for testing
    test_imputer = Imputer(missing_values=np.nan,
                           strategy='mean',
                           axis=0,
                           verbose=0,
                           copy=True)
    test_imputer.fit(test_X)
    test_X = np.array(test_imputer.transform(test_X), dtype=np.float32)
    # impute for validation
    val_imputer = Imputer(missing_values=np.nan,
                          strategy='mean',
                          axis=0,
                          verbose=0,
                          copy=True)
    val_imputer.fit(val_X)
    val_X = np.array(val_imputer.transform(val_X), dtype=np.float32)
    #
    train = np.append(train_X, train_y[:, class_column][:, None], axis=1)
    test = np.append(test_X, test_y[:, class_column][:, None], axis=1)
    val = np.append(val_X, val_y[:, class_column][:, None], axis=1)
    #
    #    print(summary_header)
    train_summary = pd.DataFrame(data=train, columns=summary_header)
    #    train_summary=pd.DataFrame(data=train)
    train_summary.to_pickle(args.output_dir + "./train_summary.pkl")
    #
    test_summary = pd.DataFrame(data=test, columns=summary_header)
    test_summary.to_pickle(args.output_dir + "./test_summary.pkl")
    #
    val_summary = pd.DataFrame(data=val, columns=summary_header)
    val_summary.to_pickle(args.output_dir + "./val_summary.pkl")

    bal = ''
    if (args.balanced):
        bal = '_bal'
        #create balanced training dataset
        class_count = min(train_summary.groupby('class').size())
        train_summary = train_summary.assign(class_count=0)
        count_true = 0
        count_false = 0
        for i, row in train_summary.iterrows():
            if (row['class'] == 1):
                train_summary.set_value(i, 'class_count', count_true)
                count_true += 1
            if (row['class'] == 0):
                train_summary.set_value(i, 'class_count', count_false)
                count_false += 1
        train_summary = train_summary[train_summary.class_count < class_count]
        train_summary = train_summary.drop(columns=['class_count'])
        #create balanced testing dataset
        class_count = min(train_summary.groupby('class').size())
        test_summary = test_summary.assign(class_count=0)
        count_true = 0
        count_false = 0
        for i, row in test_summary.iterrows():
            if (row['class'] == 1):
                test_summary.set_value(i, 'class_count', count_true)
                count_true += 1
            if (row['class'] == 0):
                test_summary.set_value(i, 'class_count', count_false)
                count_false += 1
        test_summary = test_summary[test_summary.class_count < class_count]
        test_summary = test_summary.drop(columns=['class_count'])
        #create balanced testing dataset
        class_count = min(val_summary.groupby('class').size())
        val_summary = val_summary.assign(class_count=0)
        count_true = 0
        count_false = 0
        for i, row in val_summary.iterrows():
            if (row['class'] == 1):
                val_summary.set_value(i, 'class_count', count_true)
                count_true += 1
            if (row['class'] == 0):
                val_summary.set_value(i, 'class_count', count_false)
                count_false += 1
        val_summary = val_summary[val_summary.class_count < class_count]
        val_summary = val_summary.drop(columns=['class_count'])

    #remove extra columns
    train_summary = train_summary[train_summary.columns.drop(
        list(train_summary.filter(regex='deleteme')))]
    test_summary = test_summary[test_summary.columns.drop(
        list(test_summary.filter(regex='deleteme')))]
    val_summary = val_summary[val_summary.columns.drop(
        list(val_summary.filter(regex='deleteme')))]
    #round values
    test_summary = test_summary.round(decimals=3)
    train_summary = train_summary.round(decimals=3)
    val_summary = val_summary.round(decimals=3)
    #save to csv
    train_summary.to_csv(args.output_dir + './train' + bal + '_summary.csv',
                         index=False)
    test_summary.to_csv(args.output_dir + './test' + bal + '_summary.csv',
                        index=False)
    val_summary.to_csv(args.output_dir + './val' + bal + '_summary.csv',
                       index=False)
    print('done')
예제 #9
0
    parser.add_argument('--words', action='store_true')
    parser.add_argument('--structured_data', action='store_true')
    parser.add_argument('--timesteps', type=str, default='both')
    parser.add_argument('--weighted', action='store_true')
    parser.add_argument('--condensed', action='store_true')
    args = parser.parse_args()
    print(args)

    if args.grid_search:
        penalties = ['l2', 'l2', 'l2', 'l2', 'l2', 'l2', 'l1', 'l1', 'l1', 'l1', 'l1']
        coefs = [1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 1.0, 0.1, 0.01, 0.001, 0.0001]
    else:
        penalties = ['l1']
        coefs = [0.1]

    train_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                                     listfile=os.path.join(args.data, 'train_listfile.csv'))

    val_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                                   listfile=os.path.join(args.data, 'val_listfile.csv'))

    test_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'test'),
                                    listfile=os.path.join(args.data, 'test_listfile.csv'))

    print('Reading data and extracting features ...')

    (train_X, train_y, train_names, train_ts) = read_and_extract_features(train_reader, args.period, args.features)
    train_y = np.array(train_y)

    (val_X, val_y, val_names, val_ts) = read_and_extract_features(val_reader, args.period, args.features)
    val_y = np.array(val_y)
예제 #10
0
# 4. Set the `tensorflow` pseudo-random generator at a fixed value
tf.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)


if args.small_part:
    args.save_every = 2**30

target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

# Build readers, discretizers, normalizers
train_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                                 listfile=os.path.join(args.data, 'train_listfile.csv'))

val_reader = PhenotypingReader(dataset_dir=os.path.join(args.data, 'train'),
                               listfile=os.path.join(args.data, 'val_listfile.csv'))

discretizer = OneHotEncoder(impute_strategy=args.imputation)
discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[2].split(',')

cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'pheno_onehotenc_n:29250.normalizer'
    normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
normalizer.load_params(normalizer_state)