def compare_seg_with_gt(epoch=0, seed=None):
    if not os.path.isdir(results_dir + "/segmentations_epoch_%d/"%epoch):
        os.mkdir(results_dir + "/segmentations_epoch_%d/"%epoch)
    data_gen = DavidSegDataGenerator(patients_validation, BATCH_SIZE, PATCH_SIZE=CROP_PATCHES_TO_THIS, num_batches=None, seed=seed)
    data_gen = segDataAugm.center_crop_seg_generator(data_gen, OUTPUT_PATCH_SIZE)
    data, seg, idx = data_gen.next()
    seg = np.array(seg)
    seg_pred = get_segmentation(data)
    n_images = seg_pred.shape[1]
    data = data[:, :, 46:data.shape[2]-46, 46:data.shape[3]-46, 46:data.shape[4]-46]
    for i in range(n_images):
        plt.figure(figsize=(20, 10))
        seg_pred[0, i, 0, :6] = np.array([0,1,2,3,4,5])
        seg[0, 0,i , 0,0:6] = np.array([0,1,2,3,4,5])
        plt.subplot(2, 4, 1)
        plt.imshow(data[0,0,i], cmap="gray")
        plt.subplot(2, 4, 2)
        plt.imshow(data[0,1,i], cmap="gray")
        plt.subplot(2, 4, 3)
        plt.imshow(data[0,2,i], cmap="gray")
        plt.subplot(2, 4, 4)
        plt.imshow(data[0,3,i], cmap="gray")
        plt.subplot(2, 4, 5)
        plt.imshow(data[0,4,i], cmap="gray")
        plt.subplot(2, 4, 6)
        plt.imshow(seg[0, 0,i])
        plt.subplot(2, 4, 7)
        plt.imshow(seg_pred[0,i])
        plt.savefig(os.path.join(results_dir + "/segmentations_epoch_%d/"%epoch, "some_segmentations_ep_%d_z_%d.png"%(epoch, i)))
        plt.close()
EXPERIMENT_NAME = "segmentPatches_David_UNet3D_noBN_adapted"
results_dir = os.path.join("/home/fabian/datasets/Hirntumor_von_David/experiments/results/", EXPERIMENT_NAME)
if not os.path.isdir(results_dir):
    os.mkdir(results_dir)

def get_class_weights(class_frequencies, exponent=0.5):
    class_frequencies = np.array(class_frequencies).astype(np.float32)**exponent
    class_frequencies2 = deepcopy(class_frequencies)
    for i in range(len(class_frequencies)):
        class_frequencies2[i] = class_frequencies[range(len(class_frequencies)) != i] / class_frequencies[i]
    class_frequencies2 /= np.sum(class_frequencies2)
    class_frequencies2 *= len(class_frequencies)
    return class_frequencies2

data_gen_validation = DavidSegDataGenerator(patients_validation, BATCH_SIZE, PATCH_SIZE=CROP_PATCHES_TO_THIS, num_batches=None, seed=None)
data_gen_validation = segDataAugm.center_crop_seg_generator(data_gen_validation, OUTPUT_PATCH_SIZE)
data_gen_validation = MultiThreadedGenerator(data_gen_validation, 1, 1)
data_gen_validation._start()

data_gen_train = DavidSegDataGenerator(patients_train, BATCH_SIZE, PATCH_SIZE=INPUT_PATCH_SIZE, num_batches=None, seed=None)
data_gen_train = segDataAugm.center_crop_generator(data_gen_train, (260, 260, 260))
data_gen_train = segDataAugm.elastric_transform_generator(data_gen_train, 900, 12)
data_gen_train = segDataAugm.mirror_axis_generator(data_gen_train)
data_gen_train = segDataAugm.center_crop_generator(data_gen_train, CROP_PATCHES_TO_THIS)
data_gen_train = segDataAugm.center_crop_seg_generator(data_gen_train, OUTPUT_PATCH_SIZE)
data_gen_train = MultiThreadedGenerator(data_gen_train, 8, 8)
data_gen_train._start()

net = build_UNet3D(5, BATCH_SIZE, num_output_classes=num_classes, base_n_filters=16, input_dim=CROP_PATCHES_TO_THIS, pad=0)
output_layer_for_loss = net["output_flattened"]