def normalize_one_folder(data_folder, dataset, overwrite=False): normalize_minh_dir = get_normalize_minh_dir(brats_dir=BRATS_DIR, data_folder=data_folder, dataset=dataset) make_dir(normalize_minh_dir) data_dir = get_data_dir(brats_dir=BRATS_DIR, data_folder=data_folder, dataset=dataset) subject_paths = glob.glob(os.path.join(data_dir, "*", "*", "*.nii.gz")) for i in range(len(subject_paths)): subject_path = subject_paths[i] normalize_minh_file_path = get_normalize_minh_file_path( path=subject_path, dataset=dataset) parent_dir = get_parent_dir(normalize_minh_file_path) make_dir(parent_dir) print_processing(subject_path) template_path = get_template_path( path=subject_path, dataset=dataset, brats_dir=BRATS_DIR, template_data_folder=config["template_data_folder"], template_folder=config["template_folder"]) template = nib.load(template_path) template = template.get_fdata() if overwrite or not os.path.exists(normalize_minh_file_path): if config["truth"][0] in normalize_minh_file_path: print("saving truth to", normalize_minh_file_path) shutil.copy(subject_path, normalize_minh_file_path) elif config["mask"][0] in normalize_minh_file_path: print("saving mask to", normalize_minh_file_path) shutil.copy(subject_path, normalize_minh_file_path) else: volume = nib.load(subject_path) affine = volume.affine volume = volume.get_fdata() source_hist_match = hist_match_non_zeros(volume, template) print("saving to", normalize_minh_file_path) source_hist_match = nib.Nifti1Image(source_hist_match, affine=affine) nib.save(source_hist_match, normalize_minh_file_path)
def evaluate(args): header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor") masking_functions = (get_whole_tumor_mask, get_tumor_core_mask, get_enhancing_tumor_mask) data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats" # prediction_dir = BRATS_DIR config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction", get_filename_without_extension(config["model_file"])) name = get_filename_without_extension(config["model_file"]).replace( "_aug-0", "") config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction/wo_augmentation", name) if not os.path.exists(config["prediction_folder"]): print("model not exists. Please check") return None, None else: prediction_df_csv_folder = os.path.join(BRATS_DIR, "database/prediction/csv/") make_dir(prediction_df_csv_folder) config["prediction_df_csv"] = prediction_df_csv_folder + \ get_filename_without_extension(config["model_file"]) + ".csv" if os.path.exists(config["prediction_df_csv"]): df = pd.read_csv(config["prediction_df_csv"]) df1 = df.dice_WholeTumor.T._values df2 = df.dice_TumorCore.T._values df3 = df.dice_EnhancingTumor.T._values rows = np.zeros((df1.size, 3)) rows[:, 0] = df1 rows[:, 1] = df2 rows[:, 2] = df3 subject_ids = list() for case_folder in glob.glob( os.path.join(config["prediction_folder"], "*")): if not os.path.isdir(case_folder): continue subject_ids.append(os.path.basename(case_folder)) df = pd.DataFrame.from_records(rows, columns=header, index=subject_ids) scores = dict() for index, score in enumerate(df.columns): values = df.values.T[index] scores[score] = values[np.isnan(values) == False] else: print("-" * 60) print("SUMMARY") print("-" * 60) print("model file:", config["model_file"]) print("prediction folder:", config["prediction_folder"]) print("csv file:", config["prediction_df_csv"]) print("-" * 60) rows = list() subject_ids = list() for case_folder in glob.glob( os.path.join(config["prediction_folder"], "*")): if not os.path.isdir(case_folder): continue subject_ids.append(os.path.basename(case_folder)) truth_file = os.path.join(case_folder, "truth.nii.gz") truth_image = nib.load(truth_file) truth = truth_image.get_data() prediction_file = os.path.join(case_folder, "prediction.nii.gz") prediction_image = nib.load(prediction_file) prediction = prediction_image.get_data() score_case = get_score(truth, prediction, masking_functions) rows.append(score_case) df = pd.DataFrame.from_records(rows, columns=header, index=subject_ids) df.to_csv(config["prediction_df_csv"]) scores = dict() for index, score in enumerate(df.columns): values = df.values.T[index] scores[score] = values[np.isnan(values) == False] return scores, model_path
def main(): args = get_args.train() overwrite = args.overwrite crop = args.crop challenge = args.challenge year = args.year image_shape = args.image_shape is_bias_correction = args.is_bias_correction is_normalize = args.is_normalize is_denoise = args.is_denoise is_test = args.is_test model_name = args.model depth_unet = args.depth_unet n_base_filters_unet = args.n_base_filters_unet patch_shape = args.patch_shape is_crf = args.is_crf batch_size = args.batch_size is_hist_match = args.is_hist_match loss = args.loss args.is_augment = "0" header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor") model_scores = list() model_ids = list() for depth_unet in [4, 5]: args.depth_unet = depth_unet for n_base_filters_unet in [16, 32]: args.n_base_filters_unet = n_base_filters_unet for model_dim in [2, 3, 25]: args.model_dim = model_dim # if depth_unet == 5 or n_base_filters_unet == 32: # list_model = config_dict["model_depth"] # else: list_model = config_dict["model"] for model_name in list_model: args.model_name = model_name for is_normalize in config_dict["is_normalize"]: args.is_normalize = is_normalize for is_denoise in config_dict["is_denoise"]: args.is_denoise = is_denoise for is_hist_match in config_dict["hist_match"]: args.is_hist_match = is_hist_match for patch_shape in config_dict["patch_shape"]: args.patch_shape = patch_shape for loss in config_dict["loss"]: args.loss = loss print("=" * 120) print( ">> processing model-{}{}, depth-{}, filters-{}, patch_shape-{}, is_denoise-{}, is_normalize-{}, is_hist_match-{}, loss-{}" .format(model_name, model_dim, depth_unet, n_base_filters_unet, patch_shape, is_denoise, is_normalize, is_hist_match, loss)) is_test = "0" model_score, model_path = evaluate( args) if model_score is not None and get_filename_without_extension( model_path) not in model_ids: print("=" * 120) print(">> finished:") model_ids.append( get_filename_without_extension( model_path)) row = get_model_info_header( challenge, year, image_shape, is_bias_correction, is_denoise, is_normalize, is_hist_match, model_name, model_dim, patch_shape, loss, depth_unet, n_base_filters_unet) score = [ np.mean(model_score[ "dice_WholeTumor"]), np.mean(model_score[ "dice_TumorCore"]), np.mean(model_score[ "dice_EnhancingTumor"]), (np.mean(model_score[ "dice_WholeTumor"]) + np.mean(model_score[ "dice_TumorCore"]) + np.mean(model_score[ "dice_EnhancingTumor"])) / 3 ] row.extend(score) model_scores.append(row) header = ("challenge", "year", "image_shape", "is_bias_correction", "is_denoise", "is_normalize", "is_hist_match", "model_name", "model_dim", "depth_unet", "n_base_filters_unet", "loss", "patch_shape", "dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor", "dice_Mean") final_df = pd.DataFrame.from_records(model_scores, columns=header, index=model_ids) print(final_df) prediction_df_csv_folder = os.path.join(BRATS_DIR, "database/prediction/csv/") make_dir(prediction_df_csv_folder) to_file = prediction_df_csv_folder + "compile.csv" final_df.to_csv(to_file)
def finetune(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) if args.name != "0": model_path = args.name config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if args.overwrite or not os.path.exists(data_path): prepare_data(args) folder = os.path.join(BRATS_DIR, "database", "model", "base") if not os.path.exists(config["model_file"]): model_baseline_path = get_model_baseline_path(folder=folder, args=args) if model_baseline_path is None: raise ValueError("can not fine baseline model. Please check") else: config["model_file"] = model_baseline_path print_section("Open file") data_file_opened = open_data_file(config["data_file"]) make_dir(config["training_file"]) print_section("get training and testing generators") if args.model_dim == 3: from unet3d.generator import get_training_and_validation_and_testing_generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], is_create_patch_index_list_original=config[ "is_create_patch_index_list_original"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) elif args.model_dim == 25: from unet25d.generator import get_training_and_validation_and_testing_generators25d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) else: from unet2d.generator import get_training_and_validation_and_testing_generators2d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) print("-" * 60) print("# Load or init model") print("-" * 60) print(">> update config file") config.update(config_finetune) if not os.path.exists(config["model_file"]): raise Exception("{} model file not found. Please try again".format( config["model_file"])) else: from unet3d.utils.model_utils import generate_model print(">> load old and generate model") model = generate_model( config["model_file"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss, weight_tv_to_main_loss=args.weight_tv_to_main_loss) model.summary() # run training print("-" * 60) print("# start finetuning") print("-" * 60) print("Number of training steps: ", n_train_steps) print("Number of validation steps: ", n_validation_steps) data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args, dir_read_write="finetune", is_finetune=True) config["model_file"] = model_path if os.path.exists(config["model_file"]): print("{} existed. Will skip!!!".format(config["model_file"])) else: if args.is_test == "1": config["n_epochs"] = 5 if args.is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="finetune", workspace="vuhoangminh") else: experiment = None if args.model_dim == 2 and args.model == "isensee": config["initial_learning_rate"] = 1e-7 print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()
is_test = "0" model_list = list() cmd_list = list() out_file_list = list() for model_name in ["unet", "isensee"]: for is_denoise in config_dict["is_denoise"]: for is_normalize in config_dict["is_normalize"]: for is_hist_match in ["0", "1"]: for loss in ["minh", "weighted"]: patch_shape = "160-192-128" log_folder = "log" make_dir(log_folder) d = datetime.date.today() year_current = d.year month_current = '{:02d}'.format(d.month) date_current = '{:02d}'.format(d.day) model_filename = get_filename_without_extension( get_model_h5_filename(datatype="model", is_bias_correction="1", is_denoise=is_denoise, is_normalize=is_normalize, is_hist_match=is_hist_match, depth_unet=4, n_base_filters_unet=16, model_name=model_name,
def predict(args, prediction_dir="desktop"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) if not os.path.exists(model_path): print("model not exists. Please check") else: config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if prediction_dir == "SERVER": prediction_dir = "brats" else: prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats" # prediction_dir = BRATS_DIR config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction", get_filename_without_extension(config["model_file"])) if is_all_cases_predicted(config["prediction_folder"], config["testing_file"]): print("Already predicted. Skip...") list_already_predicted.append(config["prediction_folder"]) else: make_dir(config["prediction_folder"]) print("-" * 60) print("SUMMARY") print("-" * 60) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) print("prediction folder:", config["prediction_folder"]) print("-" * 60) if not os.path.exists(config["model_file"]): raise ValueError("can not find model {}. Please check".format( config["model_file"])) if args.model_dim == 3: from unet3d.prediction import run_validation_cases elif args.model_dim == 25: from unet25d.prediction import run_validation_cases elif args.model_dim == 2: from unet2d.prediction import run_validation_cases else: raise ValueError( "dim {} NotImplemented error. Please check".format( args.model_dim)) run_validation_cases( validation_keys_file=config["testing_file"], model_file=config["model_file"], training_modalities=config["training_modalities"], labels=config["labels"], hdf5_file=config["data_file"], output_label_map=True, output_dir=config["prediction_folder"])
def main(): args = get_args.train25d() depth_unet = args.depth_unet n_base_filters_unet = args.n_base_filters_unet patch_shape = args.patch_shape is_crf = args.is_crf batch_size = args.batch_size is_hist_match = args.is_hist_match loss = args.loss header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor") model_scores = list() model_ids = list() for is_augment in ["1"]: args.is_augment = is_augment for model_name in ["unet"]: args.model = model_name for is_denoise in ["0"]: args.is_denoise = is_denoise for is_normalize in ["z"]: args.is_normalize = is_normalize for is_hist_match in ["0"]: args.is_hist_match = is_hist_match for loss in ["weighted"]: args.loss = loss for patch_shape in ["160-192-3", "160-192-5", "160-192-7", "160-192-9", "160-192-11", "160-192-13", "160-192-15", "160-192-17"]: # for patch_shape in ["160-192-3", "160-192-5", "160-192-7", "160-192-9", "160-192-11"]: # for patch_shape in ["160-192-3", "160-192-13", "160-192-15", "160-192-17"]: args.patch_shape = patch_shape model_dim = 25 print("="*120) print( ">> processing model-{}{}, depth-{}, filters-{}, patch_shape-{}, is_denoise-{}, is_normalize-{}, is_hist_match-{}, loss-{}".format( model_name, model_dim, depth_unet, n_base_filters_unet, patch_shape, is_denoise, is_normalize, is_hist_match, loss)) is_test = "0" model_score, model_path = evaluate(args) if model_score is not None: print("="*120) print(">> finished:") model_ids.append( get_filename_without_extension(model_path)) row = get_model_info_header(args.challenge, args.year, args.image_shape, args.is_bias_correction, args.is_denoise, args.is_normalize, args.is_hist_match, model_name, model_dim, patch_shape, loss, depth_unet, n_base_filters_unet) score = [np.mean(model_score["dice_WholeTumor"]), np.mean( model_score["dice_TumorCore"]), np.mean( model_score["dice_EnhancingTumor"]), (np.mean(model_score["dice_WholeTumor"])+np.mean(model_score["dice_TumorCore"])+np.mean(model_score["dice_EnhancingTumor"]))/3] row.extend(score) model_scores.append(row) header = ("challenge", "year", "image_shape", "is_bias_correction", "is_denoise", "is_normalize", "is_hist_match", "model_name", "model_dim", "depth_unet", "n_base_filters_unet", "loss", "patch_shape", "dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor", "dice_Mean") final_df = pd.DataFrame.from_records( model_scores, columns=header, index=model_ids) print(final_df) prediction_df_csv_folder = os.path.join( BRATS_DIR, "database/prediction/csv/") make_dir(prediction_df_csv_folder) to_file = prediction_df_csv_folder + "compile.csv" final_df.to_csv(to_file)
def predict(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-192-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="minh", model_dim=3): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=model_dim, dir_read_write="finetune", is_finetune=True) if not os.path.exists(model_path): print("model not exists. Please check") else: config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) config["prediction_folder"] = os.path.join( BRATS_DIR, "database/prediction", get_filename_without_extension(config["model_file"])) if is_all_cases_predicted(config["prediction_folder"], config["testing_file"]): print("Already predicted. Skip...") list_already_predicted.append(config["prediction_folder"]) else: make_dir(config["prediction_folder"]) print("-" * 60) print("SUMMARY") print("-" * 60) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) print("prediction folder:", config["prediction_folder"]) print("-" * 60) if not os.path.exists(config["model_file"]): raise ValueError("can not find model {}. Please check".format( config["model_file"])) if model_dim == 3: from unet3d.prediction import run_validation_cases elif model_dim == 25: from unet25d.prediction import run_validation_cases elif model_dim == 2: from unet2d.prediction import run_validation_cases else: raise ValueError( "dim {} NotImplemented error. Please check".format( model_dim)) run_validation_cases( validation_keys_file=config["testing_file"], model_file=config["model_file"], training_modalities=config["training_modalities"], labels=config["labels"], hdf5_file=config["data_file"], output_label_map=True, output_dir=config["prediction_folder"])