Exemple #1
0
    def test_random_patch_start(self):
        self.create_data_file(len_x=10, len_y=10, len_z=10)

        validation_split = 0.8
        batch_size = 10
        validation_batch_size = 3
        patch_shape = (5, 5, 5)
        random_start = (3, 3, 3)
        overlap = 2

        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
                                                            self.training_keys_file, self.validation_keys_file,
                                                            data_split=validation_split,
                                                            validation_batch_size=validation_batch_size,
                                                            patch_shape=patch_shape,
                                                            training_patch_start_offset=random_start,
                                                            validation_patch_overlap=overlap,
                                                            skip_blank=False)

        training_generator, validation_generator, n_training_steps, n_validation_steps = generators

        expected_training_samples = int(np.round(self.n_samples * validation_split)) * 2**3

        self.verify_generator(training_generator, n_training_steps, batch_size, expected_training_samples)

        expected_validation_samples = int(np.round(self.n_samples * (1 - validation_split))) * 4**3

        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
                              expected_validation_samples)

        self.data_file.close()
        self.rm_tmp_files()
Exemple #2
0
    def test_patch_generators(self):
        self.create_data_file(len_x=4, len_y=4, len_z=4)

        validation_split = 0.8
        batch_size = 10
        validation_batch_size = 3
        patch_shape = (2, 2, 2)

        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
                                                            self.training_keys_file, self.validation_keys_file,
                                                            data_split=validation_split,
                                                            validation_batch_size=validation_batch_size,
                                                            patch_shape=patch_shape,
                                                            skip_blank=False)
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators

        expected_training_samples = int(np.round(self.n_samples * validation_split)) * 2**3

        self.verify_generator(training_generator, n_training_steps, batch_size, expected_training_samples)

        expected_validation_samples = int(np.round(self.n_samples * (1 - validation_split))) * 2**3

        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
                              expected_validation_samples)

        self.data_file.close()
        self.rm_tmp_files()
Exemple #3
0
    def test_get_training_and_validation_generators(self):
        self.create_data_file()

        validation_split = 0.8
        batch_size = 3
        validation_batch_size = 3

        generators = get_training_and_validation_generators(data_file=self.data_file,
                                                            batch_size=batch_size,
                                                            n_labels=self.n_labels,
                                                            training_keys_file=self.training_keys_file,
                                                            validation_keys_file=self.validation_keys_file,
                                                            data_split=validation_split,
                                                            validation_batch_size=validation_batch_size,
                                                            skip_blank=False)
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators

        self.verify_generator(training_generator, n_training_steps, batch_size,
                              np.round(validation_split * self.n_samples))

        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
                              np.round((1 - validation_split) * self.n_samples))

        self.data_file.close()
        self.rm_tmp_files()
Exemple #4
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"]
                           )  #config["image_shape"] = (144, 144, 144)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        #print('start model computing')
        # model = unet_model_3d(input_shape=config["input_shape"],                # 4+(32, 32, 32)
        #                       pool_size=config["pool_size"],                    #config["pool_size"] = (2, 2, 2), maxpooling size
        #                       n_labels=config["n_labels"],                      #config["n_labels"] = len(config["labels"])
        #                       initial_learning_rate=config["initial_learning_rate"],        #config["initial_learning_rate"] = 0.00001
        #                       deconvolution=config["deconvolution"])                        #config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

        model = custom_unet(reu2018)
    #print('model loaded')
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],  #config["batch_size"] = 6
        data_split=config["validation_split"],  #validation_split = 0.8
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print('training mdel')
    train_model(
        model=model,
        model_file=config[
            "model_file"],  #config["model_file"] = os.path.abspath("tumor_segmentation_model.h5")
        training_generator=train_generator,
        validation_generator=validation_generator,
        steps_per_epoch=n_train_steps,
        validation_steps=n_validation_steps,
        initial_learning_rate=config["initial_learning_rate"],
        learning_rate_drop=config["learning_rate_drop"],
        learning_rate_patience=config["patience"],
        early_stopping_patience=config["early_stop"],
        n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("model has been trained already")
Exemple #5
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print(training_files)
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    # get training and testing generators - generate pickel files containing IDS
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    data_file_opened.close()
Exemple #6
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["hdf5_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])
    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"],
                              downsize_filters_factor=config["downsize_nb_filters_factor"],
                              pool_size=config["pool_size"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite,
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"])

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Exemple #7
0
def main():
    if not os.path.exists(config["hdf5_file"]):
        training_files = list()
        for label_file in glob.glob("./data/training/subject-*-label.hdr"):
            training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file))

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])

    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"],
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"], labels=config["labels"], augment=True)

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
Exemple #8
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading old model file from the location: ",
              config["model_file"])
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Creating new model at the location: ", config["model_file"])
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print("Running the Training. Model file:", config["model_file"])
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
def main(overwrite=False):
    # convert input images into an hdf5 file
    # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = att_res_ds_unet.att_res_ds_unet_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                                      initial_learning_rate=config["initial_learning_rate"],
                                                      n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='att_res_ds_uet.png', show_shapes=True)

    # get training and testing generators
    # ../unet3d/generator.py
    # 创建生成器(generator),用于后面训练
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    # ../unet3d/training.py
    # 训练一个keras模型
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #10
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print("Number of Training file Found:", len(training_files))
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])

    print("Opening data file.")
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading existing model file.")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Instantiating new model file.")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"])

    # get training and testing generators
    print("Getting training and testing generators.")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print("Running the training......")
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("Training DONE")
Exemple #11
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distortion_factor"],
        augment_rotation_factor=config["rotation_factor"],
        mirror=config["mirror"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_path=config["logging_path"])
    data_file_opened.close()
