Пример #1
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"]
                           )  #config["image_shape"] = (144, 144, 144)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        #print('start model computing')
        # model = unet_model_3d(input_shape=config["input_shape"],                # 4+(32, 32, 32)
        #                       pool_size=config["pool_size"],                    #config["pool_size"] = (2, 2, 2), maxpooling size
        #                       n_labels=config["n_labels"],                      #config["n_labels"] = len(config["labels"])
        #                       initial_learning_rate=config["initial_learning_rate"],        #config["initial_learning_rate"] = 0.00001
        #                       deconvolution=config["deconvolution"])                        #config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

        model = custom_unet(reu2018)
    #print('model loaded')
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],  #config["batch_size"] = 6
        data_split=config["validation_split"],  #validation_split = 0.8
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print('training mdel')
    train_model(
        model=model,
        model_file=config[
            "model_file"],  #config["model_file"] = os.path.abspath("tumor_segmentation_model.h5")
        training_generator=train_generator,
        validation_generator=validation_generator,
        steps_per_epoch=n_train_steps,
        validation_steps=n_validation_steps,
        initial_learning_rate=config["initial_learning_rate"],
        learning_rate_drop=config["learning_rate_drop"],
        learning_rate_patience=config["patience"],
        early_stopping_patience=config["early_stop"],
        n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("model has been trained already")
Пример #2
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print(training_files)
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    # get training and testing generators - generate pickel files containing IDS
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    data_file_opened.close()
Пример #3
0
def main():
    kwargs = vars(parse_args())
    prediction_dir = os.path.abspath(kwargs.pop("prediction_dir"))
    output_label_map = not kwargs.pop("no_label_map")
    for key, value in kwargs.items():
        if value:
            if key == "modalities":
                config["training_modalities"] = value
            else:
                config[key] = value
    filenames, subject_ids = fetch_brats_2020_files(config["training_modalities"], group="Validation",
                                                    include_truth=False, return_subject_ids=True)
    if not os.path.exists(config["data_file"]):
        write_data_to_file(filenames, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids, save_truth=False)
        pickle_dump(list(range(len(subject_ids))), config["validation_file"])

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=output_label_map,
                         output_dir=prediction_dir,
                         test=False,
                         output_basename=kwargs["output_basename"],
                         permute=config["permute"])
    for filename_list, subject_id in zip(filenames, subject_ids):
        prediction_filename = os.path.join(prediction_dir, kwargs["output_basename"].format(subject=subject_id))
        print("Resampling:", prediction_filename)
        ref = nib.load(filename_list[0])
        pred = nib.load(prediction_filename)
        pred_resampled = resample_to_img(pred, ref, interpolation="nearest")
        pred_resampled.to_filename(prediction_filename)
