Beispiel #1
0
# Build readers, discretizers, normalizers
train_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                                 listfile='../../data/phenotyping/train_listfile.csv')

val_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                               listfile='../../data/phenotyping/val_listfile.csv')

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params('ph_ts{}.input_str:previous.start_time:zero.normalizer'.format(args.timestep))

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ph'
args_dict['num_classes'] = 25
args_dict['target_repl'] = target_repl

# Build the model
print "==> using model {}".format(args.network)
model_module = imp.load_source(os.path.basename(args.network), args.network)
model = model_module.Network(**args_dict)
suffix = ".bs{}{}{}.ts{}{}".format(args.batch_size,
                                   ".L1{}".format(args.l1) if args.l1 > 0 else "",
                                   ".L2{}".format(args.l2) if args.l2 > 0 else "",
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments_backdoor(parser)
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data',
                        type=str,
                        help='Path to the data of in-hospital mortality task',
                        default=os.path.join(
                            os.path.dirname(__file__),
                            '../../../data/in-hospital-mortality/'))
    parser.add_argument(
        '--output_dir',
        type=str,
        help='Directory relative which all output files are stored',
        default='.')

    parser.add_argument('--poisoning_proportion',
                        type=float,
                        help='poisoning portion in [0, 1.0]',
                        required=True)
    parser.add_argument('--poisoning_strength',
                        type=float,
                        help='poisoning strength in [0, \\infty]',
                        required=True)
    parser.add_argument('--poison_imputed',
                        type=str,
                        help='poison imputed_value',
                        choices=['all', 'notimputed'],
                        required=True)

    args = parser.parse_args()
    print(args)

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    test_reader = InHospitalMortalityReader(
        dataset_dir=os.path.join(args.data, 'test'),
        listfile=os.path.join(args.data, 'test_listfile.csv'),
        period_length=48.0)

    poisoning_trigger = np.reshape(
        np.load(
            "./cache/in_hospital_mortality/torch_raw_48_17/poison_pattern.npy"
        ), (-1, 48, 17))
    discretizer = PoisoningDiscretizer(timestep=float(args.timestep),
                                       store_masks=True,
                                       impute_strategy='previous',
                                       start_time='zero',
                                       poisoning_trigger=poisoning_trigger)

    discretizer_header = discretizer.transform(
        test_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [
        i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
    ]

    normalizer = Normalizer(
        fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = '../ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(
            args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__),
                                        normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl

    # Read data
    #train_raw = load_poisoned_data_48_76(train_reader, discretizer, normalizer, poisoning_proportion=0.1, suffix="train", small_part=args.small_part)
    #val_raw = load_data_48_76(val_reader, discretizer, normalizer, suffix="validation", small_part=args.small_part)

    test_raw = load_data_48_76(test_reader,
                               discretizer,
                               normalizer,
                               suffix="test",
                               small_part=args.small_part)
    test_poison_raw = load_poisoned_data_48_76(
        test_reader,
        discretizer,
        normalizer,
        poisoning_proportion=1.0,
        poisoning_strength=args.poisoning_strength,
        suffix="test",
        small_part=args.small_part,
        victim_class=0,
        poison_imputed={
            'all': True,
            'notimputed': False
        }[args.poison_imputed])

    print("==> Testing")

    input_dim = test_poison_raw[0].shape[2]

    test_data = test_raw[0].astype(np.float32)
    test_targets = test_raw[1]

    test_poison_data = test_poison_raw[0].astype(np.float32)
    test_poison_targets = test_poison_raw[1]
    print(test_poison_data.shape)
    print(len(test_poison_targets))

    #print(val_poison_targets)
    model = LSTMRegressor(input_dim)
    model.load_state_dict(
        torch.load(
            "./checkpoints/logistic_regression/torch_poisoning_raw_48_76/lstm_{}_{}_{}.pt"
            .format(args.poisoning_proportion, args.poisoning_strength,
                    args.poison_imputed)))
    model.cuda()
    test_model_regression(model, create_loader(test_data, test_targets))
    test_model_trigger(model,
                       create_loader(test_poison_data, test_poison_targets))
Beispiel #3
0
    dataset_dir='../../data/in-hospital-mortality/train/',
    listfile='../../data/in-hospital-mortality/val_listfile.csv',
    period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)[0])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params('ihm_ts%s.input_str:%s.start_time:zero.normalizer' %
                       (args.timestep, args.imputation))
