Exemplo n.º 1
0
def main():

    config = prepare_config()
    prediction_dir = os.path.abspath("prediction")
    run_validation_cases(config=config,
                         output_label_map=True,
                         output_dir=prediction_dir)
Exemplo n.º 2
0
def main(overwrite=False):
    
    # convert test images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file_test"]):
        testing_files, subject_ids = fetch_testing_data_files(return_subject_ids=True)

        #write_data_to_file(testing_files, config["data_file_test"], image_shape=config["image_shape"],
                           #subject_ids=subject_ids)
        write_data_to_file(testing_files, config["data_file_test"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    
    test_keys_file=config["test_file"]
    num_test_files = config["num_test_files"] # Change this accordingly on config.py
    pickle_dump(list(np.arange(num_test_files)), test_keys_file)

    data_file_opened = open_data_file(config["data_file_test"])
    
    prediction_dir = config["output_dir"]
    run_validation_cases(test_keys_file,
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file_test"],
                         output_label_map=True,
                         output_dir=prediction_dir)
Exemplo n.º 3
0
def main():
    kwargs = vars(parse_args())
    prediction_dir = os.path.abspath(kwargs.pop("prediction_dir"))
    output_label_map = not kwargs.pop("no_label_map")
    for key, value in kwargs.items():
        if value:
            if key == "modalities":
                config["training_modalities"] = value
            else:
                config[key] = value
    filenames, subject_ids = fetch_brats_2020_files(config["training_modalities"], group="Validation",
                                                    include_truth=False, return_subject_ids=True)
    if not os.path.exists(config["data_file"]):
        write_data_to_file(filenames, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids, save_truth=False)
        pickle_dump(list(range(len(subject_ids))), config["validation_file"])

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=output_label_map,
                         output_dir=prediction_dir,
                         test=False,
                         output_basename=kwargs["output_basename"],
                         permute=config["permute"])
    for filename_list, subject_id in zip(filenames, subject_ids):
        prediction_filename = os.path.join(prediction_dir, kwargs["output_basename"].format(subject=subject_id))
        print("Resampling:", prediction_filename)
        ref = nib.load(filename_list[0])
        pred = nib.load(prediction_filename)
        pred_resampled = resample_to_img(pred, ref, interpolation="nearest")
        pred_resampled.to_filename(prediction_filename)
Exemplo n.º 4
0
def main():
    prediction_dir = os.path.abspath("prediction")

    # with tf.contrib.tfprof.ProfileContext('./profile_dir') as pctx:
    if args.bs == 1:
        run_validation_cases(validation_keys_file=config["validation_file"],
                             model_file=config["model_file"],
                             training_modalities=config["training_modalities"],
                             labels=config["labels"],
                             hdf5_file=config["data_file"],
                             output_label_map=True,
                             output_dir=prediction_dir,
                             warmup=args.warmup,
                             report_interval=args.report_interval,
                             batch_size=args.bs,
                             n_batch=args.nb)
    else:
        run_large_batch_validation_cases(
            validation_keys_file=config["validation_file"],
            model_file=config["model_file"],
            training_modalities=config["training_modalities"],
            labels=config["labels"],
            hdf5_file=config["data_file"],
            output_label_map=True,
            output_dir=prediction_dir,
            batch_size=args.bs,
            report_interval=args.report_interval,
            warmup=args.warmup,
            n_batch=args.nb)
Exemplo n.º 5
0
    def main(self):
        prediction_dir = os.path.abspath("results/prediction/rev_" +
                                         str(self.config.rev) +
                                         "/prediction_" + self.config.data_set)
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir)

        self.config.model_file = os.path.abspath(
            "Data/generated_data/" + self.config.data_set +
            "_isensee_2017_model_rev" + str(self.config.rev) + ".h5"
        )  #  patch (128,128,128) ; n_filters = 16; ski_blank = True; depth = 5

        self.config.data_file = os.path.abspath("Data/generated_data/" +
                                                self.config.data_set +
                                                "_testing.h5")
        self.config.validation_file = os.path.abspath(
            "Data/generated_data/" + self.config.data_set +
            "_testing_validation_ids.pkl")

        run_validation_cases(
            validation_keys_file=self.config.validation_file,
            model_file=self.config.model_file,
            training_modalities=self.config.training_modalities,
            labels=self.config.labels,
            hdf5_file=self.config.data_file,
            output_label_map=True,
            overlap=self.config.validation_patch_overlap,
            output_dir=prediction_dir)