Пример #4
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["hdf5_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])
    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"],
                              downsize_filters_factor=config["downsize_nb_filters_factor"],
                              pool_size=config["pool_size"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite,
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"])

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Пример #5
0
def main():
    if not os.path.exists(config["hdf5_file"]):
        training_files = list()
        for label_file in glob.glob("./data/training/subject-*-label.hdr"):
            training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file))

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])

    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"],
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"], labels=config["labels"], augment=True)

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Пример #6
0
def main(args):

    prediction_dir = os.path.abspath("./headneck/prediction_test/" +
                                     args.organ.lower())

    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir)

    test_data_files, subject_ids = fetch_test_data_files(
        return_subject_ids=True)

    if not os.path.exists(config["data_file"]):
        write_data_to_file(test_data_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    if not os.path.exists(config["test_file"]):
        test_list = list(range(len(subject_ids)))
        pickle_dump(test_list, config["test_file"])

    run_validation_cases(validation_keys_file=config["test_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=True,
                         output_dir=prediction_dir)

    header = ("Background", "Organ")
    masking_functions = (get_background_mask, get_organ_mask)
    rows = list()

    prediction_path = "./headneck/prediction_test/" + args.organ.lower() + "/"

    for case_folder in glob.glob(prediction_path + "*/"):
        truth_file = os.path.join(case_folder, "truth.nii.gz")
        truth_image = nib.load(truth_file)
        truth = truth_image.get_data()
        prediction_file = os.path.join(case_folder, "prediction.nii.gz")
        prediction_image = nib.load(prediction_file)
        prediction = prediction_image.get_data()
        rows.append([
            dice_coefficient(func(truth), func(prediction))
            for func in masking_functions
        ])

    df = pd.DataFrame.from_records(rows, columns=header)
    df.to_csv(prediction_path + "headneck_scores.csv")

    scores = dict()
    for index, score in enumerate(df.columns):
        values = df.values.T[index]
        scores[score] = values[np.isnan(values) == False]

    plt.boxplot(list(scores.values()), labels=list(scores.keys()))
    plt.ylabel("Dice Coefficient")
    plt.savefig(prediction_path + "test_scores_boxplot.png")
    plt.close()
Пример #7
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading old model file from the location: ",
              config["model_file"])
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Creating new model at the location: ", config["model_file"])
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print("Running the Training. Model file:", config["model_file"])
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Пример #8
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print("Number of Training file Found:", len(training_files))
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])

    print("Opening data file.")
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading existing model file.")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Instantiating new model file.")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"])

    # get training and testing generators
    print("Getting training and testing generators.")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print("Running the training......")
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("Training DONE")
def main(overwrite=False):
    # convert input images into an hdf5 file
    # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = att_res_ds_unet.att_res_ds_unet_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                                      initial_learning_rate=config["initial_learning_rate"],
                                                      n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='att_res_ds_uet.png', show_shapes=True)

    # get training and testing generators
    # ../unet3d/generator.py
    # 创建生成器(generator),用于后面训练
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    # ../unet3d/training.py
    # 训练一个keras模型
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Пример #10
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distortion_factor"],
        augment_rotation_factor=config["rotation_factor"],
        mirror=config["mirror"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_path=config["logging_path"])
    data_file_opened.close()
Пример #11
0
def main(overwrite=True):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
Пример #12
0
def _save_new_h5_datafile(data_file_h5, new_image_shape):
    training_files = _fetch_training_data_files('private')

    # write all the data files into the hdf5 file
    # if necessary crop the data to the new dimensions (if less than original)
    # or add the 0 layer around it (if more than original)
    write_data_to_file(training_files,
                       data_file_h5,
                       image_shape=new_image_shape)
Пример #13
0
def main():
    # convert_brats_data(pre_config["data_original"], pre_config["data_preprocessed"], overwrite=False)

    subject_ids, training_files = fetch_training_data_files(
        pre_config["data_preprocessed"])
    pickle_dump(subject_ids, pre_config["subject_ids_file"])

    write_data_to_file(pre_config["npy_path"],
                       subject_ids,
                       training_files,
                       image_shape=pre_config["image_shape"])
Пример #14
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.source_data_file) or not os.path.exists(self.config.target_data_file):
            '''
            We write two files, one with source samples and one with target samples.
            '''
            source_data_files, target_data_files, subject_ids_source, subject_ids_target = self.fetch_training_data_files(return_subject_ids=True)

            if not os.path.exists(self.config.source_data_file) or overwrite_data:
                write_data_to_file(source_data_files, self.config.source_data_file, image_shape=self.config.image_shape,
                               subject_ids=subject_ids_source)
            if not os.path.exists(self.config.target_data_file) or overwrite_data:
                write_data_to_file(target_data_files, self.config.target_data_file, image_shape=self.config.image_shape,
                               subject_ids=subject_ids_target)
        else:
            print("Reusing previously written data file. Set overwrite_data to True to overwrite this file.")

        source_data = open_data_file(self.config.source_data_file)
        target_data = open_data_file(self.config.target_data_file)


        # instantiate new model, compile = False because the compilation is made in JDOT.py

        model, context_output_name = isensee2017_model(input_shape=self.config.input_shape, n_labels=self.config.n_labels,
                                      initial_learning_rate=self.config.initial_learning_rate,
                                      n_base_filters=self.config.n_base_filters,
                                      loss_function=self.config.loss_function,
                                      shortcut=self.config.shortcut,
                                      depth=self.config.depth,
                                      compile=False)
        # get training and testing generators
        if not self.config.depth_jdot:
            context_output_name = []
        jd = JDOT(model, config=self.config, source_data=source_data, target_data=target_data, context_output_name=context_output_name)
        # m = jd.load_old_model(self.config.model_file)
        # print(m)
        if self.config.load_base_model:
            print("Loading trained model")
            jd.load_old_model(os.path.abspath("Data/saved_models/model_center_"+self.config.source_center)+".h5")
        elif not self.config.overwrite_model:
            jd.load_old_model(self.config.model_file)
        else:
            print("Creating new model, this will overwrite your old model")
        jd.compile_model()
        if self.config.train_jdot:
            jd.train_model(self.config.epochs)
        else:
            jd.train_model_on_source(self.config.epochs)
        jd.evaluate_model()

        source_data.close()
        target_data.close()