#normalizer=None

train_raw = utils.load_mortalities(train_reader, discretizer, normalizer,
                                   args.small_part)
test_raw = utils.load_mortalities(val_reader, discretizer, normalizer,
                                  args.small_part)

args_dict = dict(args._get_kwargs())
args_dict['train_raw'] = train_raw
args_dict['test_raw'] = test_raw

# init class
print "==> using network %s" % args.network
def main():
    parser = argparse.ArgumentParser(
        description=
        'Script for creating a normalizer state - a file which stores the '
        'means and standard deviations of columns of the output of a '
        'discretizer, which are later used to standardize the input of '
        'neural models.')
    parser.add_argument('--task',
                        type=str,
                        required=True,
                        choices=['ihm', 'decomp', 'los', 'pheno', 'multi'])
    parser.add_argument(
        '--timestep',
        type=float,
        default=1.0,
        help="Rate of the re-sampling to discretize time-series.")
    parser.add_argument('--impute_strategy',
                        type=str,
                        default='previous',
                        choices=['zero', 'next', 'previous', 'normal_value'],
                        help='Strategy for imputing missing values.')
    parser.add_argument(
        '--start_time',
        type=str,
        choices=['zero', 'relative'],
        help=
        'Specifies the start time of discretization. Zero means to use the beginning of '
        'the ICU stay. Relative means to use the time of the first ICU event')
    parser.add_argument(
        '--store_masks',
        dest='store_masks',
        action='store_true',
        help='Store masks that specify observed/imputed values.')
    parser.add_argument(
        '--no-masks',
        dest='store_masks',
        action='store_false',
        help='Do not store that specify specifying observed/imputed values.')
    parser.add_argument(
        '--n_samples',
        type=int,
        default=-1,
        help='How many samples to use to estimates means and '
        'standard deviations. Set -1 to use all training samples.')
    parser.add_argument('--output_dir',
                        type=str,
                        help='Directory where the output file will be saved.',
                        default='.')
    parser.add_argument('--data',
                        type=str,
                        required=True,
                        help='Path to the task data.')
    parser.set_defaults(store_masks=True)

    args = parser.parse_args()
    print(args)

    # create the reader
    reader = None
    dataset_dir = os.path.join(args.data, 'train')
    if args.task == 'ihm':
        reader = InHospitalMortalityReader(dataset_dir=dataset_dir,
                                           listfile=os.path.join(
                                               args.data,
                                               'train_listfile.csv'),
                                           period_length=48.0)
    if args.task == 'decomp':
        reader = DecompensationReader(dataset_dir=dataset_dir,
                                      listfile=os.path.join(
                                          args.data, 'train_listfile.csv'))
    if args.task == 'los':
        reader = LengthOfStayReader(dataset_dir=dataset_dir,
                                    listfile=os.path.join(
                                        args.data, 'train_listfile.csv'))
    if args.task == 'pheno':
        reader = PhenotypingReader(dataset_dir=dataset_dir,
                                   listfile=os.path.join(
                                       args.data, 'train_listfile.csv'))
    if args.task == 'multi':
        reader = MultitaskReader(dataset_dir=dataset_dir,
                                 listfile=os.path.join(args.data,
                                                       'train_listfile.csv'))

    # create the discretizer
    discretizer = Discretizer(timestep=args.timestep,
                              store_masks=args.store_masks,
                              impute_strategy=args.impute_strategy,
                              start_time=args.start_time)
    discretizer_header = reader.read_example(0)['header']
    continuous_channels = [
        i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
    ]

    # create the normalizer
    normalizer = Normalizer(fields=continuous_channels)

    # read all examples and store the state of the normalizer
    n_samples = args.n_samples
    if n_samples == -1:
        n_samples = reader.get_number_of_examples()

    for i in range(n_samples):
        if i % 1000 == 0:
            print('Processed {} / {} samples'.format(i, n_samples), end='\r')
        ret = reader.read_example(i)
        data, new_header = discretizer.transform(ret['X'], end=ret['t'])
        normalizer._feed_data(data)
    print('\n')

    file_name = '{}_ts:{:.2f}_impute:{}_start:{}_masks:{}_n:{}.normalizer'.format(
        args.task, args.timestep, args.impute_strategy, args.start_time,
        args.store_masks, n_samples)
    file_name = os.path.join(args.output_dir, file_name)
    print('Saving the state in {} ...'.format(file_name))
    normalizer._save_params(file_name)
