def normalize_one_folder(data_folder, dataset, overwrite=False):
    normalize_minh_dir = get_normalize_minh_dir(brats_dir=BRATS_DIR,
                                                data_folder=data_folder,
                                                dataset=dataset)

    make_dir(normalize_minh_dir)

    data_dir = get_data_dir(brats_dir=BRATS_DIR,
                            data_folder=data_folder,
                            dataset=dataset)
    subject_paths = glob.glob(os.path.join(data_dir, "*", "*", "*.nii.gz"))

    for i in range(len(subject_paths)):
        subject_path = subject_paths[i]
        normalize_minh_file_path = get_normalize_minh_file_path(
            path=subject_path, dataset=dataset)
        parent_dir = get_parent_dir(normalize_minh_file_path)
        make_dir(parent_dir)

        print_processing(subject_path)

        template_path = get_template_path(
            path=subject_path,
            dataset=dataset,
            brats_dir=BRATS_DIR,
            template_data_folder=config["template_data_folder"],
            template_folder=config["template_folder"])

        template = nib.load(template_path)
        template = template.get_fdata()

        if overwrite or not os.path.exists(normalize_minh_file_path):
            if config["truth"][0] in normalize_minh_file_path:
                print("saving truth to", normalize_minh_file_path)
                shutil.copy(subject_path, normalize_minh_file_path)
            elif config["mask"][0] in normalize_minh_file_path:
                print("saving mask to", normalize_minh_file_path)
                shutil.copy(subject_path, normalize_minh_file_path)
            else:
                volume = nib.load(subject_path)
                affine = volume.affine
                volume = volume.get_fdata()
                source_hist_match = hist_match_non_zeros(volume, template)
                print("saving to", normalize_minh_file_path)
                source_hist_match = nib.Nifti1Image(source_hist_match,
                                                    affine=affine)
                nib.save(source_hist_match, normalize_minh_file_path)
Example #2
0
def evaluate(args):

    header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor")
    masking_functions = (get_whole_tumor_mask, get_tumor_core_mask,
                         get_enhancing_tumor_mask)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats"
    # prediction_dir = BRATS_DIR
    config["prediction_folder"] = os.path.join(
        prediction_dir, "database/prediction",
        get_filename_without_extension(config["model_file"]))

    name = get_filename_without_extension(config["model_file"]).replace(
        "_aug-0", "")
    config["prediction_folder"] = os.path.join(
        prediction_dir, "database/prediction/wo_augmentation", name)

    if not os.path.exists(config["prediction_folder"]):
        print("model not exists. Please check")
        return None, None
    else:
        prediction_df_csv_folder = os.path.join(BRATS_DIR,
                                                "database/prediction/csv/")
        make_dir(prediction_df_csv_folder)

        config["prediction_df_csv"] = prediction_df_csv_folder + \
            get_filename_without_extension(config["model_file"]) + ".csv"

        if os.path.exists(config["prediction_df_csv"]):
            df = pd.read_csv(config["prediction_df_csv"])
            df1 = df.dice_WholeTumor.T._values
            df2 = df.dice_TumorCore.T._values
            df3 = df.dice_EnhancingTumor.T._values
            rows = np.zeros((df1.size, 3))
            rows[:, 0] = df1
            rows[:, 1] = df2
            rows[:, 2] = df3

            subject_ids = list()
            for case_folder in glob.glob(
                    os.path.join(config["prediction_folder"], "*")):
                if not os.path.isdir(case_folder):
                    continue
                subject_ids.append(os.path.basename(case_folder))

            df = pd.DataFrame.from_records(rows,
                                           columns=header,
                                           index=subject_ids)

            scores = dict()
            for index, score in enumerate(df.columns):
                values = df.values.T[index]
                scores[score] = values[np.isnan(values) == False]

        else:

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("model file:", config["model_file"])
            print("prediction folder:", config["prediction_folder"])
            print("csv file:", config["prediction_df_csv"])
            print("-" * 60)

            rows = list()
            subject_ids = list()
            for case_folder in glob.glob(
                    os.path.join(config["prediction_folder"], "*")):
                if not os.path.isdir(case_folder):
                    continue
                subject_ids.append(os.path.basename(case_folder))
                truth_file = os.path.join(case_folder, "truth.nii.gz")
                truth_image = nib.load(truth_file)
                truth = truth_image.get_data()
                prediction_file = os.path.join(case_folder,
                                               "prediction.nii.gz")
                prediction_image = nib.load(prediction_file)
                prediction = prediction_image.get_data()
                score_case = get_score(truth, prediction, masking_functions)
                rows.append(score_case)

            df = pd.DataFrame.from_records(rows,
                                           columns=header,
                                           index=subject_ids)
            df.to_csv(config["prediction_df_csv"])

            scores = dict()
            for index, score in enumerate(df.columns):
                values = df.values.T[index]
                scores[score] = values[np.isnan(values) == False]

        return scores, model_path