Пример #15
0
def main():
    kwargs = vars(parse_args())
    prediction_dir = os.path.abspath(kwargs.pop("prediction_dir"))
    output_label_map = not kwargs.pop("no_label_map")
    for key, value in kwargs.items():
        if value:
            if key == "modalities":
                config["training_modalities"] = value
            else:
                config[key] = value

    validate_path = kwargs["validate_path"]
    subject_ids = list()
    filenames = list()
    blacklist = []
    for root, dirs, files in os.walk(validate_path):
        for f in files:
            subject_id = f.split('.')[0]
            if subject_id not in blacklist:
                subject_ids.append(subject_id)
                subject_files = list()
                subject_files.append(validate_path + '/' + f)
                filenames.append(tuple(subject_files))

    if not os.path.exists(config["data_file"]):
        write_data_to_file(filenames,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           save_truth=False)
        pickle_dump(list(range(len(subject_ids))), config["validation_file"])

    run_validation_cases(validation_keys_file=config["validation_file"],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["data_file"],
                         output_label_map=output_label_map,
                         output_dir=prediction_dir,
                         test=False,
                         output_basename=kwargs["output_basename"],
                         permute=config["permute"])
    for filename_list, subject_id in zip(filenames, subject_ids):
        prediction_filename = os.path.join(
            prediction_dir,
            kwargs["output_basename"].format(subject=subject_id))
        print("Resampling:", prediction_filename)
        ref = nib.load(filename_list[0])
        pred = nib.load(prediction_filename)
        pred_resampled = resample_to_img(pred, ref, interpolation="nearest")
        pred_resampled.to_filename(prediction_filename)