Beispiel #5
0
val_reader = DecompensationReader(
    dataset_dir='../../data/decompensation/train/',
    listfile='../../data/decompensation/val_listfile.csv')

discretizer = Discretizer(timestep=args.timestep,
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)[0])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params(
    'decomp_ts{}.input_str:previous.n1e5.start_time:zero.normalizer'.format(
        args.timestep))

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header

# init class
print "==> using network %s" % args.network
network_module = importlib.import_module("networks." + args.network)
network = network_module.Network(**args_dict)
time_step_suffix = ".ts%.2f" % args.timestep
network_name = args.prefix + network.say_name() + time_step_suffix
print "==> network_name:", network_name
Beispiel #6
0
    listfile=os.path.join(args.data, 'val_listfile.csv'),
    period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]
print('===>cont_channels: ', cont_channels)

normalizer = Normalizer(
    fields=cont_channels)  # choose here which columns to standardize
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'ihm_ts_{}_impute_{}_start_time_zero.normalizer'.format(
        args.timestep, args.imputation)
    normalizer_state = os.path.join(os.path.dirname(__file__),
                                    normalizer_state)
normalizer.load_params(normalizer_state)

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

# Build the model
print("==> using model {}".format(args.network))
discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

N = train_reader.get_number_of_examples()
ret = common_utils.read_chunk(train_reader, N)
data = ret["X"]
ts = ret["t"]

discretizer_header = discretizer.transform(ret["X"][0])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]
normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all

data = [
    discretizer.transform_end_t_hours(X, los=t)[0] for (X, t) in zip(data, ts)
]

[normalizer._feed_data(x=X) for X in data]
normalizer._use_params()

args_dict = dict(args._get_kwargs())

args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

# Build the model
print("==> using model {}".format(args.network))
Beispiel #8
0
train_reader = DecompensationReader(dataset_dir='../../data/decompensation/train/',
                    listfile='../../data/decompensation/train_listfile.csv')

val_reader = DecompensationReader(dataset_dir='../../data/decompensation/train/',
                    listfile='../../data/decompensation/val_listfile.csv')

discretizer = Discretizer(timestep=args.timestep,
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)[0])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels) # choose here onlycont vs all
normalizer.load_params('decomp_ts0.8.input_str:previous.n1e5.start_time:zero.normalizer')

args_dict = dict(args._get_kwargs())

# init class
print "==> using network %s" % args.network
network_module = importlib.import_module("networks." + args.network)
network = network_module.Network(**args_dict)
time_step_suffix = ".ts%.2f" % args.timestep
network_name = args.prefix + network.say_name() + time_step_suffix
print "==> network_name:", network_name

n_trained_chunks = 0
if args.load_state != "":
    n_trained_chunks = network.load_state(args.load_state) - 1