Exemple #12
0
    def test_generator_with_permutations(self):
        self.create_data_file(len_x=5, len_y=5, len_z=5, n_channels=5)
        batch_size = 2
        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
                                                            self.training_keys_file, self.validation_keys_file,
                                                            permute=True)
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators

        _ = next(training_generator)

        self.rm_tmp_files()
def main(overwrite=False):
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])

    # get training and validation generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        npy_path=config["npy_path"],
        subject_ids_file=config["subject_ids_file"],
        batch_size=config["batch_size"],
        validation_batch_size=config["validation_batch_size"],
        n_labels=config["n_labels"],
        labels=config["labels"],

        training_keys_file=config["training_keys_file"],
        validation_keys_file=config["validation_keys_file"],
        data_split=config["validation_split"],
        overwrite=overwrite,

        augment=config["augment"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        permute=config["permute"],

        image_shape=config["image_shape"],
        patch_shape=config["patch_shape"],

        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],

        skip_blank=config["skip_blank"]
        )

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
Exemple #14
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,
                                  n_labels=config["n_labels"],
                                  GPU=config["GPU"])
    else:
        # instantiate new model
        base_model, model = unet_model_3d_multiGPU(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            GPU=config["GPU"])
    # Save Model
    plot_model(base_model,
               to_file="liver_segmentation_model_581_resize_1GPU.png",
               show_shapes=True)
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"] * config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"] * config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print('INFO: Training Details', '\n Batch Size : ',
          config["batch_size"] * config["GPU"], '\n Epoch Size : ',
          config["n_epochs"])

    # For debugging ONLY
    # n_train_steps = 10

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                base_model=base_model)
    data_file_opened.close()