Пример #16
0
    def main(self, overwrite_data=True):
        self.config.validation_split = 0.0
        self.config.data_file = os.path.abspath("Data/generated_data/" +
                                                self.config.data_set +
                                                "_testing.h5")
        self.config.training_file = os.path.abspath("Data/generated_data/" +
                                                    self.config.data_set +
                                                    "_testing.pkl")
        self.config.validation_file = os.path.abspath(
            "Data/generated_data/" + self.config.data_set +
            "_testing_validation_ids.pkl")
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            testing_files, subject_ids = self.fetch_testing_data_files(
                return_subject_ids=True)
            write_data_to_file(testing_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        data_file_opened = open_data_file(self.config.data_file)
        testing_split, _ = get_validation_split(
            data_file_opened,
            data_split=0,
            overwrite_data=self.config.overwrite_data,
            training_file=self.config.training_file,
            validation_file=self.config.validation_file)
        # get training and testing generators
        # train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        #     data_file_opened,
        #     batch_size=self.config.batch_size,
        #     data_split=self.config.validation_split,
        #     overwrite_data=overwrite_data,
        #     validation_keys_file=self.config.validation_file,
        #     training_keys_file=self.config.training_file,
        #     n_labels=self.config.n_labels,
        #     labels=self.config.labels,
        #     patch_shape=self.config.patch_shape,
        #     validation_batch_size=self.config.validation_batch_size,
        #     validation_patch_overlap=self.config.validation_patch_overlap,
        #     training_patch_start_offset=self.config.training_patch_start_offset,
        #     permute=self.config.permute,
        #     augment=self.config.augment,
        #     skip_blank=self.config.skip_blank,
        #     augment_flip=self.config.flip,
        #     augment_distortion_factor=self.config.distort)

        data_file_opened.close()
Пример #17
0
    def fetch_training_data_files(self):
        data_files = list()
        for subject_dir in glob.glob(
                os.path.join(os.path.dirname(__file__),
                             "../Data/data_" + self.config.data_set,
                             "training", "*")):
            subject_center = subject_dir[
                -9:
                -7]  # Retrieve for the MICCAI16 data-set the center of the patient

        self.ids.append(os.path.basename(subject_dir))
        subject_files = list()
        for modality in self.config.training_modalities + [
                "./" + self.config.GT
        ]:  # Autre solution ? "/ManualSegmentation/ pour miccai16"
            subject_files.append(
                os.path.join(subject_dir, modality +
                             ".nii.gz"))  # + "/Preprocessed/ pour miccai16
        data_files.append(tuple(subject_files))
        write_data_to_file(data_files,
                           self.config.source_data_file,
                           image_shape=self.config.image_shape,
                           subject_ids=self.ids)
Пример #18
0
def prepare_data(args):

    data_dir = get_h5_training_dir(BRATS_DIR, "data")

    # make dir
    if not os.path.exists(data_dir):
        print_separator()
        print("making dir", data_dir)
        os.makedirs(data_dir)

    print_section("convert input images into an hdf5 file")

    data_filename = get_training_h5_filename(datatype="data", args=args)

    print(data_filename)

    data_file_path = os.path.join(data_dir, data_filename)

    print("save to", data_file_path)

    dataset = get_dataset(
        is_test=args.is_test, is_bias_correction=args.is_bias_correction, is_denoise=args.is_denoise)

    print("reading folder:", dataset)

    if args.overwrite or not os.path.exists(data_file_path):
        training_files = fetch_training_data_files(dataset)
        write_data_to_file(training_files, data_file_path,
                           config=config,
                           image_shape=get_shape_from_string(args.image_shape),
                           brats_dir=BRATS_DIR,
                           crop=args.crop,
                           is_normalize=args.is_normalize,
                           is_hist_match=args.is_hist_match,
                           dataset=dataset,
                           is_denoise=args.is_denoise)
Пример #19
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    print(overwrite or not os.path.exists(config["data_file"]))
    print('path: ', os.path.exists(config["data_file"]))
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        # try:
        write_data_to_file(
            training_files,
            config["data_file"],
            image_shape=config["image_shape"])  #, normalize=False)
        # except:
        #    import pdb; pdb.set_trace()
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])
        print(model.summary())
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=False,  #overwrite, # set to False so that the training idcs 
        # are used as previously; as they are now used for the
        # normalization already in write_data_to_file (above)
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    # normalize the dataset if required
    # use only the training img (training_keys_file)
    fetch_training_data_files()

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Пример #20
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    with open('isensemodel_original.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # Save Model
    plot_model(model, to_file="isensemodel_original.png", show_shapes=True)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Пример #21
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
        # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
        #                              initial_learning_rate=config["initial_learning_rate"],
        #                              n_base_filters=config["n_base_filters"])

        if config['freeze_encoder']:
            last_index = list(layer.name for layer in model.layers) \
                .index('up_sampling3d_1')
            for layer in model.layers[:last_index]:
                layer.trainable = False
            from keras.optimizers import Adam
            from unet3d.model.isensee2017 import weighted_dice_coefficient_loss
            model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss)
        # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]):
        #     assert new_layer.name == layer.name
        #     new_layer.set_weights(layer.get_weights())
        # model = new_model
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment=config["augment"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Пример #22
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        # if this happens, then the code wont care what is in "model_name" in config because it will take whatever
        # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful
        # with this.
        model = load_old_model(config, re_compile=False)
        model.summary()
        # visualize_filters_shape(model)
    else:
        # instantiate new model
        if (config["model_name"] == "3d_unet_residual"):
            """3D Unet Residual Model"""
            model = isensee2017_model(input_shape=config["input_shape"],
                                      n_labels=config["n_labels"],
                                      n_base_filters=config["n_base_filters"],
                                      activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        elif (config["model_name"] == "attention_unet"):
            """Attention Unet Model"""
            model = attention_unet_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                n_base_filters=config["n_base_filters"],
                activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        else:
            """Wrong entry for model_name"""
            raise Exception(
                'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.'
            )

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
Пример #23
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,
                                  n_labels=config["n_labels"],
                                  GPU=config["GPU"])
    else:
        # instantiate new model
        base_model, model = unet_model_3d_multiGPU(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            GPU=config["GPU"])
    # Save Model
    plot_model(base_model,
               to_file="liver_segmentation_model_581_resize_1GPU.png",
               show_shapes=True)
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"] * config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"] * config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print('INFO: Training Details', '\n Batch Size : ',
          config["batch_size"] * config["GPU"], '\n Epoch Size : ',
          config["n_epochs"])

    # For debugging ONLY
    # n_train_steps = 10

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                base_model=base_model)
    data_file_opened.close()