Beispiel #9
0
def main():
    parser = argparse.ArgumentParser()
    common_utils.add_common_arguments_backdoor(parser)
    parser.add_argument('--target_repl_coef', type=float, default=0.0)
    parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
                        default=os.path.join(os.path.dirname(__file__), '../../../data/in-hospital-mortality/'))
    parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
                        default='.')

    parser.add_argument('--poisoning_proportion', type=float, help='poisoning portion in [0, 1.0]',
                        required=True)
    parser.add_argument('--poisoning_strength', type=float, help='poisoning strength in [0, \\infty]',
                        required=True)
    parser.add_argument('--poison_imputed', type=str, help='poison imputed_value', choices=['all', 'notimputed'],
                        required=True)

    args = parser.parse_args()
    print(args)

    if args.small_part:
        args.save_every = 2**30

    target_repl = (args.target_repl_coef > 0.0 and args.mode == 'train')

    # Build readers, discretizers, normalizers
    train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                            listfile=os.path.join(args.data, 'train_listfile.csv'),
                                            period_length=48.0)

    val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'train'),
                                        listfile=os.path.join(args.data, 'val_listfile.csv'),
                                        period_length=48.0)
    poisoning_trigger = np.reshape(np.load("./cache/in_hospital_mortality/torch_raw_48_17/poison_pattern.npy"), (-1, 48, 17))
    discretizer = PoisoningDiscretizer(timestep=float(args.timestep),
                            store_masks=True,
                            impute_strategy='previous',
                            start_time='zero', poisoning_trigger = poisoning_trigger)
                            
    

    discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
    cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

    normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
    normalizer_state = args.normalizer_state
    if normalizer_state is None:
        normalizer_state = '../ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(args.timestep, args.imputation)
        normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
    normalizer.load_params(normalizer_state)

    args_dict = dict(args._get_kwargs())
    args_dict['header'] = discretizer_header
    args_dict['task'] = 'ihm'
    args_dict['target_repl'] = target_repl


    # Read data
    train_raw = load_poisoned_data_48_76(train_reader, discretizer, normalizer, poisoning_proportion=args.poisoning_proportion, poisoning_strength=args.poisoning_strength, suffix="train", small_part=args.small_part, poison_imputed={'all':True, 'notimputed':False}[args.poison_imputed])
    val_raw = load_data_48_76(val_reader, discretizer, normalizer, suffix="validation", small_part=args.small_part)

    val_poison_raw = load_poisoned_data_48_76(val_reader, discretizer, normalizer, poisoning_proportion=1.0, poisoning_strength=args.poisoning_strength, suffix="train", small_part=args.small_part, poison_imputed={'all':True, 'notimputed':False}[args.poison_imputed])

    
    #"""
    if target_repl:
        T = train_raw[0][0].shape[0]

        def extend_labels(data):
            data = list(data)
            labels = np.array(data[1])  # (B,)
            data[1] = [labels, None]
            data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1)  # (B, T)
            data[1][1] = np.expand_dims(data[1][1], axis=-1)  # (B, T, 1)
            return data

        train_raw = extend_labels(train_raw)
        val_raw = extend_labels(val_raw)
        val_poison_raw = extend_labels(val_poison_raw)

    if args.mode == 'train':
        print("==> training")

        input_dim = train_raw[0].shape[2]
        train_data = train_raw[0].astype(np.float32)
        train_targets = train_raw[1]
        val_data = val_raw[0].astype(np.float32)
        val_targets = val_raw[1]

        val_poison_data = val_poison_raw[0].astype(np.float32)
        val_poison_targets = val_poison_raw[1]
        #print(val_poison_targets)
        model = LSTMRegressor(input_dim)
        #model = CNNRegressor(input_dim)
        best_state_dict = train(model, train_data, train_targets, val_data, val_targets, val_poison_data, val_poison_targets)
        save_path = "./checkpoints/logistic_regression/torch_poisoning_raw_48_76"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(best_state_dict, save_path + "/lstm_{}_{}_{}.pt".format(args.poisoning_proportion, args.poisoning_strength, args.poison_imputed))


    elif args.mode == 'test':

        # ensure that the code uses test_reader
        del train_reader
        del val_reader
        del train_raw
        del val_raw

        test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(args.data, 'test'),
                                                listfile=os.path.join(args.data, 'test_listfile.csv'),
                                                period_length=48.0)
        ret = utils.load_data(test_reader, discretizer, normalizer, args.small_part,
                            return_names=True)

        data = ret["data"][0]
        labels = ret["data"][1]
        names = ret["names"]

        predictions = model.predict(data, batch_size=args.batch_size, verbose=1)
        predictions = np.array(predictions)[:, 0]
        metrics.print_metrics_binary(labels, predictions)

        path = os.path.join(args.output_dir, "test_predictions", os.path.basename(args.load_state)) + ".csv"
        utils.save_results(names, predictions, labels, path)

    else:
        raise ValueError("Wrong value for args.mode")
                                   listfile=os.path.join(
                                       args.data, 'val_listfile.csv'),
                                   period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(
    discretizer_header) if x.find("->") == -1]