Exemplo n.º 6
0
def main(args):

    prediction_dir = os.path.abspath("./brats/prediction/" + args.mode.lower())

    if not os.path.exists(prediction_dir):

        os.makedirs(prediction_dir)

    if args.mode.lower() == "unet":

        run_validation_cases(
            validation_keys_file=config_unet["validation_file"],
            model_file=config_unet["model_file"],
            training_modalities=config_unet["training_modalities"],
            labels=config_unet["labels"],
            hdf5_file=config_unet["data_file"],
            output_label_map=True,
            output_dir=prediction_dir)
    elif args.mode.lower() == "isensee2017":

        run_validation_cases(
            validation_keys_file=config_isensee["validation_file"],
            model_file=config_isensee["model_file"],
            training_modalities=config_isensee["training_modalities"],
            labels=config_isensee["labels"],
            hdf5_file=config_isensee["data_file"],
            output_label_map=True,
            output_dir=prediction_dir)
    else:

        raise Exception("ERROR: Unvalid model mode! Enter unet OR isensee")
Exemplo n.º 7
0
def main(args):

    prediction_dir = os.path.abspath("./headneck/prediction_test/" +
                                     args.organ.lower())

    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir)

    test_data_files, subject_ids = fetch_test_data_files(
        return_subject_ids=True)

    if not os.path.exists(config["data_file"]):
        write_data_to_file(test_data_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    if not os.path.exists(config["test_file"]):
        test_list = list(range(len(subject_ids)))
        pickle_dump(test_list, config["test_file"])

    run_validation_cases(validation_keys_file=config["test_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=prediction_dir)

    header = ("Background", "Organ")
    masking_functions = (get_background_mask, get_organ_mask)
    rows = list()

    prediction_path = "./headneck/prediction_test/" + args.organ.lower() + "/"

    for case_folder in glob.glob(prediction_path + "*/"):
        truth_file = os.path.join(case_folder, "truth.nii.gz")
        truth_image = nib.load(truth_file)
        truth = truth_image.get_data()
        prediction_file = os.path.join(case_folder, "prediction.nii.gz")
        prediction_image = nib.load(prediction_file)
        prediction = prediction_image.get_data()
        rows.append([
            dice_coefficient(func(truth), func(prediction))
            for func in masking_functions
        ])

    df = pd.DataFrame.from_records(rows, columns=header)
    df.to_csv(prediction_path + "headneck_scores.csv")

    scores = dict()
    for index, score in enumerate(df.columns):
        values = df.values.T[index]
        scores[score] = values[np.isnan(values) == False]

    plt.boxplot(list(scores.values()), labels=list(scores.keys()))
    plt.ylabel("Dice Coefficient")
    plt.savefig(prediction_path + "test_scores_boxplot.png")
    plt.close()
Exemplo n.º 8
0
def createoutputs(prediction_dir):
    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=prediction_dir)
Exemplo n.º 9
0
def main():
    prediction_dir = os.path.abspath("prediction")
    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=prediction_dir)
Exemplo n.º 10
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--path_run', help='path to the run folder.')

    parsed_args = parser.parse_args(sys.argv[1:])

    path_run = parsed_args.path_run  # get run folder

    #path_pred = os.path.join(path_run, "prediction")

    #path_pred = os.path.abspath(path_run + "/prediction")

    path_pred = os.path.abspath("prediction")

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=path_pred)
    '''the network outputs the nii stacks in a different resolution than the originals introduced, in order to make
    them equal, these are loaded into python and saved again. '''

    dir_pred = listdir(path_pred)

    for case_folder in dir_pred:

        # load predicted files
        spine_file = os.path.join(path_pred, case_folder, "data_spine.nii.gz")
        truth_file = os.path.join(path_pred, case_folder, "truth.nii.gz")
        prediction_file = os.path.join(path_pred, case_folder,
                                       "prediction.nii.gz")
        spine_image = nib.load(spine_file)
        truth_image = nib.load(truth_file)
        prediction_image = nib.load(prediction_file)
        spine = spine_image.get_data()
        truth = truth_image.get_data()
        prediction = prediction_image.get_data()

        # save predicted files
        spine = nib.Nifti1Image(spine, affine=np.eye(4, 4))
        nib.save(spine, spine_file)
        truth = nib.Nifti1Image(truth, affine=np.eye(4, 4))
        nib.save(truth, truth_file)
        prediction = nib.Nifti1Image(prediction, affine=np.eye(4, 4))
        nib.save(prediction, prediction_file)

    # rename predicted files to match the original ones

    path_original = "data/"  # get path of original data
    dir_original = listdir(path_original)  # list of original cases
    for case_pred in dir_pred:  # for each predicted case
        name1, name2, number = case_pred.split("_")  # get number
        rename(path_pred + "/" + case_pred, path_pred + "/" +
               dir_original[int(number)])  # rename to match original case
