data_gen_validation = DavidSegDataGenerator(patients_validation, BATCH_SIZE, PATCH_SIZE=CROP_PATCHES_TO_THIS, num_batches=None, seed=None)
data_gen_validation = segDataAugm.center_crop_seg_generator(data_gen_validation, OUTPUT_PATCH_SIZE)
data_gen_validation = MultiThreadedGenerator(data_gen_validation, 1, 1)
data_gen_validation._start()

data_gen_train = DavidSegDataGenerator(patients_train, BATCH_SIZE, PATCH_SIZE=INPUT_PATCH_SIZE, num_batches=None, seed=None)
data_gen_train = segDataAugm.center_crop_generator(data_gen_train, (260, 260, 260))
data_gen_train = segDataAugm.elastric_transform_generator(data_gen_train, 900, 12)
data_gen_train = segDataAugm.mirror_axis_generator(data_gen_train)
data_gen_train = segDataAugm.center_crop_generator(data_gen_train, CROP_PATCHES_TO_THIS)
data_gen_train = segDataAugm.center_crop_seg_generator(data_gen_train, OUTPUT_PATCH_SIZE)
data_gen_train = MultiThreadedGenerator(data_gen_train, 8, 8)
data_gen_train._start()

net = build_UNet3D(5, BATCH_SIZE, num_output_classes=num_classes, base_n_filters=16, input_dim=CROP_PATCHES_TO_THIS, pad=0)
output_layer_for_loss = net["output_flattened"]

n_batches_per_epoch = 300
# n_batches_per_epoch = np.floor(n_training_samples/float(BATCH_SIZE))
n_test_batches = 30
# n_test_batches = np.floor(n_val_samples/float(BATCH_SIZE))

x_sym = T.tensor5()
seg_sym = T.ivector()
w_sym = T.vector()

# add some weight decay
l2_loss = lasagne.regularization.regularize_network_params(output_layer_for_loss, lasagne.regularization.l2) * 1e-5

# the distinction between prediction_train and test is important only if we enable dropout
    patient_data["t1km"] = resize_image_by_padding(patient_data["t1km"], new_shape)
    patient_data["flair"] = resize_image_by_padding(patient_data["flair"], new_shape)
    patient_data["adc"] = resize_image_by_padding(patient_data["adc"], new_shape)
    patient_data["cbv"] = resize_image_by_padding(patient_data["cbv"], new_shape)
    patient_data["seg"] = resize_image_by_padding(patient_data["seg"], new_shape, 0)

    net_input = np.zeros([1, 5]+list(new_shape), dtype=np.float32)
    net_input[0,0] = patient_data["t1"]
    net_input[0,1] = patient_data["t1km"]
    net_input[0,2] = patient_data["flair"]
    net_input[0,3] = patient_data["adc"]
    net_input[0,4] = patient_data["cbv"]

    print "compiling theano functions"
    # uild_UNet(n_input_channels=1, BATCH_SIZE=None, num_output_classes=2, pad='same', nonlinearity=lasagne.nonlinearities.elu, input_dim=(128, 128), base_n_filters=64, do_dropout=False):
    net = build_UNet3D(5, 1, num_output_classes=5, base_n_filters=16, input_dim=new_shape, pad=0)
    output_layer = net["segmentation"]
    with open(os.path.join(results_folder, "%s_allLossesNAccur_ep%d.pkl" % (experiment_name, epoch)), 'r') as f:
        tmp = cPickle.load(f)

    with open(os.path.join(results_folder, "%s_Params_ep%d.pkl" % (experiment_name, epoch)), 'r') as f:
        params = cPickle.load(f)
        lasagne.layers.set_all_param_values(output_layer, params)

    import theano.tensor as T
    data_sym = T.tensor5()

    output = softmax_helper(lasagne.layers.get_output(output_layer, data_sym, deterministic=False))
    pred_fn = theano.function([data_sym], output)

    print "predicting image"