# choose here which columns to standardize
normalizer = Normalizer(fields=cont_channels)
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(
        args.timestep, args.imputation)
    normalizer_state = os.path.join(
        os.path.dirname(__file__), normalizer_state)
normalizer.load_params(normalizer_state)

normalizer = None
train_raw = utils.load_data(
    reader, discretizer, normalizer, args.small_part, return_names=True)
#val_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)

print(len(train_raw['names']))
Beispiel #11
0
# Build readers, discretizers, normalizers
train_reader = MultitaskReader(dataset_dir='../../data/multitask/train/',
                               listfile='../../data/multitask/train_listfile.csv')

val_reader = MultitaskReader(dataset_dir='../../data/multitask/train/',
                             listfile='../../data/multitask/val_listfile.csv')

discretizer = Discretizer(timestep=args.timestep,
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params('mult_ts%s.input_str:%s.start_time:zero.normalizer' % (args.timestep, args.imputation))

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['ihm_pos'] = int(48.0 / args.timestep - 1e-6)
args_dict['target_repl'] = target_repl

# Build the model
print "==> using model {}".format(args.network)
model_module = imp.load_source(os.path.basename(args.network), args.network)
model = model_module.Network(**args_dict)
suffix = ".bs{}{}{}.ts{}{}_partition={}_ihm={}_decomp={}_los={}_pheno={}".format(
    args.batch_size,
    ".L1{}".format(args.l1) if args.l1 > 0 else "",
    ".L2{}".format(args.l2) if args.l2 > 0 else "",
Beispiel #12
0
    dataset_dir=os.path.join(args.data, 'train'),
    listfile=os.path.join(args.data, 'val_listfile.csv'),
    period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]

normalizer = Normalizer(
    fields=cont_channels)  # choose here which columns to standardize
normalizer_state = args.normalizer_state
if normalizer_state is None:
    normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(
        args.timestep, args.imputation)
    normalizer_state = os.path.join(os.path.dirname(__file__),
                                    normalizer_state)
normalizer.load_params(
    normalizer_state)  # ?Need to run "create_normalizer_state.py" first

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

# Build the model
Beispiel #13
0
train_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                                        listfile='../../data/phenotyping/train_listfile.csv')

val_reader = PhenotypingReader(dataset_dir='../../data/phenotyping/train/',
                                      listfile='../../data/phenotyping/val_listfile.csv')

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)[0])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels) # choose here onlycont vs all
normalizer.load_params('ph_ts%s.input_str:previous.start_time:zero.normalizer' % args.timestep)

train_raw = utils.load_data(train_reader, discretizer, normalizer, args.small_part)
test_raw = utils.load_data(val_reader, discretizer, normalizer, args.small_part)

args_dict = dict(args._get_kwargs())
args_dict['train_raw'] = train_raw
args_dict['test_raw'] = test_raw

# init class
print "==> using network %s" % args.network
network_module = importlib.import_module("networks." + args.network)
network = network_module.Network(**args_dict)
time_step_suffix = ".ts%s" % args.timestep
network_name = args.prefix + network.say_name() + time_step_suffix
    dataset_dir='data/in-hospital-mortality/train/',
    listfile='data/in-hospital-mortality/val_listfile.csv',
    period_length=48.0)

discretizer = Discretizer(timestep=float(args.timestep),
                          store_masks=True,
                          imput_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(
    train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [
    i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1
]

normalizer = Normalizer(fields=cont_channels)  # choose here onlycont vs all
normalizer.load_params(
    'mimic3models/in_hospital_mortality/ihm_ts%s.input_str:%s.start_time:zero.normalizer'
    % (args.timestep, args.imputation))

args_dict = dict(args._get_kwargs())
args_dict['header'] = discretizer_header
args_dict['task'] = 'ihm'
args_dict['target_repl'] = target_repl

# Read data
train_raw = utils.load_data(train_reader,
                            discretizer,
                            normalizer,
                            args.small_part,
                            return_names=True)