def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], subject_ids=subject_ids) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): print("Loading old model file from the location: ", config["model_file"]) model = load_old_model(config["model_file"]) else: # instantiate new model print("Creating new model at the location: ", config["model_file"]) model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) print("Running the Training. Model file:", config["model_file"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
def load_model(self): model, context_output_name = isensee2017_model(input_shape=self.config.input_shape, n_labels=self.config.n_labels, initial_learning_rate=self.config.initial_learning_rate, n_base_filters=self.config.n_base_filters, loss_function=self.config.loss_function, shortcut=self.config.shortcut, compile=False) jd = JDOT(model, config=self.config, context_output_name=context_output_name) jd.load_old_model(self.config.model_file) jd.compile_model() self.jd = jd
def main(self, overwrite_data=True, overwrite_model=True): # convert input images into an hdf5 file if overwrite_data or not os.path.exists(self.config.source_data_file) or not os.path.exists(self.config.target_data_file): ''' We write two files, one with source samples and one with target samples. ''' source_data_files, target_data_files, subject_ids_source, subject_ids_target = self.fetch_training_data_files(return_subject_ids=True) if not os.path.exists(self.config.source_data_file) or overwrite_data: write_data_to_file(source_data_files, self.config.source_data_file, image_shape=self.config.image_shape, subject_ids=subject_ids_source) if not os.path.exists(self.config.target_data_file) or overwrite_data: write_data_to_file(target_data_files, self.config.target_data_file, image_shape=self.config.image_shape, subject_ids=subject_ids_target) else: print("Reusing previously written data file. Set overwrite_data to True to overwrite this file.") source_data = open_data_file(self.config.source_data_file) target_data = open_data_file(self.config.target_data_file) # instantiate new model, compile = False because the compilation is made in JDOT.py model, context_output_name = isensee2017_model(input_shape=self.config.input_shape, n_labels=self.config.n_labels, initial_learning_rate=self.config.initial_learning_rate, n_base_filters=self.config.n_base_filters, loss_function=self.config.loss_function, shortcut=self.config.shortcut, depth=self.config.depth, compile=False) # get training and testing generators if not self.config.depth_jdot: context_output_name = [] jd = JDOT(model, config=self.config, source_data=source_data, target_data=target_data, context_output_name=context_output_name) # m = jd.load_old_model(self.config.model_file) # print(m) if self.config.load_base_model: print("Loading trained model") jd.load_old_model(os.path.abspath("Data/saved_models/model_center_"+self.config.source_center)+".h5") elif not self.config.overwrite_model: jd.load_old_model(self.config.model_file) else: print("Creating new model, this will overwrite your old model") jd.compile_model() if self.config.train_jdot: jd.train_model(self.config.epochs) else: jd.train_model_on_source(self.config.epochs) jd.evaluate_model() source_data.close() target_data.close()
def main(overwrite=False): if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) # get training and validation generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( npy_path=config["npy_path"], subject_ids_file=config["subject_ids_file"], batch_size=config["batch_size"], validation_batch_size=config["validation_batch_size"], n_labels=config["n_labels"], labels=config["labels"], training_keys_file=config["training_keys_file"], validation_keys_file=config["validation_keys_file"], data_split=config["validation_split"], overwrite=overwrite, augment=config["augment"], augment_flip=config["flip"], augment_distortion_factor=config["distort"], permute=config["permute"], image_shape=config["image_shape"], patch_shape=config["patch_shape"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], skip_blank=config["skip_blank"] ) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"])
def main(overwrite=False): # # convert input images into an hdf5 file # # 若有则加载旧数据集,注意,此时image_shape为之前设置的 # if overwrite or not os.path.exists(config["data_file"]): # training_files, subject_ids = fetch_training_data_files(return_subject_ids=True) # # write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], # subject_ids=subject_ids) # data_file_opened = open_data_file(config["data_file"]) # 加载/创建模型文件 if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) from keras.utils.vis_utils import plot_model plot_model(model, to_file='isensee_unet.png', show_shapes=True)
def main(overwrite=False): # convert input images into an hdf5 file pdb.set_trace() if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], modality_names=config['all_modalities'], subject_ids=subject_ids, mean_std_file=config['mean_std_file']) # return data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) # get training and testing generators # pdb.set_trace() train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"], pred_specific=config['pred_specific'], overlap_label=config['overlap_label_generator'], for_final_val=config['for_final_val']) # run training # pdb.set_trace() time_0 = time.time() train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"], logging_file=config['logging_file']) print('Training time:', sec2hms(time.time() - time_0)) data_file_opened.close()
def train(overwrite=True, crop=True, challenge="brats", year=2018, image_shape="160-160-128", is_bias_correction="1", is_normalize="z", is_denoise="0", is_hist_match="0", is_test="1", depth_unet=4, n_base_filters_unet=16, model_name="unet", patch_shape="128-128-128", is_crf="0", batch_size=1, loss="weighted"): data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths( brats_dir=BRATS_DIR, overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test, model_name=model_name, depth_unet=depth_unet, n_base_filters_unet=n_base_filters_unet, patch_shape=patch_shape, is_crf=is_crf, loss=loss, model_dim=3) config["data_file"] = data_path config["model_file"] = model_path config["training_file"] = trainids_path config["validation_file"] = validids_path config["testing_file"] = testids_path config["patch_shape"] = get_shape_from_string(patch_shape) config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"])) if overwrite or not os.path.exists(data_path): prepare_data(overwrite=overwrite, crop=crop, challenge=challenge, year=year, image_shape=image_shape, is_bias_correction=is_bias_correction, is_normalize=is_normalize, is_denoise=is_denoise, is_hist_match=is_hist_match, is_test=is_test) print_section("Open file") data_file_opened = open_data_file(config["data_file"]) print_section("get training and testing generators") train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators( data_file_opened, batch_size=batch_size, data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], testing_keys_file=config["testing_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=batch_size, validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment_flipud=config["augment_flipud"], augment_fliplr=config["augment_fliplr"], augment_elastic=config["augment_elastic"], augment_rotation=config["augment_rotation"], augment_shift=config["augment_shift"], augment_shear=config["augment_shear"], augment_zoom=config["augment_zoom"], n_augment=config["n_augment"], skip_blank=config["skip_blank"]) print("-" * 60) print("# Load or init model") print("-" * 60) if not overwrite and os.path.exists(config["model_file"]): print("load old model") from unet3d.utils.model_utils import generate_model model = generate_model(config["model_file"], loss_function=loss) # model = load_old_model(config["model_file"]) else: # instantiate new model if model_name == "unet": print("init unet model") model = unet_model_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "densefcn": print("init densenet model") # config["initial_learning_rate"] = 1e-5 model = densefcn_model_3d( input_shape=config["input_shape"], classes=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], nb_dense_block=5, nb_layers_per_block=4, early_transition=True, dropout_rate=0.2, loss_function=loss) elif model_name == "denseunet": print("init denseunet model") model = dense_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) elif model_name == "resunet": print("init resunet model") model = res_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) if model_name == "seunet": print("init seunet model") model = se_unet_3d( input_shape=config["input_shape"], pool_size=config["pool_size"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], deconvolution=config["deconvolution"], depth=depth_unet, n_base_filters=n_base_filters_unet, loss_function=loss) else: print("init isensee model") model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], loss_function=loss) model.summary() print("-" * 60) print("# start training") print("-" * 60) # run training if is_test == "0": experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS", project_name="train", workspace="vuhoangminh") else: experiment = None print(config["initial_learning_rate"], config["learning_rate_drop"]) print("data file:", config["data_file"]) print("model file:", config["model_file"]) print("training file:", config["training_file"]) print("validation file:", config["validation_file"]) print("testing file:", config["testing_file"]) train_model(experiment=experiment, model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) if is_test == "0": experiment.log_parameters(config) data_file_opened.close()
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], subject_ids=subject_ids) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) else: # instantiate new model model = isensee2017_model( input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) with open('isensemodel_original.txt', 'w') as fh: # Pass the file handle in as a lambda function to make it callable model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n')) # Save Model plot_model(model, to_file="isensemodel_original.png", show_shapes=True) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
def main(config=None): # convert input images into an hdf5 file overwrite = config['overwrite'] if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], subject_ids=subject_ids) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): # if this happens, then the code wont care what is in "model_name" in config because it will take whatever # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful # with this. model = load_old_model(config, re_compile=False) model.summary() # visualize_filters_shape(model) else: # instantiate new model if (config["model_name"] == "3d_unet_residual"): """3D Unet Residual Model""" model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], n_base_filters=config["n_base_filters"], activation_name='softmax') optimizer = getattr( opts, config["optimizer"]["name"])(**config["optimizer"].get('args')) loss = getattr(module_metric, config["loss_fc"]) metrics = [getattr(module_metric, x) for x in config["metrics"]] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) model.summary() # visualize_filters_shape(model) elif (config["model_name"] == "attention_unet"): """Attention Unet Model""" model = attention_unet_model( input_shape=config["input_shape"], n_labels=config["n_labels"], n_base_filters=config["n_base_filters"], activation_name='softmax') optimizer = getattr( opts, config["optimizer"]["name"])(**config["optimizer"].get('args')) loss = getattr(module_metric, config["loss_fc"]) metrics = [getattr(module_metric, x) for x in config["metrics"]] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) model.summary() # visualize_filters_shape(model) else: """Wrong entry for model_name""" raise Exception( 'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.' ) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], validation_batch_size=config["validation_batch_size"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["optimizer"]["args"]["lr"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"], model_best_path=config['model_best']) data_file_opened.close()
def main(config=None): # convert input images into an hdf5 file overwrite = config['overwrite'] if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], subject_ids=subject_ids, norm_type=config['normalization_type']) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config, re_compile=False) else: # instantiate new model model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], n_base_filters=config["n_base_filters"], activation_name='softmax') optimizer = getattr( opts, config["optimizer"]["name"])(**config["optimizer"].get('args')) loss = getattr(module_metric, config["loss_fc"]) metrics = [getattr(module_metric, x) for x in config["metrics"]] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], validation_batch_size=config["validation_batch_size"], permute=config["permute"], augment=config["augment"], skip_blank=config["skip_blank"], augment_flip=config["flip"], augment_distortion_factor=config["distort"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["optimizer"]["args"]["lr"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"], model_best_path=config['model_best']) data_file_opened.close()
def main(self, overwrite_data=True, overwrite_model=True): # convert input images into an hdf5 file if overwrite_data or not os.path.exists(self.config.data_file): training_files, subject_ids = self.fetch_training_data_files( return_subject_ids=True) write_data_to_file(training_files, self.config.data_file, image_shape=self.config.image_shape, subject_ids=subject_ids) else: print( "Reusing previously written data file. Set overwrite_data to True to overwrite this file." ) data_file_opened = open_data_file(self.config.data_file) if not overwrite_model and os.path.exists(self.config.model_file): model = load_old_model(self.config.model_file) else: # instantiate new model model, context_output_name = isensee2017_model( input_shape=self.config.input_shape, n_labels=self.config.n_labels, initial_learning_rate=self.config.initial_learning_rate, n_base_filters=self.config.n_base_filters, loss_function=self.config.loss_function, shortcut=self.config.shortcut) # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=self.config.batch_size, data_split=self.config.validation_split, overwrite_data=overwrite_data, validation_keys_file=self.config.validation_file, training_keys_file=self.config.training_file, n_labels=self.config.n_labels, labels=self.config.labels, patch_shape=self.config.patch_shape, validation_batch_size=self.config.validation_batch_size, validation_patch_overlap=self.config.validation_patch_overlap, training_patch_overlap=self.config.training_patch_overlap, training_patch_start_offset=self.config. training_patch_start_offset, permute=self.config.permute, augment=self.config.augment, skip_blank=self.config.skip_blank, augment_flip=self.config.flip, augment_distortion_factor=self.config.distort) # run training train_model(model=model, model_file=self.config.model_file, training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=self.config.initial_learning_rate, learning_rate_drop=self.config.learning_rate_drop, learning_rate_patience=self.config.patience, early_stopping_patience=self.config.early_stop, n_epochs=self.config.epochs, niseko=self.config.niseko) data_file_opened.close()
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): training_files, subject_ids = fetch_training_data_files(return_subject_ids=True) write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"], subject_ids=subject_ids) data_file_opened = open_data_file(config["data_file"]) if not overwrite and os.path.exists(config["model_file"]): model = load_old_model(config["model_file"]) # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], # initial_learning_rate=config["initial_learning_rate"], # n_base_filters=config["n_base_filters"]) if config['freeze_encoder']: last_index = list(layer.name for layer in model.layers) \ .index('up_sampling3d_1') for layer in model.layers[:last_index]: layer.trainable = False from keras.optimizers import Adam from unet3d.model.isensee2017 import weighted_dice_coefficient_loss model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss) # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]): # assert new_layer.name == layer.name # new_layer.set_weights(layer.get_weights()) # model = new_model else: # instantiate new model model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"]) model.summary() # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=config["patch_shape"], validation_batch_size=config["validation_batch_size"], validation_patch_overlap=config["validation_patch_overlap"], training_patch_start_offset=config["training_patch_start_offset"], augment=config["augment"], skip_blank=config["skip_blank"]) # run training train_model(model=model, model_file=config["model_file"], training_generator=train_generator, validation_generator=validation_generator, steps_per_epoch=n_train_steps, validation_steps=n_validation_steps, initial_learning_rate=config["initial_learning_rate"], learning_rate_drop=config["learning_rate_drop"], learning_rate_patience=config["patience"], early_stopping_patience=config["early_stop"], n_epochs=config["n_epochs"]) data_file_opened.close()
def train_and_predict(): print('-'*30) print('Loading and preprocessing train data...') print('-'*30) imgs_train, imgs_gtruth_train = load_train_data() imgs_train = np.transpose(imgs_train, (0, 4, 1, 2, 3)) imgs_gtruth_train = np.transpose(imgs_gtruth_train, (0, 4, 1, 2, 3)) print('-'*30) print('Loading and preprocessing validation data...') print('-'*30) imgs_val, imgs_gtruth_val = load_validatation_data() imgs_val = np.transpose(imgs_val, (0, 4, 1, 2, 3)) imgs_gtruth_val = np.transpose(imgs_gtruth_val, (0, 4, 1, 2, 3)) print('-'*30) print('Creating and compiling model...') print('-'*30) # create a model model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"],loss_function=dice_coef_loss) model.summary() #summarize layers #print(model.summary()) # plot graph #plot_model(model, to_file='3d_unet.png') print('-'*30) print('Fitting model...') print('-'*30) #============================================================================ print('training starting..') log_filename = 'outputs/' + image_type +'_model_train.csv' csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True) # early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min') #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5" checkpoint_filepath = 'outputs/' + 'weights.h5' checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min') callbacks_list = [csv_log, checkpoint] callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"], verbose=True)) callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"])) #============================================================================ hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) # validation_split=0.2, model_name = 'outputs/' + image_type + '_model_last' model.save(model_name) # creates a HDF5 file 'my_model.h5'