Esempio n. 1
0
def run_validation_cases(validation_keys_file,
                         model_file,
                         training_modalities,
                         labels,
                         hdf5_file,
                         output_label_map=False,
                         output_dir=".",
                         threshold=0.5,
                         overlap=16,
                         permute=False,
                         isDeeper=False,
                         depth=4):
    validation_indices = pickle_load(validation_keys_file)

    model = None

    if not isDeeper:
        #Recreated the unet_model_3d
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])  #load_old_model(model_file)
    else:
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            depth=depth)

    #Loading the weights to the model
    model.load_weights(model_file)
    data_file = tables.open_file(hdf5_file, "r")
    for index in validation_indices:
        if 'subject_ids' in data_file.root:
            case_directory = os.path.join(
                output_dir, data_file.root.subject_ids[index].decode('utf-8'))
        else:
            case_directory = os.path.join(output_dir,
                                          "validation_case_{}".format(index))
        run_validation_case(data_index=index,
                            output_dir=case_directory,
                            model=model,
                            data_file=data_file,
                            training_modalities=training_modalities,
                            output_label_map=output_label_map,
                            labels=labels,
                            threshold=threshold,
                            overlap=overlap,
                            permute=permute)
    data_file.close()
Esempio n. 2
0
def main():
    if not os.path.exists(config["hdf5_file"]):
        training_files = list()
        for label_file in glob.glob("./data/training/subject-*-label.hdr"):
            training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file))

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])

    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"],
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"], labels=config["labels"], augment=True)

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Esempio n. 3
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["hdf5_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])
    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"],
                              downsize_filters_factor=config["downsize_nb_filters_factor"],
                              pool_size=config["pool_size"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite,
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"])

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Esempio n. 4
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print("Number of Training file Found:", len(training_files))
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])

    print("Opening data file.")
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading existing model file.")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Instantiating new model file.")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"])

    # get training and testing generators
    print("Getting training and testing generators.")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print("Running the training......")
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("Training DONE")
Esempio n. 5
0
    def test_batch_normalization(self):
        model = unet_model_3d(input_shape=(1, 16, 16, 16), depth=2, deconvolution=True, metrics=[], n_labels=1,
                              batch_normalization=True)

        layer_names = [layer.name for layer in model.layers]

        for name in layer_names[:-3]:  # exclude the last convolution layer
            if 'conv3d' in name and 'transpose' not in name:
                self.assertIn(name.replace('conv3d', 'batch_normalization'), layer_names)
Esempio n. 6
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distortion_factor"],
        augment_rotation_factor=config["rotation_factor"],
        mirror=config["mirror"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_path=config["logging_path"])
    data_file_opened.close()
