def run_validation_cases(validation_keys_file, model_file, training_modalities, labels, hdf5_file, output_label_map=False, output_dir=".", threshold=0.5, overlap=16, permute=False, isDeeper=False, depth=4): validation_indices = pickle_load(validation_keys_file) model = None if not isDeeper: #Recreated the unet_model_3d model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"]) #load_old_model(model_file) else: model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth) #Loading the weights to the model model.load_weights(model_file) data_file = tables.open_file(hdf5_file, "r") for index in validation_indices: if 'subject_ids' in data_file.root: case_directory = os.path.join( output_dir, data_file.root.subject_ids[index].decode('utf-8')) else: case_directory = os.path.join(output_dir, "validation_case_{}".format(index)) run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file, training_modalities=training_modalities, output_label_map=output_label_map, labels=labels, threshold=threshold, overlap=overlap, permute=permute) data_file.close()
def main(): if not os.path.exists(config["hdf5_file"]): training_files = list() for label_file in glob.glob("./data/training/subject-*-label.hdr"): training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file)) write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"]) hdf5_file_opened = tables.open_file(config["hdf5_file"], "r") if os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"]) # get training and testing generators train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators( hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], augment=True) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=nb_train_samples, validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"]) hdf5_file_opened.close()
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["hdf5_file"]): training_files = fetch_training_data_files() write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"]) hdf5_file_opened = tables.open_file(config["hdf5_file"], "r") if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d(input_shape=config["input_shape"], downsize_filters_factor=config["downsize_nb_filters_factor"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"]) # get training and testing generators train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators( hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=nb_train_samples, validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"]) hdf5_file_opened.close()
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): training_files = fetch_training_data_files() print("Number of Training file Found:", len(training_files)) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"]) print("Opening data file.") data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): print("Loading existing model file.") model = load_old_model(config["model_file"]) else: # instantiate new model print("Instantiating new model file.") model = unet_model_3d(input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"]) # get training and testing generators print("Getting training and testing generators.") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # run training print("Running the training......") train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close() print("Training DONE")
def test_batch_normalization(self): model = unet_model_3d(input_shape=(1, 16, 16, 16), depth=2, deconvolution=True, metrics=[], n_labels=1, batch_normalization=True) layer_names = [layer.name for layer in model.layers] for name in layer_names[:-3]: # exclude the last convolution layer if 'conv3d' in name and 'transpose' not in name: self.assertIn(name.replace('conv3d', 'batch_normalization'), layer_names)
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): training_files = fetch_training_data_files() write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"]) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"]) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distortion_factor"], augment_rotation_factor=config["rotation_factor"], mirror=config["mirror"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"], logging_path=config["logging_path"]) data_file_opened.close()
def main(overwrite=False): if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d( input_shape=config["input_shape"], n_labels=config["n_labels"], pool_size=config["pool_size"], deconvolution=config["deconvolution"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) # get training and validation generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( npy_path=config["npy_path"], subject_ids_file=config["subject_ids_file"], batch_size=config["batch_size"], validation_batch_size=config["validation_batch_size"], n_labels=config["n_labels"], labels=config["labels"], training_keys_file=config["training_keys_file"], validation_keys_file=config["validation_keys_file"], data_split=config["validation_split"], overwrite=overwrite, augment=config["augment"], augment_flip=config["flip"], augment_distortion_factor=config["distort"], permute=config["permute"], image_shape=config["image_shape"], patch_shape=config["patch_shape"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], skip_blank=config["skip_blank"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"])
def train(ibis_data, input_shape=(96,96,96), batch_size=2, data_split=0.8, num_gpus=None, only_aa=False): input_shape = input_shape+(21,) if only_aa else input_shape+(59,) model = unet_model_3d(input_shape=input_shape, num_gpus=num_gpus) if num_gpus is not None and num_gpus > 1: batch_size *= num_gpus train, validate = IBISGenerator.get_training_and_validation(ibis_data, input_shape=input_shape, batch_size=batch_size, only_aa=only_aa) train_model_generator( model=model, model_file=os.path.abspath("./molmimic_{}.h5".format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))), training_generator=train.generate(), validation_generator=validate.generate(), steps_per_epoch=train.steps_per_epoch, validation_steps=validate.steps_per_epoch, initial_learning_rate=0.001, learning_rate_drop=0.6, learning_rate_epochs=10, n_epochs=200 )
def main(overwrite=False): # # convert input images into an hdf5 file # if overwrite or not os.path.exists(config["data_file"]): # training_files = fetch_training_data_files() # # write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"]) # data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"]) from keras.utils.vis_utils import plot_model plot_model(model, to_file='original_uet.png', show_shapes=True)
def main(overwrite=False): args = get_args.train() overwrite = args.overwrite # config["data_file"] = get_brats_data_h5_path(args.challenge, args.year, # args.inputshape, args.isbiascorrection, # args.normalization, args.clahe, # args.histmatch) # print(config["data_file"]) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators_new( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_steps_file=config["n_steps_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], is_create_patch_index_list_original=config["is_create_patch_index_list_original"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) print("-"*60) print("# Load or init model") print("-"*60) if not overwrite and os.path.exists(config["model_file"]): print("load old model") model = load_old_model(config["model_file"]) else: # instantiate new model print("init model model") model = unet_model_3d(input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=config["depth"], n_base_filters=config["n_base_filters"]) # model.summary() # import nibabel as nib # laptop_save_dir = "C:/Users/minhm/Desktop/temp/" # desktop_save_dir = "/home/minhvu/Desktop/temp/" # save_dir = desktop_save_dir # temp_in_path = desktop_save_dir + "template.nii.gz" # temp_out_path = desktop_save_dir + "out.nii.gz" # temp_out_truth_path = desktop_save_dir + "truth.nii.gz" # n_validation_samples = 0 # validation_samples = list() # for i in range(20): # print(i) # x, y = next(train_generator) # hash_x = hash(str(x)) # validation_samples.append(hash_x) # n_validation_samples += x.shape[0] # temp_in = nib.load(temp_in_path) # temp_out = nib.Nifti1Image(x[0][0], affine=temp_in.affine) # nib.save(temp_out, temp_out_path) # temp_out = nib.Nifti1Image(y[0][0], affine=temp_in.affine) # nib.save(temp_out, temp_out_truth_path) # print(n_validation_samples) print("-"*60) print("# start training") print("-"*60) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
def train(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-160-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="weighted"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=3) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if overwrite or not os.path.exists(data_path): prepare_data(overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=batch_size, data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) print("-" * 60) print("# Load or init model") print("-" * 60) if not overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if model_name == "unet": print("init unet model") model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "densefcn": print("init densenet model") # config["initial_learning_rate"] = 1e-5 model = densefcn_model_3d( input_shape=config["input_shape"], classes=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], nb_dense_block=5, nb_layers_per_block=4, early_transition=True, dropout_rate=0.2, loss_function=loss) elif model_name == "denseunet": print("init denseunet model") model = dense_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "resunet": print("init resunet model") model = res_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) if model_name == "seunet": print("init seunet model") model = se_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) else: print("init isensee model") model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="train", workspace="vuhoangminh") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if is_test == "0": experiment.log_parameters(config) data_file_opened.close()
def main(overwrite_data=False, overwrite_model=False): # run if the data not already stored hdf5 if overwrite_data or not os.path.exists(config["data_file"]): _save_new_h5_datafile(config["data_file"], new_image_shape=config["image_shape"]) data_file_opened = open_data_file(config["data_file"]) if not overwrite_model and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model print('initializing new isensee model with input shape', config['input_shape']) ''' model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) ''' model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], n_base_filters=config["n_base_filters"]) # get training and testing generators (train_generator, validation_generator, n_train_steps, n_validation_steps) = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite_data, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
import os import glob from unet3d.model import unet_model_3d from data_generator import DataGenerator DATA_LEN = 5 IMG_SIZE = 224 N_CHANNEL = 16 DATAPATH = '/home/trungdunghoang/Documents/EPFL/3DUnetCNN/data_test' model = unet_model_3d(input_shape=(DATA_LEN, IMG_SIZE, IMG_SIZE, N_CHANNEL)) train_generator = DataGenerator(DATAPATH) model.fit_generator(generator=train_generator, steps_per_epoch=len(train_generator), epochs=10, validation_data=train_generator, validation_steps=len(train_generator))
def main(overwrite=False): # convert input images into an hdf5 file print(overwrite or not os.path.exists(config["data_file"])) print('path: ', os.path.exists(config["data_file"])) if overwrite or not os.path.exists(config["data_file"]): training_files = fetch_training_data_files() # try: write_data_to_file( training_files, config["data_file"], image_shape=config["image_shape"]) #, normalize=False) # except: # import pdb; pdb.set_trace() data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"]) print(model.summary()) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=False, #overwrite, # set to False so that the training idcs # are used as previously; as they are now used for the # normalization already in write_data_to_file (above) validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # normalize the dataset if required # use only the training img (training_keys_file) fetch_training_data_files() # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
import os from unet3d.model import unet_model_3d from unet3d.training import load_old_model, train_model from unet3d.generator import ExampleSphereGenerator from keras import backend as K K.set_image_dim_ordering('tf') input_shape = (96, 96, 96, 1) model = unet_model_3d(input_shape=input_shape) train_generator, validation_generator = ExampleSphereGenerator.get_training_and_validation( input_shape, cnt=5, border=10, batch_size=20, n_samples=500) train_model(model=model, model_file=os.path.abspath("./SphereCNN.h5"), training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=train_generator.num_steps, validation_steps=validation_generator.num_steps, initial_learning_rate=0.00001, learning_rate_drop=0.5, learning_rate_epochs=10, n_epochs=50)