Exemple #15
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    print(overwrite or not os.path.exists(config["data_file"]))
    print('path: ', os.path.exists(config["data_file"]))
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        # try:
        write_data_to_file(
            training_files,
            config["data_file"],
            image_shape=config["image_shape"])  #, normalize=False)
        # except:
        #    import pdb; pdb.set_trace()
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])
        print(model.summary())
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=False,  #overwrite, # set to False so that the training idcs 
        # are used as previously; as they are now used for the
        # normalization already in write_data_to_file (above)
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    # normalize the dataset if required
    # use only the training img (training_keys_file)
    fetch_training_data_files()

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
        # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
        #                              initial_learning_rate=config["initial_learning_rate"],
        #                              n_base_filters=config["n_base_filters"])

        if config['freeze_encoder']:
            last_index = list(layer.name for layer in model.layers) \
                .index('up_sampling3d_1')
            for layer in model.layers[:last_index]:
                layer.trainable = False
            from keras.optimizers import Adam
            from unet3d.model.isensee2017 import weighted_dice_coefficient_loss
            model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss)
        # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]):
        #     assert new_layer.name == layer.name
        #     new_layer.set_weights(layer.get_weights())
        # model = new_model
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment=config["augment"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #17
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            training_files, subject_ids = self.fetch_training_data_files(
                return_subject_ids=True)
            write_data_to_file(training_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        else:
            print(
                "Reusing previously written data file. Set overwrite_data to True to overwrite this file."
            )

        data_file_opened = open_data_file(self.config.data_file)

        if not overwrite_model and os.path.exists(self.config.model_file):
            model = load_old_model(self.config.model_file)
        else:
            # instantiate new model

            model, context_output_name = isensee2017_model(
                input_shape=self.config.input_shape,
                n_labels=self.config.n_labels,
                initial_learning_rate=self.config.initial_learning_rate,
                n_base_filters=self.config.n_base_filters,
                loss_function=self.config.loss_function,
                shortcut=self.config.shortcut)

        # get training and testing generators

        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
            data_file_opened,
            batch_size=self.config.batch_size,
            data_split=self.config.validation_split,
            overwrite_data=overwrite_data,
            validation_keys_file=self.config.validation_file,
            training_keys_file=self.config.training_file,
            n_labels=self.config.n_labels,
            labels=self.config.labels,
            patch_shape=self.config.patch_shape,
            validation_batch_size=self.config.validation_batch_size,
            validation_patch_overlap=self.config.validation_patch_overlap,
            training_patch_overlap=self.config.training_patch_overlap,
            training_patch_start_offset=self.config.
            training_patch_start_offset,
            permute=self.config.permute,
            augment=self.config.augment,
            skip_blank=self.config.skip_blank,
            augment_flip=self.config.flip,
            augment_distortion_factor=self.config.distort)

        # run training
        train_model(model=model,
                    model_file=self.config.model_file,
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=self.config.initial_learning_rate,
                    learning_rate_drop=self.config.learning_rate_drop,
                    learning_rate_patience=self.config.patience,
                    early_stopping_patience=self.config.early_stop,
                    n_epochs=self.config.epochs,
                    niseko=self.config.niseko)

        data_file_opened.close()
Exemple #18
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    with open('isensemodel_original.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # Save Model
    plot_model(model, to_file="isensemodel_original.png", show_shapes=True)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #19
0
def main(overwrite_data=False, overwrite_model=False):
    # run if the data not already stored hdf5
    if overwrite_data or not os.path.exists(config["data_file"]):
        _save_new_h5_datafile(config["data_file"],
                              new_image_shape=config["image_shape"])

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite_model and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model

        print('initializing new isensee model with input shape',
              config['input_shape'])
        '''
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
        '''
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    (train_generator, validation_generator, n_train_steps,
     n_validation_steps) = get_training_and_validation_generators(
         data_file_opened,
         batch_size=config["batch_size"],
         data_split=config["validation_split"],
         overwrite=overwrite_data,
         validation_keys_file=config["validation_file"],
         training_keys_file=config["training_file"],
         n_labels=config["n_labels"],
         labels=config["labels"],
         patch_shape=config["patch_shape"],
         validation_batch_size=config["validation_batch_size"],
         validation_patch_overlap=config["validation_patch_overlap"],
         training_patch_start_offset=config["training_patch_start_offset"],
         permute=config["permute"],
         augment=config["augment"],
         skip_blank=config["skip_blank"],
         augment_flip=config["flip"],
         augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #20
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,n_labels=config["n_labels"],GPU=config["GPU"])
    else:
        # instantiate new model HYbrid Dense-Unet model from HDense project
        parser = argparse.ArgumentParser(description='Keras DenseUnet Training')
        parser.add_argument('-b', type=int, default= 1 )#config["batch_size"])
        parser.add_argument('-input_size', type=int, default= config["patch_shape"][0]) # 224 ) 
        parser.add_argument('-input_cols', type=int, default= config["patch_shape"][2]) #  8)
        args = parser.parse_args()
        #print(args.b)
        #model = dense_rnn_net(args)
        base_model = denseunet_3d(args)
        sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
        model = base_model
        base_model.compile(optimizer=sgd, loss=[weighted_crossentropy])

    # get training and testing generators

     # Save Model
    plot_model(base_model,to_file="liver_segmentation_HDenseUnet.png",show_shapes=True)

    # Open the file
    with open(config['model_summaryfile'],'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        base_model.summary(line_length=150,print_fn=lambda x: fh.write(x + '\n'))

    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"]*config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"]*config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
      
    print('INFO: Training Details','\n Batch Size : ',config["batch_size"]*config["GPU"]
                                  ,'\n Epoch Size : ',config["n_epochs"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],base_model=base_model)
    data_file_opened.close()
def main(overwrite=False):
    # convert input images into an hdf5 file
    pdb.set_trace()
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           modality_names=config['all_modalities'],
                           subject_ids=subject_ids,
                           mean_std_file=config['mean_std_file'])
#     return
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators


#     pdb.set_trace()
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        pred_specific=config['pred_specific'],
        overlap_label=config['overlap_label_generator'],
        for_final_val=config['for_final_val'])

    # run training
    #     pdb.set_trace()
    time_0 = time.time()
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_file=config['logging_file'])
    print('Training time:', sec2hms(time.time() - time_0))
    data_file_opened.close()