Esempio n. 7
0
def main(overwrite=False):
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            pool_size=config["pool_size"],
            deconvolution=config["deconvolution"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and validation generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        npy_path=config["npy_path"],
        subject_ids_file=config["subject_ids_file"],
        batch_size=config["batch_size"],
        validation_batch_size=config["validation_batch_size"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        training_keys_file=config["training_keys_file"],
        validation_keys_file=config["validation_keys_file"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        augment=config["augment"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        permute=config["permute"],
        image_shape=config["image_shape"],
        patch_shape=config["patch_shape"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
Esempio n. 8
0
def train(ibis_data, input_shape=(96,96,96), batch_size=2, data_split=0.8, num_gpus=None, only_aa=False):
	input_shape = input_shape+(21,) if only_aa else input_shape+(59,)
	model = unet_model_3d(input_shape=input_shape, num_gpus=num_gpus)
	if num_gpus is not None and num_gpus > 1:
		batch_size *= num_gpus
	train, validate = IBISGenerator.get_training_and_validation(ibis_data, input_shape=input_shape, batch_size=batch_size, only_aa=only_aa)

	train_model_generator(
		model=model,
        model_file=os.path.abspath("./molmimic_{}.h5".format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))),
        training_generator=train.generate(),
        validation_generator=validate.generate(),
        steps_per_epoch=train.steps_per_epoch,
        validation_steps=validate.steps_per_epoch,
        initial_learning_rate=0.001,
        learning_rate_drop=0.6,
        learning_rate_epochs=10,
        n_epochs=200
		)
Esempio n. 9
0
def main(overwrite=False):
    # # convert input images into an hdf5 file
    # if overwrite or not os.path.exists(config["data_file"]):
    #     training_files = fetch_training_data_files()
    #
    #     write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])
    # data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='original_uet.png', show_shapes=True)
Esempio n. 10
0
def main(overwrite=False):
    args = get_args.train()
    overwrite = args.overwrite

    # config["data_file"] = get_brats_data_h5_path(args.challenge, args.year,
    #                                              args.inputshape, args.isbiascorrection,
    #                                              args.normalization, args.clahe,
    #                                              args.histmatch)

    # print(config["data_file"])

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators_new(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_steps_file=config["n_steps_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        is_create_patch_index_list_original=config["is_create_patch_index_list_original"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-"*60)
    print("# Load or init model")
    print("-"*60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("init model model")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])

    # model.summary()

    # import nibabel as nib
    # laptop_save_dir = "C:/Users/minhm/Desktop/temp/"
    # desktop_save_dir = "/home/minhvu/Desktop/temp/"
    # save_dir = desktop_save_dir
    # temp_in_path = desktop_save_dir + "template.nii.gz"
    # temp_out_path = desktop_save_dir + "out.nii.gz"
    # temp_out_truth_path = desktop_save_dir + "truth.nii.gz"

    # n_validation_samples = 0
    # validation_samples = list()
    # for i in range(20):
    #     print(i)
    #     x, y = next(train_generator)
    #     hash_x = hash(str(x))
    #     validation_samples.append(hash_x)
    #     n_validation_samples += x.shape[0]

    #     temp_in = nib.load(temp_in_path)
    #     temp_out = nib.Nifti1Image(x[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_path)

    #     temp_out = nib.Nifti1Image(y[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_truth_path)

    # print(n_validation_samples)

    print("-"*60)
    print("# start training")
    print("-"*60)
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Esempio n. 11
0
def train(overwrite=True,
          crop=True,
          challenge="brats",
          year=2018,
          image_shape="160-160-128",
          is_bias_correction="1",
          is_normalize="z",
          is_denoise="0",
          is_hist_match="0",
          is_test="1",
          depth_unet=4,
          n_base_filters_unet=16,
          model_name="unet",
          patch_shape="128-128-128",
          is_crf="0",
          batch_size=1,
          loss="weighted"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=3)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if overwrite or not os.path.exists(data_path):
        prepare_data(overwrite=overwrite,
                     crop=crop,
                     challenge=challenge,
                     year=year,
                     image_shape=image_shape,
                     is_bias_correction=is_bias_correction,
                     is_normalize=is_normalize,
                     is_denoise=is_denoise,
                     is_hist_match=is_hist_match,
                     is_test=is_test)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
        data_file_opened,
        batch_size=batch_size,
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if model_name == "unet":
            print("init unet model")
            model = unet_model_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "densefcn":
            print("init densenet model")
            # config["initial_learning_rate"] = 1e-5
            model = densefcn_model_3d(
                input_shape=config["input_shape"],
                classes=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                nb_dense_block=5,
                nb_layers_per_block=4,
                early_transition=True,
                dropout_rate=0.2,
                loss_function=loss)

        elif model_name == "denseunet":
            print("init denseunet model")
            model = dense_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "resunet":
            print("init resunet model")
            model = res_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        if model_name == "seunet":
            print("init seunet model")
            model = se_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        else:
            print("init isensee model")
            model = isensee2017_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if is_test == "0":
        experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                project_name="train",
                                workspace="vuhoangminh")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    print("data file:", config["data_file"])
    print("model file:", config["model_file"])
    print("training file:", config["training_file"])
    print("validation file:", config["validation_file"])
    print("testing file:", config["testing_file"])

    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
Esempio n. 12
0
def main(overwrite_data=False, overwrite_model=False):
    # run if the data not already stored hdf5
    if overwrite_data or not os.path.exists(config["data_file"]):
        _save_new_h5_datafile(config["data_file"],
                              new_image_shape=config["image_shape"])

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite_model and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model

        print('initializing new isensee model with input shape',
              config['input_shape'])
        '''
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
        '''
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    (train_generator, validation_generator, n_train_steps,
     n_validation_steps) = get_training_and_validation_generators(
         data_file_opened,
         batch_size=config["batch_size"],
         data_split=config["validation_split"],
         overwrite=overwrite_data,
         validation_keys_file=config["validation_file"],
         training_keys_file=config["training_file"],
         n_labels=config["n_labels"],
         labels=config["labels"],
         patch_shape=config["patch_shape"],
         validation_batch_size=config["validation_batch_size"],
         validation_patch_overlap=config["validation_patch_overlap"],
         training_patch_start_offset=config["training_patch_start_offset"],
         permute=config["permute"],
         augment=config["augment"],
         skip_blank=config["skip_blank"],
         augment_flip=config["flip"],
         augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Esempio n. 13
0
import os
import glob

from unet3d.model import unet_model_3d
from data_generator import DataGenerator

DATA_LEN = 5
IMG_SIZE = 224
N_CHANNEL = 16
DATAPATH = '/home/trungdunghoang/Documents/EPFL/3DUnetCNN/data_test'

model = unet_model_3d(input_shape=(DATA_LEN, IMG_SIZE, IMG_SIZE, N_CHANNEL))

train_generator = DataGenerator(DATAPATH)
model.fit_generator(generator=train_generator,
                    steps_per_epoch=len(train_generator),
                    epochs=10,
                    validation_data=train_generator,
                    validation_steps=len(train_generator))
Esempio n. 14
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    print(overwrite or not os.path.exists(config["data_file"]))
    print('path: ', os.path.exists(config["data_file"]))
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        # try:
        write_data_to_file(
            training_files,
            config["data_file"],
            image_shape=config["image_shape"])  #, normalize=False)
        # except:
        #    import pdb; pdb.set_trace()
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])
        print(model.summary())
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=False,  #overwrite, # set to False so that the training idcs 
        # are used as previously; as they are now used for the
        # normalization already in write_data_to_file (above)
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    # normalize the dataset if required
    # use only the training img (training_keys_file)
    fetch_training_data_files()

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Esempio n. 15
0
import os
from unet3d.model import unet_model_3d
from unet3d.training import load_old_model, train_model
from unet3d.generator import ExampleSphereGenerator

from keras import backend as K
K.set_image_dim_ordering('tf')

input_shape = (96, 96, 96, 1)
model = unet_model_3d(input_shape=input_shape)

train_generator, validation_generator = ExampleSphereGenerator.get_training_and_validation(
    input_shape, cnt=5, border=10, batch_size=20, n_samples=500)

train_model(model=model,
            model_file=os.path.abspath("./SphereCNN.h5"),
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=train_generator.num_steps,
            validation_steps=validation_generator.num_steps,
            initial_learning_rate=0.00001,
            learning_rate_drop=0.5,
            learning_rate_epochs=10,
            n_epochs=50)