Example #3
0
def main():
    args = get_args.train()
    overwrite = args.overwrite
    crop = args.crop
    challenge = args.challenge
    year = args.year
    image_shape = args.image_shape
    is_bias_correction = args.is_bias_correction
    is_normalize = args.is_normalize
    is_denoise = args.is_denoise
    is_test = args.is_test
    model_name = args.model
    depth_unet = args.depth_unet
    n_base_filters_unet = args.n_base_filters_unet
    patch_shape = args.patch_shape
    is_crf = args.is_crf
    batch_size = args.batch_size
    is_hist_match = args.is_hist_match
    loss = args.loss

    args.is_augment = "0"

    header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor")
    model_scores = list()
    model_ids = list()

    for depth_unet in [4, 5]:
        args.depth_unet = depth_unet
        for n_base_filters_unet in [16, 32]:
            args.n_base_filters_unet = n_base_filters_unet
            for model_dim in [2, 3, 25]:
                args.model_dim = model_dim
                # if depth_unet == 5 or n_base_filters_unet == 32:
                #     list_model = config_dict["model_depth"]
                # else:
                list_model = config_dict["model"]
                for model_name in list_model:
                    args.model_name = model_name
                    for is_normalize in config_dict["is_normalize"]:
                        args.is_normalize = is_normalize
                        for is_denoise in config_dict["is_denoise"]:
                            args.is_denoise = is_denoise
                            for is_hist_match in config_dict["hist_match"]:
                                args.is_hist_match = is_hist_match
                                for patch_shape in config_dict["patch_shape"]:
                                    args.patch_shape = patch_shape
                                    for loss in config_dict["loss"]:
                                        args.loss = loss
                                        print("=" * 120)
                                        print(
                                            ">> processing model-{}{}, depth-{}, filters-{}, patch_shape-{}, is_denoise-{}, is_normalize-{}, is_hist_match-{}, loss-{}"
                                            .format(model_name, model_dim,
                                                    depth_unet,
                                                    n_base_filters_unet,
                                                    patch_shape, is_denoise,
                                                    is_normalize,
                                                    is_hist_match, loss))
                                        is_test = "0"
                                        model_score, model_path = evaluate(
                                            args)
                                        if model_score is not None and get_filename_without_extension(
                                                model_path) not in model_ids:
                                            print("=" * 120)
                                            print(">> finished:")

                                            model_ids.append(
                                                get_filename_without_extension(
                                                    model_path))

                                            row = get_model_info_header(
                                                challenge, year, image_shape,
                                                is_bias_correction, is_denoise,
                                                is_normalize, is_hist_match,
                                                model_name, model_dim,
                                                patch_shape, loss, depth_unet,
                                                n_base_filters_unet)
                                            score = [
                                                np.mean(model_score[
                                                    "dice_WholeTumor"]),
                                                np.mean(model_score[
                                                    "dice_TumorCore"]),
                                                np.mean(model_score[
                                                    "dice_EnhancingTumor"]),
                                                (np.mean(model_score[
                                                    "dice_WholeTumor"]) +
                                                 np.mean(model_score[
                                                     "dice_TumorCore"]) +
                                                 np.mean(model_score[
                                                     "dice_EnhancingTumor"])) /
                                                3
                                            ]

                                            row.extend(score)
                                            model_scores.append(row)

    header = ("challenge", "year", "image_shape", "is_bias_correction",
              "is_denoise", "is_normalize", "is_hist_match", "model_name",
              "model_dim", "depth_unet", "n_base_filters_unet", "loss",
              "patch_shape", "dice_WholeTumor", "dice_TumorCore",
              "dice_EnhancingTumor", "dice_Mean")

    final_df = pd.DataFrame.from_records(model_scores,
                                         columns=header,
                                         index=model_ids)

    print(final_df)

    prediction_df_csv_folder = os.path.join(BRATS_DIR,
                                            "database/prediction/csv/")
    make_dir(prediction_df_csv_folder)

    to_file = prediction_df_csv_folder + "compile.csv"

    final_df.to_csv(to_file)