Пример #24
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,n_labels=config["n_labels"],GPU=config["GPU"])
    else:
        # instantiate new model HYbrid Dense-Unet model from HDense project
        parser = argparse.ArgumentParser(description='Keras DenseUnet Training')
        parser.add_argument('-b', type=int, default= 1 )#config["batch_size"])
        parser.add_argument('-input_size', type=int, default= config["patch_shape"][0]) # 224 ) 
        parser.add_argument('-input_cols', type=int, default= config["patch_shape"][2]) #  8)
        args = parser.parse_args()
        #print(args.b)
        #model = dense_rnn_net(args)
        base_model = denseunet_3d(args)
        sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
        model = base_model
        base_model.compile(optimizer=sgd, loss=[weighted_crossentropy])

    # get training and testing generators

     # Save Model
    plot_model(base_model,to_file="liver_segmentation_HDenseUnet.png",show_shapes=True)

    # Open the file
    with open(config['model_summaryfile'],'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        base_model.summary(line_length=150,print_fn=lambda x: fh.write(x + '\n'))

    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"]*config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"]*config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
      
    print('INFO: Training Details','\n Batch Size : ',config["batch_size"]*config["GPU"]
                                  ,'\n Epoch Size : ',config["n_epochs"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],base_model=base_model)
    data_file_opened.close()
Пример #25
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    pdb.set_trace()
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           modality_names=config['all_modalities'],
                           subject_ids=subject_ids,
                           mean_std_file=config['mean_std_file'])
#     return
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators


#     pdb.set_trace()
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        pred_specific=config['pred_specific'],
        overlap_label=config['overlap_label_generator'],
        for_final_val=config['for_final_val'])

    # run training
    #     pdb.set_trace()
    time_0 = time.time()
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_file=config['logging_file'])
    print('Training time:', sec2hms(time.time() - time_0))
    data_file_opened.close()
Пример #26
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            training_files, subject_ids = self.fetch_training_data_files(
                return_subject_ids=True)
            write_data_to_file(training_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        else:
            print(
                "Reusing previously written data file. Set overwrite_data to True to overwrite this file."
            )

        data_file_opened = open_data_file(self.config.data_file)

        if not overwrite_model and os.path.exists(self.config.model_file):
            model = load_old_model(self.config.model_file)
        else:
            # instantiate new model

            model, context_output_name = isensee2017_model(
                input_shape=self.config.input_shape,
                n_labels=self.config.n_labels,
                initial_learning_rate=self.config.initial_learning_rate,
                n_base_filters=self.config.n_base_filters,
                loss_function=self.config.loss_function,
                shortcut=self.config.shortcut)

        # get training and testing generators

        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
            data_file_opened,
            batch_size=self.config.batch_size,
            data_split=self.config.validation_split,
            overwrite_data=overwrite_data,
            validation_keys_file=self.config.validation_file,
            training_keys_file=self.config.training_file,
            n_labels=self.config.n_labels,
            labels=self.config.labels,
            patch_shape=self.config.patch_shape,
            validation_batch_size=self.config.validation_batch_size,
            validation_patch_overlap=self.config.validation_patch_overlap,
            training_patch_overlap=self.config.training_patch_overlap,
            training_patch_start_offset=self.config.
            training_patch_start_offset,
            permute=self.config.permute,
            augment=self.config.augment,
            skip_blank=self.config.skip_blank,
            augment_flip=self.config.flip,
            augment_distortion_factor=self.config.distort)

        # run training
        train_model(model=model,
                    model_file=self.config.model_file,
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=self.config.initial_learning_rate,
                    learning_rate_drop=self.config.learning_rate_drop,
                    learning_rate_patience=self.config.patience,
                    early_stopping_patience=self.config.early_stop,
                    n_epochs=self.config.epochs,
                    niseko=self.config.niseko)

        data_file_opened.close()