Exemple #22
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           norm_type=config['normalization_type'])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config, re_compile=False)
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  n_base_filters=config["n_base_filters"],
                                  activation_name='softmax')

        optimizer = getattr(
            opts,
            config["optimizer"]["name"])(**config["optimizer"].get('args'))
        loss = getattr(module_metric, config["loss_fc"])
        metrics = [getattr(module_metric, x) for x in config["metrics"]]
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
Exemple #23
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        # if this happens, then the code wont care what is in "model_name" in config because it will take whatever
        # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful
        # with this.
        model = load_old_model(config, re_compile=False)
        model.summary()
        # visualize_filters_shape(model)
    else:
        # instantiate new model
        if (config["model_name"] == "3d_unet_residual"):
            """3D Unet Residual Model"""
            model = isensee2017_model(input_shape=config["input_shape"],
                                      n_labels=config["n_labels"],
                                      n_base_filters=config["n_base_filters"],
                                      activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        elif (config["model_name"] == "attention_unet"):
            """Attention Unet Model"""
            model = attention_unet_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                n_base_filters=config["n_base_filters"],
                activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        else:
            """Wrong entry for model_name"""
            raise Exception(
                'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.'
            )

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
def main(overwrite=False):

    # convert input images into an hdf5 file
    if overwrite or not (os.path.exists(config["data_file0"])
                         and os.path.exists(config["data_file1"])):

        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)
        training_files0, training_files1 = training_files
        subject_ids0, subject_ids1 = subject_ids

        if not os.path.exists(config["data_file0"]):
            write_data_to_file(training_files0,
                               config["data_file0"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids0)
        if not os.path.exists(config["data_file1"]):
            write_data_to_file(training_files1,
                               config["data_file1"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids1)

    data_file_opened0 = open_data_file(config["data_file0"])
    data_file_opened1 = open_data_file(config["data_file1"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = siam3dunet_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

        #model = testnet_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"])

        #if os.path.exists(config["model_file"]):
        #    model = load_weights(config["model_file"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened0,
        data_file_opened1,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    '''
    train_data = []
    train_label = []
    for i in range(n_train_steps):
        a, b = next(train_generator)
        train_data.append(a)
        train_label.append(b)

        a0, a1 = a

        for i in range(len(a0[0,0,0,0,:])):
            a0_0 = a0[0,2,:,:,i]
            if a0_0.min() == a0_0.max():
                a0_0 = a0_0 - a0_0
            else:                
                a0_0 = (a0_0-a0_0.min())/(a0_0.max()-a0_0.min())
        #print (a0_0.shape)
        #print (a0_0.max())
        #print (a0_0.min())
            imsave(f'vis_img/{i}.jpg', a0_0)
        raise
    '''

    test_data, test_label = next(validation_generator)
    test_g = (test_data, test_label)

    train_data, train_label = next(train_generator)
    train_g = (train_data, train_label)

    if not overwrite and os.path.exists(config["model_file"]):

        txt_file = open(f"output_log.txt", "w")

        #res = model.evaluate(test_data, test_label)
        #print (res)
        pre = model.predict(test_data)
        #print ([i for i in pre[0]])
        #print ([int(i) for i in test_label[0]])
        for i in range(len(pre[0])):
            txt_file.write(
                str(pre[0][i][0]) + ' ' + str(test_label[0][i]) + "\n")

        pre_train = model.predict(train_data)
        for i in range(len(pre_train[0])):
            txt_file.write(
                str(pre_train[0][i][0]) + ' ' + str(train_label[0][i]) + "\n")

        txt_file.close()
        raise

    # run training

    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=test_g,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    '''
    for i in range(len(train_label)):
        #scores = model.evaluate(train_data[i], train_label[i], verbose=1)
        scores = model.predict(train_data[i])
        print (len(scores[0]))
    '''

    data_file_opened0.close()
    data_file_opened1.close()