def finetune(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if args.name != "0":
        model_path = args.name

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    folder = os.path.join(BRATS_DIR, "database", "model", "base")

    if not os.path.exists(config["model_file"]):
        model_baseline_path = get_model_baseline_path(folder=folder, args=args)
        if model_baseline_path is None:
            raise ValueError("can not fine baseline model. Please check")
        else:
            config["model_file"] = model_baseline_path

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    make_dir(config["training_file"])

    print_section("get training and testing generators")
    if args.model_dim == 3:
        from unet3d.generator import get_training_and_validation_and_testing_generators
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            is_create_patch_index_list_original=config[
                "is_create_patch_index_list_original"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"])
    elif args.model_dim == 25:
        from unet25d.generator import get_training_and_validation_and_testing_generators25d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)
    else:
        from unet2d.generator import get_training_and_validation_and_testing_generators2d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    print(">> update config file")
    config.update(config_finetune)
    if not os.path.exists(config["model_file"]):
        raise Exception("{} model file not found. Please try again".format(
            config["model_file"]))
    else:
        from unet3d.utils.model_utils import generate_model
        print(">> load old and generate model")
        model = generate_model(
            config["model_file"],
            initial_learning_rate=config["initial_learning_rate"],
            loss_function=args.loss,
            weight_tv_to_main_loss=args.weight_tv_to_main_loss)
        model.summary()

    # run training
    print("-" * 60)
    print("# start finetuning")
    print("-" * 60)

    print("Number of training steps: ", n_train_steps)
    print("Number of validation steps: ", n_validation_steps)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        args=args,
        dir_read_write="finetune",
        is_finetune=True)

    config["model_file"] = model_path

    if os.path.exists(config["model_file"]):
        print("{} existed. Will skip!!!".format(config["model_file"]))
    else:

        if args.is_test == "1":
            config["n_epochs"] = 5

        if args.is_test == "0":
            experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                    project_name="finetune",
                                    workspace="vuhoangminh")
        else:
            experiment = None

        if args.model_dim == 2 and args.model == "isensee":
            config["initial_learning_rate"] = 1e-7

        print(config["initial_learning_rate"], config["learning_rate_drop"])
        print("data file:", config["data_file"])
        print("model file:", config["model_file"])
        print("training file:", config["training_file"])
        print("validation file:", config["validation_file"])
        print("testing file:", config["testing_file"])

        train_model(experiment=experiment,
                    model=model,
                    model_file=config["model_file"],
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=config["initial_learning_rate"],
                    learning_rate_drop=config["learning_rate_drop"],
                    learning_rate_patience=config["patience"],
                    early_stopping_patience=config["early_stop"],
                    n_epochs=config["n_epochs"])

        if args.is_test == "0":
            experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
is_test = "0"

model_list = list()
cmd_list = list()
out_file_list = list()