Пример #27
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           norm_type=config['normalization_type'])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config, re_compile=False)
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  n_base_filters=config["n_base_filters"],
                                  activation_name='softmax')

        optimizer = getattr(
            opts,
            config["optimizer"]["name"])(**config["optimizer"].get('args'))
        loss = getattr(module_metric, config["loss_fc"])
        metrics = [getattr(module_metric, x) for x in config["metrics"]]
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
Пример #28
0
def main(overwrite=False):

    # convert input images into an hdf5 file
    if overwrite or not (os.path.exists(config["data_file0"])
                         and os.path.exists(config["data_file1"])):

        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)
        training_files0, training_files1 = training_files
        subject_ids0, subject_ids1 = subject_ids

        if not os.path.exists(config["data_file0"]):
            write_data_to_file(training_files0,
                               config["data_file0"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids0)
        if not os.path.exists(config["data_file1"]):
            write_data_to_file(training_files1,
                               config["data_file1"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids1)

    data_file_opened0 = open_data_file(config["data_file0"])
    data_file_opened1 = open_data_file(config["data_file1"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = siam3dunet_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

        #model = testnet_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"])

        #if os.path.exists(config["model_file"]):
        #    model = load_weights(config["model_file"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened0,
        data_file_opened1,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    '''
    train_data = []
    train_label = []
    for i in range(n_train_steps):
        a, b = next(train_generator)
        train_data.append(a)
        train_label.append(b)

        a0, a1 = a

        for i in range(len(a0[0,0,0,0,:])):
            a0_0 = a0[0,2,:,:,i]
            if a0_0.min() == a0_0.max():
                a0_0 = a0_0 - a0_0
            else:                
                a0_0 = (a0_0-a0_0.min())/(a0_0.max()-a0_0.min())
        #print (a0_0.shape)
        #print (a0_0.max())
        #print (a0_0.min())
            imsave(f'vis_img/{i}.jpg', a0_0)
        raise
    '''

    test_data, test_label = next(validation_generator)
    test_g = (test_data, test_label)

    train_data, train_label = next(train_generator)
    train_g = (train_data, train_label)

    if not overwrite and os.path.exists(config["model_file"]):

        txt_file = open(f"output_log.txt", "w")

        #res = model.evaluate(test_data, test_label)
        #print (res)
        pre = model.predict(test_data)
        #print ([i for i in pre[0]])
        #print ([int(i) for i in test_label[0]])
        for i in range(len(pre[0])):
            txt_file.write(
                str(pre[0][i][0]) + ' ' + str(test_label[0][i]) + "\n")

        pre_train = model.predict(train_data)
        for i in range(len(pre_train[0])):
            txt_file.write(
                str(pre_train[0][i][0]) + ' ' + str(train_label[0][i]) + "\n")

        txt_file.close()
        raise

    # run training

    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=test_g,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    '''
    for i in range(len(train_label)):
        #scores = model.evaluate(train_data[i], train_label[i], verbose=1)
        scores = model.predict(train_data[i])
        print (len(scores[0]))
    '''

    data_file_opened0.close()
    data_file_opened1.close()