def main(): global config args = get_args.train25d() config = path_utils.update_is_augment(args, config) data_path, _, _, _, _ = path_utils.get_training_h5_paths(BRATS_DIR, args) if args.overwrite or not os.path.exists(data_path): prepare_data(args) train(args)
def train(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if args.patch_shape in ["160-192-13", "160-192-15", "160-192-17"]: config["initial_learning_rate"] = 1e-4 print("lr updated...") if args.patch_shape in ["160-192-3"]: config["initial_learning_rate"] = 1e-2 print("lr updated...") if args.overwrite or not os.path.exists(data_path): prepare_data(args) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) print("-" * 60) print("# Load or init model") print("-" * 60) if not args.overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=args.loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if args.model == "seunet": print("init seunet model") model = unet_model_25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], # batch_normalization=True, depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss, is_unet_original=False) elif args.model == "unet": print("init unet model") model = unet_model_25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], # batch_normalization=True, depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "segnet": print("init segnet model") model = segnet25d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "isensee": print("init isensee model") model = isensee25d_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if args.is_test == "0": experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL", project_name="train", workspace="guusgrimbergen") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()
def finetune(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) if args.name != "0": model_path = args.name config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if args.overwrite or not os.path.exists(data_path): prepare_data(args) folder = os.path.join(BRATS_DIR, "database", "model", "base") if not os.path.exists(config["model_file"]): model_baseline_path = get_model_baseline_path(folder=folder, args=args) if model_baseline_path is None: raise ValueError("can not fine baseline model. Please check") else: config["model_file"] = model_baseline_path print_section("Open file") data_file_opened = open_data_file(config["data_file"]) make_dir(config["training_file"]) print_section("get training and testing generators") if args.model_dim == 3: from unet3d.generator import get_training_and_validation_and_testing_generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], is_create_patch_index_list_original=config[ "is_create_patch_index_list_original"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) elif args.model_dim == 25: from unet25d.generator import get_training_and_validation_and_testing_generators25d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) else: from unet2d.generator import get_training_and_validation_and_testing_generators2d train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test) print("-" * 60) print("# Load or init model") print("-" * 60) print(">> update config file") config.update(config_finetune) if not os.path.exists(config["model_file"]): raise Exception("{} model file not found. Please try again".format( config["model_file"])) else: from unet3d.utils.model_utils import generate_model print(">> load old and generate model") model = generate_model( config["model_file"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss, weight_tv_to_main_loss=args.weight_tv_to_main_loss) model.summary() # run training print("-" * 60) print("# start finetuning") print("-" * 60) print("Number of training steps: ", n_train_steps) print("Number of validation steps: ", n_validation_steps) data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args, dir_read_write="finetune", is_finetune=True) config["model_file"] = model_path if os.path.exists(config["model_file"]): print("{} existed. Will skip!!!".format(config["model_file"])) else: if args.is_test == "1": config["n_epochs"] = 5 if args.is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="finetune", workspace="vuhoangminh") else: experiment = None if args.model_dim == 2 and args.model == "isensee": config["initial_learning_rate"] = 1e-7 print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()
def train(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-160-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="weighted"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=3) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if overwrite or not os.path.exists(data_path): prepare_data(overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=batch_size, data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) print("-" * 60) print("# Load or init model") print("-" * 60) if not overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if model_name == "unet": print("init unet model") model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "densefcn": print("init densenet model") # config["initial_learning_rate"] = 1e-5 model = densefcn_model_3d( input_shape=config["input_shape"], classes=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], nb_dense_block=5, nb_layers_per_block=4, early_transition=True, dropout_rate=0.2, loss_function=loss) elif model_name == "denseunet": print("init denseunet model") model = dense_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "resunet": print("init resunet model") model = res_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) if model_name == "seunet": print("init seunet model") model = se_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) else: print("init isensee model") model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="train", workspace="vuhoangminh") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if is_test == "0": experiment.log_parameters(config) data_file_opened.close()
def train(args): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, args=args) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(args.patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if "casnet" in args.model: config["data_type_generator"] = 'cascaded' elif "sepnet" in args.model: config["data_type_generator"] = 'separated' else: config["data_type_generator"] = 'combined' if args.overwrite or not os.path.exists(data_path): prepare_data(args) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d( data_file_opened, batch_size=args.batch_size, data_split=config["validation_split"], overwrite=args.overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=args.batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"], is_test=args.is_test, data_type_generator=config["data_type_generator"]) print("-" * 60) print("# Load or init model") print("-" * 60) config["input_shape"] = config["input_shape"][0:len(config["input_shape"] ) - 1] if not args.overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model if "casnet" in args.model: args.loss = "casweighted" model = generate_model(config["model_file"], loss_function=args.loss) else: # instantiate new model if args.model == "isensee": print("init isensee model") model = isensee2d_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=args.loss) elif args.model == "unet": print("init unet model") model = unet_model_2d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "segnet": print("init segnet model") model = segnet2d( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function=args.loss) elif args.model == "casnet_v1": print("init casnet_v1 model") model = casnet_v1( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v2": print("init casnet_v2 model") model = casnet_v2( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v3": print("init casnet_v3 model") model = casnet_v3( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v4": print("init casnet_v4 model") model = casnet_v4( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v5": print("init casnet_v5 model") model = casnet_v5( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet, loss_function="casweighted") elif args.model == "casnet_v6": print("init casnet_v6 model") model = casnet_v6( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v7": print("init casnet_v7 model") model = casnet_v7( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "casnet_v8": print("init casnet_v8 model") model = casnet_v8( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "sepnet_v1": print("init sepnet_v1 model") model = sepnet_v1( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) elif args.model == "sepnet_v2": print("init sepnet_v2 model") model = sepnet_v2( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=args.depth_unet, n_base_filters=args.n_base_filters_unet) else: raise ValueError("Model is NotImplemented. Please check") model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if args.is_test == "0": experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL", project_name="train", workspace="guusgrimbergen") else: experiment = None if args.model == "isensee": config["initial_learning_rate"] = 1e-6 print(config["initial_learning_rate"], config["learning_rate_drop"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if args.is_test == "0": experiment.log_parameters(config) data_file_opened.close() from keras import backend as K K.clear_session()