for model_name in ["unet", "isensee"]:
    for is_denoise in config_dict["is_denoise"]:
        for is_normalize in config_dict["is_normalize"]:
            for is_hist_match in ["0", "1"]:
                for loss in ["minh", "weighted"]:

                    patch_shape = "160-192-128"

                    log_folder = "log"
                    make_dir(log_folder)

                    d = datetime.date.today()
                    year_current = d.year
                    month_current = '{:02d}'.format(d.month)
                    date_current = '{:02d}'.format(d.day)

                    model_filename = get_filename_without_extension(
                        get_model_h5_filename(datatype="model",
                                              is_bias_correction="1",
                                              is_denoise=is_denoise,
                                              is_normalize=is_normalize,
                                              is_hist_match=is_hist_match,
                                              depth_unet=4,
                                              n_base_filters_unet=16,
                                              model_name=model_name,
Example #6
0
def predict(args, prediction_dir="desktop"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(args.patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))

        if prediction_dir == "SERVER":
            prediction_dir = "brats"
        else:
            prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats"
        # prediction_dir = BRATS_DIR
        config["prediction_folder"] = os.path.join(
            prediction_dir, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if args.model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif args.model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif args.model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        args.model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])
def main():
    args = get_args.train25d()

    depth_unet = args.depth_unet
    n_base_filters_unet = args.n_base_filters_unet
    patch_shape = args.patch_shape
    is_crf = args.is_crf
    batch_size = args.batch_size
    is_hist_match = args.is_hist_match
    loss = args.loss

    header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor")
    model_scores = list()
    model_ids = list()

    for is_augment in ["1"]:
        args.is_augment = is_augment
        for model_name in ["unet"]:
            args.model = model_name
            for is_denoise in ["0"]:
                args.is_denoise = is_denoise
                for is_normalize in ["z"]:
                    args.is_normalize = is_normalize
                    for is_hist_match in ["0"]:
                        args.is_hist_match = is_hist_match
                        for loss in ["weighted"]:
                            args.loss = loss
                            for patch_shape in ["160-192-3", "160-192-5", "160-192-7", "160-192-9", "160-192-11", "160-192-13", "160-192-15", "160-192-17"]:
                            # for patch_shape in ["160-192-3", "160-192-5", "160-192-7", "160-192-9", "160-192-11"]:
                                # for patch_shape in ["160-192-3", "160-192-13", "160-192-15", "160-192-17"]:
                                args.patch_shape = patch_shape
                                model_dim = 25

                                print("="*120)
                                print(
                                    ">> processing model-{}{}, depth-{}, filters-{}, patch_shape-{}, is_denoise-{}, is_normalize-{}, is_hist_match-{}, loss-{}".format(
                                        model_name,
                                        model_dim,
                                        depth_unet,
                                        n_base_filters_unet,
                                        patch_shape,
                                        is_denoise,
                                        is_normalize,
                                        is_hist_match,
                                        loss))
                                is_test = "0"
                                model_score, model_path = evaluate(args)
                                if model_score is not None:
                                    print("="*120)
                                    print(">> finished:")

                                    model_ids.append(
                                        get_filename_without_extension(model_path))

                                    row = get_model_info_header(args.challenge, args.year, args.image_shape, args.is_bias_correction,
                                                                args.is_denoise, args.is_normalize, args.is_hist_match,
                                                                model_name, model_dim, patch_shape, loss,
                                                                depth_unet, n_base_filters_unet)
                                    score = [np.mean(model_score["dice_WholeTumor"]),
                                             np.mean(
                                        model_score["dice_TumorCore"]),
                                        np.mean(
                                        model_score["dice_EnhancingTumor"]),
                                        (np.mean(model_score["dice_WholeTumor"])+np.mean(model_score["dice_TumorCore"])+np.mean(model_score["dice_EnhancingTumor"]))/3]

                                    row.extend(score)
                                    model_scores.append(row)

    header = ("challenge", "year", "image_shape", "is_bias_correction",
              "is_denoise", "is_normalize", "is_hist_match",
              "model_name", "model_dim", "depth_unet",
              "n_base_filters_unet", "loss", "patch_shape",
              "dice_WholeTumor", "dice_TumorCore",
              "dice_EnhancingTumor", "dice_Mean")

    final_df = pd.DataFrame.from_records(
        model_scores, columns=header, index=model_ids)

    print(final_df)

    prediction_df_csv_folder = os.path.join(
        BRATS_DIR, "database/prediction/csv/")
    make_dir(prediction_df_csv_folder)

    to_file = prediction_df_csv_folder + "compile.csv"

    final_df.to_csv(to_file)
Example #8
0
def predict(overwrite=True,
            crop=True,
            challenge="brats",
            year=2018,
            image_shape="160-192-128",
            is_bias_correction="1",
            is_normalize="z",
            is_denoise="0",
            is_hist_match="0",
            is_test="1",
            depth_unet=4,
            n_base_filters_unet=16,
            model_name="unet",
            patch_shape="128-128-128",
            is_crf="0",
            batch_size=1,
            loss="minh",
            model_dim=3):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=model_dim,
        dir_read_write="finetune",
        is_finetune=True)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))
        config["prediction_folder"] = os.path.join(
            BRATS_DIR, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])