Exemplo n.º 11
0
def main():
    prediction_dir = os.path.abspath("prediction")

    # original prediction code
    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=os.path.abspath("brats/liver_data.h5"),
                         output_label_map=True,
                         output_dir=prediction_dir)
Exemplo n.º 12
0
def main():
    kwargs = vars(parse_args())
    prediction_dir = os.path.abspath(kwargs.pop("prediction_dir"))
    output_label_map = not kwargs.pop("no_label_map")
    for key, value in kwargs.items():
        if value:
            if key == "modalities":
                config["training_modalities"] = value
            else:
                config[key] = value

    validate_path = kwargs["validate_path"]
    subject_ids = list()
    filenames = list()
    blacklist = []
    for root, dirs, files in os.walk(validate_path):
        for f in files:
            subject_id = f.split('.')[0]
            if subject_id not in blacklist:
                subject_ids.append(subject_id)
                subject_files = list()
                subject_files.append(validate_path + '/' + f)
                filenames.append(tuple(subject_files))

    if not os.path.exists(config["data_file"]):
        write_data_to_file(filenames,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           save_truth=False)
        pickle_dump(list(range(len(subject_ids))), config["validation_file"])

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=output_label_map,
                         output_dir=prediction_dir,
                         test=False,
                         output_basename=kwargs["output_basename"],
                         permute=config["permute"])
    for filename_list, subject_id in zip(filenames, subject_ids):
        prediction_filename = os.path.join(
            prediction_dir,
            kwargs["output_basename"].format(subject=subject_id))
        print("Resampling:", prediction_filename)
        ref = nib.load(filename_list[0])
        pred = nib.load(prediction_filename)
        pred_resampled = resample_to_img(pred, ref, interpolation="nearest")
        pred_resampled.to_filename(prediction_filename)
Exemplo n.º 13
0
def main():

    path_pred = os.path.abspath("prediction")

    run_validation_cases(
        validation_keys_file=config["validation_file"],
        model_file=config["model_file"],
        training_modalities=config["training_modalities"],
        labels=config["labels"],
        hdf5_file=config["data_file"],
        output_label_map=False,  # TRUE -> CLASSES, FALSE -> SIGMOID
        output_dir=path_pred)
    '''the network outputs the nii stacks in a different resolution than the originals introduced, in order to make
        them equal, these are loaded into python and saved again. '''

    dir_pred = listdir(path_pred)

    for case_folder in dir_pred:

        # load predicted files
        spine_file = os.path.join(path_pred, case_folder, "data_spine.nii.gz")
        truth_file = os.path.join(path_pred, case_folder, "truth.nii.gz")
        prediction_file = os.path.join(path_pred, case_folder,
                                       "prediction.nii.gz")
        spine_image = nib.load(spine_file)
        truth_image = nib.load(truth_file)
        prediction_image = nib.load(prediction_file)
        spine = spine_image.get_data()
        truth = truth_image.get_data()
        prediction = prediction_image.get_data()

        # save predicted files
        spine = nib.Nifti1Image(spine, affine=np.eye(4, 4))
        nib.save(spine, spine_file)
        truth = nib.Nifti1Image(truth, affine=np.eye(4, 4))
        nib.save(truth, truth_file)
        prediction = nib.Nifti1Image(prediction, affine=np.eye(4, 4))
        nib.save(prediction, prediction_file)

    # rename predicted files to match the original ones

    dir_pred = (listdir(path_pred))
    path_original = "data/"  # get path of original data
    dir_original = (listdir(path_original))  # list of original cases
    for case_pred in dir_pred:  # for each predicted case
        name1, name2, number = case_pred.split("_")  # get number
        rename(path_pred + "/" + case_pred, path_pred + "/" +
               dir_original[int(number)])  # rename to match original case
