def prepare_data(args): data_dir = get_h5_training_dir(BRATS_DIR, "data") # make dir if not os.path.exists(data_dir): print_separator() print("making dir", data_dir) os.makedirs(data_dir) print_section("convert input images into an hdf5 file") data_filename = get_training_h5_filename(datatype="data", args=args) print(data_filename) data_file_path = os.path.join(data_dir, data_filename) print("save to", data_file_path) dataset = get_dataset( is_test=args.is_test, is_bias_correction=args.is_bias_correction, is_denoise=args.is_denoise) print("reading folder:", dataset) if args.overwrite or not os.path.exists(data_file_path): training_files = fetch_training_data_files(dataset) write_data_to_file(training_files, data_file_path, config=config, image_shape=get_shape_from_string(args.image_shape), brats_dir=BRATS_DIR, crop=args.crop, is_normalize=args.is_normalize, is_hist_match=args.is_hist_match, dataset=dataset, is_denoise=args.is_denoise)
def train(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if args.patch_shape in ["160-192-13", "160-192-15", "160-192-17"]: config["initial_learning_rate"] = 1e-4 print("lr updated...") if args.patch_shape in ["160-192-3"]: config["initial_learning_rate"] = 1e-2 print("lr updated...") if args.overwrite or not os.path.exists(data_path): prepare_data(args) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) print("-" * 60) print("# Load or init model") print("-" * 60) if not args.overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=args.loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if args.model == "seunet": print("init seunet model") model = unet_model_25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], # batch_normalization=True, depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss, is_unet_original=False) elif args.model == "unet": print("init unet model") model = unet_model_25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], # batch_normalization=True, depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "segnet": print("init segnet model") model = segnet25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "isensee": print("init isensee model") model = isensee25d_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if args.is_test == "0": experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL", project_name="train", workspace="guusgrimbergen") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()
def train(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-160-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="weighted"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=3) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if overwrite or not os.path.exists(data_path): prepare_data(overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=batch_size, data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) print("-" * 60) print("# Load or init model") print("-" * 60) if not overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if model_name == "unet": print("init unet model") model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "densefcn": print("init densenet model") # config["initial_learning_rate"] = 1e-5 model = densefcn_model_3d( input_shape=config["input_shape"], classes=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], nb_dense_block=5, nb_layers_per_block=4, early_transition=True, dropout_rate=0.2, loss_function=loss) elif model_name == "denseunet": print("init denseunet model") model = dense_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "resunet": print("init resunet model") model = res_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) if model_name == "seunet": print("init seunet model") model = se_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) else: print("init isensee model") model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="train", workspace="vuhoangminh") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if is_test == "0": experiment.log_parameters(config) data_file_opened.close()
def evaluate(args): header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor") masking_functions = (get_whole_tumor_mask, get_tumor_core_mask, get_enhancing_tumor_mask) data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats" # prediction_dir = BRATS_DIR config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction", get_filename_without_extension(config["model_file"])) name = get_filename_without_extension(config["model_file"]).replace( "_aug-0", "") config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction/wo_augmentation", name) if not os.path.exists(config["prediction_folder"]): print("model not exists. Please check") return None, None else: prediction_df_csv_folder = os.path.join(BRATS_DIR, "database/prediction/csv/") make_dir(prediction_df_csv_folder) config["prediction_df_csv"] = prediction_df_csv_folder + \ get_filename_without_extension(config["model_file"]) + ".csv" if os.path.exists(config["prediction_df_csv"]): df = pd.read_csv(config["prediction_df_csv"]) df1 = df.dice_WholeTumor.T._values df2 = df.dice_TumorCore.T._values df3 = df.dice_EnhancingTumor.T._values rows = np.zeros((df1.size, 3)) rows[:, 0] = df1 rows[:, 1] = df2 rows[:, 2] = df3 subject_ids = list() for case_folder in glob.glob( os.path.join(config["prediction_folder"], "*")): if not os.path.isdir(case_folder): continue subject_ids.append(os.path.basename(case_folder)) df = pd.DataFrame.from_records(rows, columns=header, index=subject_ids) scores = dict() for index, score in enumerate(df.columns): values = df.values.T[index] scores[score] = values[np.isnan(values) == False] else: print("-" * 60) print("SUMMARY") print("-" * 60) print("model file:", config["model_file"]) print("prediction folder:", config["prediction_folder"]) print("csv file:", config["prediction_df_csv"]) print("-" * 60) rows = list() subject_ids = list() for case_folder in glob.glob( os.path.join(config["prediction_folder"], "*")): if not os.path.isdir(case_folder): continue subject_ids.append(os.path.basename(case_folder)) truth_file = os.path.join(case_folder, "truth.nii.gz") truth_image = nib.load(truth_file) truth = truth_image.get_data() prediction_file = os.path.join(case_folder, "prediction.nii.gz") prediction_image = nib.load(prediction_file) prediction = prediction_image.get_data() score_case = get_score(truth, prediction, masking_functions) rows.append(score_case) df = pd.DataFrame.from_records(rows, columns=header, index=subject_ids) df.to_csv(config["prediction_df_csv"]) scores = dict() for index, score in enumerate(df.columns): values = df.values.T[index] scores[score] = values[np.isnan(values) == False] return scores, model_path
def finetune(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) if args.name != "0": model_path = args.name config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if args.overwrite or not os.path.exists(data_path): prepare_data(args) folder = os.path.join(BRATS_DIR, "database", "model", "base") if not os.path.exists(config["model_file"]): model_baseline_path = get_model_baseline_path(folder=folder, args=args) if model_baseline_path is None: raise ValueError("can not fine baseline model. Please check") else: config["model_file"] = model_baseline_path print_section("Open file") data_file_opened = open_data_file(config["data_file"]) make_dir(config["training_file"]) print_section("get training and testing generators") if args.model_dim == 3: from unet3d.generator import get_training_and_validation_and_testing_generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], is_create_patch_index_list_original=config[ "is_create_patch_index_list_original"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) elif args.model_dim == 25: from unet25d.generator import get_training_and_validation_and_testing_generators25d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) else: from unet2d.generator import get_training_and_validation_and_testing_generators2d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) print("-" * 60) print("# Load or init model") print("-" * 60) print(">> update config file") config.update(config_finetune) if not os.path.exists(config["model_file"]): raise Exception("{} model file not found. Please try again".format( config["model_file"])) else: from unet3d.utils.model_utils import generate_model print(">> load old and generate model") model = generate_model( config["model_file"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss, weight_tv_to_main_loss=args.weight_tv_to_main_loss) model.summary() # run training print("-" * 60) print("# start finetuning") print("-" * 60) print("Number of training steps: ", n_train_steps) print("Number of validation steps: ", n_validation_steps) data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args, dir_read_write="finetune", is_finetune=True) config["model_file"] = model_path if os.path.exists(config["model_file"]): print("{} existed. Will skip!!!".format(config["model_file"])) else: if args.is_test == "1": config["n_epochs"] = 5 if args.is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="finetune", workspace="vuhoangminh") else: experiment = None if args.model_dim == 2 and args.model == "isensee": config["initial_learning_rate"] = 1e-7 print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()
def predict(args, prediction_dir="desktop"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) if not os.path.exists(model_path): print("model not exists. Please check") else: config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if prediction_dir == "SERVER": prediction_dir = "brats" else: prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats" # prediction_dir = BRATS_DIR config["prediction_folder"] = os.path.join( prediction_dir, "database/prediction", get_filename_without_extension(config["model_file"])) if is_all_cases_predicted(config["prediction_folder"], config["testing_file"]): print("Already predicted. Skip...") list_already_predicted.append(config["prediction_folder"]) else: make_dir(config["prediction_folder"]) print("-" * 60) print("SUMMARY") print("-" * 60) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) print("prediction folder:", config["prediction_folder"]) print("-" * 60) if not os.path.exists(config["model_file"]): raise ValueError("can not find model {}. Please check".format( config["model_file"])) if args.model_dim == 3: from unet3d.prediction import run_validation_cases elif args.model_dim == 25: from unet25d.prediction import run_validation_cases elif args.model_dim == 2: from unet2d.prediction import run_validation_cases else: raise ValueError( "dim {} NotImplemented error. Please check".format( args.model_dim)) run_validation_cases( validation_keys_file=config["testing_file"], model_file=config["model_file"], training_modalities=config["training_modalities"], labels=config["labels"], hdf5_file=config["data_file"], output_label_map=True, output_dir=config["prediction_folder"])
def predict(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-192-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="minh", model_dim=3): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=model_dim, dir_read_write="finetune", is_finetune=True) if not os.path.exists(model_path): print("model not exists. Please check") else: config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) config["prediction_folder"] = os.path.join( BRATS_DIR, "database/prediction", get_filename_without_extension(config["model_file"])) if is_all_cases_predicted(config["prediction_folder"], config["testing_file"]): print("Already predicted. Skip...") list_already_predicted.append(config["prediction_folder"]) else: make_dir(config["prediction_folder"]) print("-" * 60) print("SUMMARY") print("-" * 60) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) print("prediction folder:", config["prediction_folder"]) print("-" * 60) if not os.path.exists(config["model_file"]): raise ValueError("can not find model {}. Please check".format( config["model_file"])) if model_dim == 3: from unet3d.prediction import run_validation_cases elif model_dim == 25: from unet25d.prediction import run_validation_cases elif model_dim == 2: from unet2d.prediction import run_validation_cases else: raise ValueError( "dim {} NotImplemented error. Please check".format( model_dim)) run_validation_cases( validation_keys_file=config["testing_file"], model_file=config["model_file"], training_modalities=config["training_modalities"], labels=config["labels"], hdf5_file=config["data_file"], output_label_map=True, output_dir=config["prediction_folder"])
def train(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if "casnet" in args.model: config["data_type_generator"] = 'cascaded' elif "sepnet" in args.model: config["data_type_generator"] = 'separated' else: config["data_type_generator"] = 'combined' if args.overwrite or not os.path.exists(data_path): prepare_data(args) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test, data_type_generator=config["data_type_generator"]) print("-" * 60) print("# Load or init model") print("-" * 60) config["input_shape"] = config["input_shape"][0:len(config["input_shape"] ) - 1] if not args.overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model if "casnet" in args.model: args.loss = "casweighted" model = generate_model(config["model_file"], loss_function=args.loss) else: # instantiate new model if args.model == "isensee": print("init isensee model") model = isensee2d_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss) elif args.model == "unet": print("init unet model") model = unet_model_2d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "segnet": print("init segnet model") model = segnet2d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "casnet_v1": print("init casnet_v1 model") model = casnet_v1( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v2": print("init casnet_v2 model") model = casnet_v2( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v3": print("init casnet_v3 model") model = casnet_v3( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v4": print("init casnet_v4 model") model = casnet_v4( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v5": print("init casnet_v5 model") model = casnet_v5( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v6": print("init casnet_v6 model") model = casnet_v6( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v7": print("init casnet_v7 model") model = casnet_v7( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v8": print("init casnet_v8 model") model = casnet_v8( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "sepnet_v1": print("init sepnet_v1 model") model = sepnet_v1( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "sepnet_v2": print("init sepnet_v2 model") model = sepnet_v2( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) else: raise ValueError("Model is NotImplemented. Please check") model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if args.is_test == "0": experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL", project_name="train", workspace="guusgrimbergen") else: experiment = None if args.model == "isensee": config["initial_learning_rate"] = 1e-6 print(config["initial_learning_rate"], config["learning_rate_drop"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()