Exemplo n.º 14
0
def main():
    #     pdb.set_trace()
    prediction_dir = os.path.abspath("prediction")
    print('Start predicting...')
    run_validation_cases(
        validation_keys_file=config["validation_file"],
        model_file=config["model_file"],
        training_modalities=config["training_modalities"],
        labels=config["labels"],
        #                          overlap=32,
        overlap=
        0,  # this param doesn't work anymore when using prediction specific patching strategy
        hdf5_file=config["data_file"],
        output_dir=prediction_dir,
        center_patch=config['center_patch'],
        overlap_label=config['overlap_label_predict'])
Exemplo n.º 15
0
def main(model_name='isensee2017'):
    if model_name == 'unet':
        from brats.train import config
    else:
        from brats.train_isensee2017 import config

    prediction_dir = os.path.abspath(os.path.join("data", "prediction"))
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir)
    run_validation_cases(model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         validation_keys_file=config["validation_keys_file"],
                         npy_path=pre_config["npy_path"],
                         subject_ids_file=pre_config["subject_ids_file"],
                         label_map=True,
                         output_dir=prediction_dir,
                         labels=config["labels"])
Exemplo n.º 16
0
def main(args):

    prediction_dir = os.path.abspath("./headneck/prediction/" +
                                     args.organ.lower())

    if not os.path.exists(prediction_dir):

        os.makedirs(prediction_dir)

        run_validation_cases(
            validation_keys_file=config_isensee["validation_file"],
            model_file=config_isensee["model_file"],
            training_modalities=config_isensee["training_modalities"],
            labels=config_isensee["labels"],
            hdf5_file=config_isensee["data_file"],
            output_label_map=True,
            output_dir=prediction_dir)
Exemplo n.º 17
0
def predict_training_dataset():
    if not os.path.exists(config['training_index_list']):
        with open(config['training_index_list'], 'wb') as f:
            pickle.dump(list(range(config['num_training_subjects'])), f)
    print('Training dataset prediction starts...')
    run_validation_cases(validation_keys_file=config['training_index_list'],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_dir=config['training_predict_dir'],
                         center_patch=config['center_patch'],
                         overlap_label=config['overlap_label_predict'],
                         final_val=True)
    mv_results(config['training_predict_dir'], config['training_to_upload'])
    print('Training dataset prediction finished.')
    return
Exemplo n.º 18
0
def main():
    prediction_dir = os.path.abspath("prediction")
    print("The prediction process has started")

    # second_model_file - second model file, deeper model
    # model_file - usual architecture, usual model

    # data_file - file

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["second_model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=prediction_dir,
                         isDeeper=True,
                         depth=5)
Exemplo n.º 19
0
def main_run():
    config['num_val_subjects'] = len(
        os.listdir('../data/preprocessed_val_data/val'))

    gen_val_h5()

    if not os.path.exists(config['val_index_list']):
        with open(config['val_index_list'], 'wb') as f:
            pickle.dump(list(range(config['num_val_subjects'])), f)
    print('Validation dataset prediction starts...')
    run_validation_cases(validation_keys_file=config['val_index_list'],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["val_data_file"],
                         output_dir=config['val_predict_dir'],
                         center_patch=config['center_patch'],
                         overlap_label=config['overlap_label_predict'],
                         final_val=True)
    mv_results(config['val_predict_dir'], config['val_to_upload'])
    print('Validation dataset prediction finished.')
    return
Exemplo n.º 20
0
def predict(overwrite=True,
            crop=True,
            challenge="brats",
            year=2018,
            image_shape="160-192-128",
            is_bias_correction="1",
            is_normalize="z",
            is_denoise="0",
            is_hist_match="0",
            is_test="1",
            depth_unet=4,
            n_base_filters_unet=16,
            model_name="unet",
            patch_shape="128-128-128",
            is_crf="0",
            batch_size=1,
            loss="minh",
            model_dim=3):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=model_dim,
        dir_read_write="finetune",
        is_finetune=True)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))
        config["prediction_folder"] = os.path.join(
            BRATS_DIR, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])
Exemplo n.º 21
0
def predict(args, prediction_dir="desktop"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(args.patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))

        if prediction_dir == "SERVER":
            prediction_dir = "brats"
        else:
            prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats"
        # prediction_dir = BRATS_DIR
        config["prediction_folder"] = os.path.join(
            prediction_dir, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if args.model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif args.model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif args.model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        args.model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])