コード例 #1
0
ファイル: train.py プロジェクト: zjngjng/3DUnetCNN
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["hdf5_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])
    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"],
                              downsize_filters_factor=config["downsize_nb_filters_factor"],
                              pool_size=config["pool_size"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite,
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"])

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
コード例 #2
0
ファイル: train.py プロジェクト: CocoInParis/2d-segmentation
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"]
                           )  #config["image_shape"] = (144, 144, 144)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        #print('start model computing')
        # model = unet_model_3d(input_shape=config["input_shape"],                # 4+(32, 32, 32)
        #                       pool_size=config["pool_size"],                    #config["pool_size"] = (2, 2, 2), maxpooling size
        #                       n_labels=config["n_labels"],                      #config["n_labels"] = len(config["labels"])
        #                       initial_learning_rate=config["initial_learning_rate"],        #config["initial_learning_rate"] = 0.00001
        #                       deconvolution=config["deconvolution"])                        #config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

        model = custom_unet(reu2018)
    #print('model loaded')
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],  #config["batch_size"] = 6
        data_split=config["validation_split"],  #validation_split = 0.8
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print('training mdel')
    train_model(
        model=model,
        model_file=config[
            "model_file"],  #config["model_file"] = os.path.abspath("tumor_segmentation_model.h5")
        training_generator=train_generator,
        validation_generator=validation_generator,
        steps_per_epoch=n_train_steps,
        validation_steps=n_validation_steps,
        initial_learning_rate=config["initial_learning_rate"],
        learning_rate_drop=config["learning_rate_drop"],
        learning_rate_patience=config["patience"],
        early_stopping_patience=config["early_stop"],
        n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("model has been trained already")
コード例 #3
0
def main():
    if not os.path.exists(config["hdf5_file"]):
        training_files = list()
        for label_file in glob.glob("./data/training/subject-*-label.hdr"):
            training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file))

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])

    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"],
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"], labels=config["labels"], augment=True)

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
コード例 #4
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading old model file from the location: ",
              config["model_file"])
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Creating new model at the location: ", config["model_file"])
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print("Running the Training. Model file:", config["model_file"])
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #5
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print("Number of Training file Found:", len(training_files))
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])

    print("Opening data file.")
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading existing model file.")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Instantiating new model file.")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"])

    # get training and testing generators
    print("Getting training and testing generators.")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print("Running the training......")
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("Training DONE")
コード例 #6
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = att_res_ds_unet.att_res_ds_unet_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                                      initial_learning_rate=config["initial_learning_rate"],
                                                      n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='att_res_ds_uet.png', show_shapes=True)

    # get training and testing generators
    # ../unet3d/generator.py
    # 创建生成器(generator),用于后面训练
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    # ../unet3d/training.py
    # 训练一个keras模型
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #7
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distortion_factor"],
        augment_rotation_factor=config["rotation_factor"],
        mirror=config["mirror"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_path=config["logging_path"])
    data_file_opened.close()
コード例 #8
0
def main(overwrite=False):
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])

    # get training and validation generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        npy_path=config["npy_path"],
        subject_ids_file=config["subject_ids_file"],
        batch_size=config["batch_size"],
        validation_batch_size=config["validation_batch_size"],
        n_labels=config["n_labels"],
        labels=config["labels"],

        training_keys_file=config["training_keys_file"],
        validation_keys_file=config["validation_keys_file"],
        data_split=config["validation_split"],
        overwrite=overwrite,

        augment=config["augment"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        permute=config["permute"],

        image_shape=config["image_shape"],
        patch_shape=config["patch_shape"],

        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],

        skip_blank=config["skip_blank"]
        )

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
コード例 #9
0
def train(overwrite=True,
          crop=True,
          challenge="brats",
          year=2018,
          image_shape="160-160-128",
          is_bias_correction="1",
          is_normalize="z",
          is_denoise="0",
          is_hist_match="0",
          is_test="1",
          depth_unet=4,
          n_base_filters_unet=16,
          model_name="unet",
          patch_shape="128-128-128",
          is_crf="0",
          batch_size=1,
          loss="weighted"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=3)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if overwrite or not os.path.exists(data_path):
        prepare_data(overwrite=overwrite,
                     crop=crop,
                     challenge=challenge,
                     year=year,
                     image_shape=image_shape,
                     is_bias_correction=is_bias_correction,
                     is_normalize=is_normalize,
                     is_denoise=is_denoise,
                     is_hist_match=is_hist_match,
                     is_test=is_test)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
        data_file_opened,
        batch_size=batch_size,
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if model_name == "unet":
            print("init unet model")
            model = unet_model_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "densefcn":
            print("init densenet model")
            # config["initial_learning_rate"] = 1e-5
            model = densefcn_model_3d(
                input_shape=config["input_shape"],
                classes=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                nb_dense_block=5,
                nb_layers_per_block=4,
                early_transition=True,
                dropout_rate=0.2,
                loss_function=loss)

        elif model_name == "denseunet":
            print("init denseunet model")
            model = dense_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "resunet":
            print("init resunet model")
            model = res_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        if model_name == "seunet":
            print("init seunet model")
            model = se_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        else:
            print("init isensee model")
            model = isensee2017_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if is_test == "0":
        experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                project_name="train",
                                workspace="vuhoangminh")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    print("data file:", config["data_file"])
    print("model file:", config["model_file"])
    print("training file:", config["training_file"])
    print("validation file:", config["validation_file"])
    print("testing file:", config["testing_file"])

    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
コード例 #10
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    with open('isensemodel_original.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # Save Model
    plot_model(model, to_file="isensemodel_original.png", show_shapes=True)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #11
0
def main(overwrite_data=False, overwrite_model=False):
    # run if the data not already stored hdf5
    if overwrite_data or not os.path.exists(config["data_file"]):
        _save_new_h5_datafile(config["data_file"],
                              new_image_shape=config["image_shape"])

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite_model and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model

        print('initializing new isensee model with input shape',
              config['input_shape'])
        '''
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
        '''
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    (train_generator, validation_generator, n_train_steps,
     n_validation_steps) = get_training_and_validation_generators(
         data_file_opened,
         batch_size=config["batch_size"],
         data_split=config["validation_split"],
         overwrite=overwrite_data,
         validation_keys_file=config["validation_file"],
         training_keys_file=config["training_file"],
         n_labels=config["n_labels"],
         labels=config["labels"],
         patch_shape=config["patch_shape"],
         validation_batch_size=config["validation_batch_size"],
         validation_patch_overlap=config["validation_patch_overlap"],
         training_patch_start_offset=config["training_patch_start_offset"],
         permute=config["permute"],
         augment=config["augment"],
         skip_blank=config["skip_blank"],
         augment_flip=config["flip"],
         augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #12
0
ファイル: train.py プロジェクト: nggbaobkit/3DUnet
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        # if this happens, then the code wont care what is in "model_name" in config because it will take whatever
        # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful
        # with this.
        model = load_old_model(config, re_compile=False)
        model.summary()
        # visualize_filters_shape(model)
    else:
        # instantiate new model
        if (config["model_name"] == "3d_unet_residual"):
            """3D Unet Residual Model"""
            model = isensee2017_model(input_shape=config["input_shape"],
                                      n_labels=config["n_labels"],
                                      n_base_filters=config["n_base_filters"],
                                      activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        elif (config["model_name"] == "attention_unet"):
            """Attention Unet Model"""
            model = attention_unet_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                n_base_filters=config["n_base_filters"],
                activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        else:
            """Wrong entry for model_name"""
            raise Exception(
                'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.'
            )

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
コード例 #13
0
ファイル: train_old.py プロジェクト: maikia/StrokeUNET
def main(overwrite=False):
    # convert input images into an hdf5 file
    print(overwrite or not os.path.exists(config["data_file"]))
    print('path: ', os.path.exists(config["data_file"]))
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        # try:
        write_data_to_file(
            training_files,
            config["data_file"],
            image_shape=config["image_shape"])  #, normalize=False)
        # except:
        #    import pdb; pdb.set_trace()
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])
        print(model.summary())
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=False,  #overwrite, # set to False so that the training idcs 
        # are used as previously; as they are now used for the
        # normalization already in write_data_to_file (above)
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    # normalize the dataset if required
    # use only the training img (training_keys_file)
    fetch_training_data_files()

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #14
0
def main(overwrite=False):
    args = get_args.train()
    overwrite = args.overwrite

    # config["data_file"] = get_brats_data_h5_path(args.challenge, args.year,
    #                                              args.inputshape, args.isbiascorrection,
    #                                              args.normalization, args.clahe,
    #                                              args.histmatch)

    # print(config["data_file"])

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators_new(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_steps_file=config["n_steps_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        is_create_patch_index_list_original=config["is_create_patch_index_list_original"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-"*60)
    print("# Load or init model")
    print("-"*60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("init model model")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])

    # model.summary()

    # import nibabel as nib
    # laptop_save_dir = "C:/Users/minhm/Desktop/temp/"
    # desktop_save_dir = "/home/minhvu/Desktop/temp/"
    # save_dir = desktop_save_dir
    # temp_in_path = desktop_save_dir + "template.nii.gz"
    # temp_out_path = desktop_save_dir + "out.nii.gz"
    # temp_out_truth_path = desktop_save_dir + "truth.nii.gz"

    # n_validation_samples = 0
    # validation_samples = list()
    # for i in range(20):
    #     print(i)
    #     x, y = next(train_generator)
    #     hash_x = hash(str(x))
    #     validation_samples.append(hash_x)
    #     n_validation_samples += x.shape[0]

    #     temp_in = nib.load(temp_in_path)
    #     temp_out = nib.Nifti1Image(x[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_path)

    #     temp_out = nib.Nifti1Image(y[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_truth_path)

    # print(n_validation_samples)

    print("-"*60)
    print("# start training")
    print("-"*60)
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #15
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.patch_shape in ["160-192-13", "160-192-15", "160-192-17"]:
        config["initial_learning_rate"] = 1e-4
        print("lr updated...")
    if args.patch_shape in ["160-192-3"]:
        config["initial_learning_rate"] = 1e-2
        print("lr updated...")

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)

    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=args.loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if args.model == "seunet":
            print("init seunet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss,
                is_unet_original=False)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "isensee":
            print("init isensee model")
            model = isensee25d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
コード例 #16
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            training_files, subject_ids = self.fetch_training_data_files(
                return_subject_ids=True)
            write_data_to_file(training_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        else:
            print(
                "Reusing previously written data file. Set overwrite_data to True to overwrite this file."
            )

        data_file_opened = open_data_file(self.config.data_file)

        if not overwrite_model and os.path.exists(self.config.model_file):
            model = load_old_model(self.config.model_file)
        else:
            # instantiate new model

            model, context_output_name = isensee2017_model(
                input_shape=self.config.input_shape,
                n_labels=self.config.n_labels,
                initial_learning_rate=self.config.initial_learning_rate,
                n_base_filters=self.config.n_base_filters,
                loss_function=self.config.loss_function,
                shortcut=self.config.shortcut)

        # get training and testing generators

        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
            data_file_opened,
            batch_size=self.config.batch_size,
            data_split=self.config.validation_split,
            overwrite_data=overwrite_data,
            validation_keys_file=self.config.validation_file,
            training_keys_file=self.config.training_file,
            n_labels=self.config.n_labels,
            labels=self.config.labels,
            patch_shape=self.config.patch_shape,
            validation_batch_size=self.config.validation_batch_size,
            validation_patch_overlap=self.config.validation_patch_overlap,
            training_patch_overlap=self.config.training_patch_overlap,
            training_patch_start_offset=self.config.
            training_patch_start_offset,
            permute=self.config.permute,
            augment=self.config.augment,
            skip_blank=self.config.skip_blank,
            augment_flip=self.config.flip,
            augment_distortion_factor=self.config.distort)

        # run training
        train_model(model=model,
                    model_file=self.config.model_file,
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=self.config.initial_learning_rate,
                    learning_rate_drop=self.config.learning_rate_drop,
                    learning_rate_patience=self.config.patience,
                    early_stopping_patience=self.config.early_stop,
                    n_epochs=self.config.epochs,
                    niseko=self.config.niseko)

        data_file_opened.close()
コード例 #17
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
        # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
        #                              initial_learning_rate=config["initial_learning_rate"],
        #                              n_base_filters=config["n_base_filters"])

        if config['freeze_encoder']:
            last_index = list(layer.name for layer in model.layers) \
                .index('up_sampling3d_1')
            for layer in model.layers[:last_index]:
                layer.trainable = False
            from keras.optimizers import Adam
            from unet3d.model.isensee2017 import weighted_dice_coefficient_loss
            model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss)
        # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]):
        #     assert new_layer.name == layer.name
        #     new_layer.set_weights(layer.get_weights())
        # model = new_model
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment=config["augment"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #18
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,n_labels=config["n_labels"],GPU=config["GPU"])
    else:
        # instantiate new model HYbrid Dense-Unet model from HDense project
        parser = argparse.ArgumentParser(description='Keras DenseUnet Training')
        parser.add_argument('-b', type=int, default= 1 )#config["batch_size"])
        parser.add_argument('-input_size', type=int, default= config["patch_shape"][0]) # 224 ) 
        parser.add_argument('-input_cols', type=int, default= config["patch_shape"][2]) #  8)
        args = parser.parse_args()
        #print(args.b)
        #model = dense_rnn_net(args)
        base_model = denseunet_3d(args)
        sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
        model = base_model
        base_model.compile(optimizer=sgd, loss=[weighted_crossentropy])

    # get training and testing generators

     # Save Model
    plot_model(base_model,to_file="liver_segmentation_HDenseUnet.png",show_shapes=True)

    # Open the file
    with open(config['model_summaryfile'],'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        base_model.summary(line_length=150,print_fn=lambda x: fh.write(x + '\n'))

    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"]*config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"]*config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
      
    print('INFO: Training Details','\n Batch Size : ',config["batch_size"]*config["GPU"]
                                  ,'\n Epoch Size : ',config["n_epochs"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],base_model=base_model)
    data_file_opened.close()
コード例 #19
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if "casnet" in args.model:
        config["data_type_generator"] = 'cascaded'
    elif "sepnet" in args.model:
        config["data_type_generator"] = 'separated'
    else:
        config["data_type_generator"] = 'combined'

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test,
        data_type_generator=config["data_type_generator"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    config["input_shape"] = config["input_shape"][0:len(config["input_shape"]
                                                        ) - 1]
    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        if "casnet" in args.model:
            args.loss = "casweighted"
        model = generate_model(config["model_file"], loss_function=args.loss)
    else:
        # instantiate new model
        if args.model == "isensee":
            print("init isensee model")
            model = isensee2d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "casnet_v1":
            print("init casnet_v1 model")
            model = casnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v2":
            print("init casnet_v2 model")
            model = casnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v3":
            print("init casnet_v3 model")
            model = casnet_v3(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v4":
            print("init casnet_v4 model")
            model = casnet_v4(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v5":
            print("init casnet_v5 model")
            model = casnet_v5(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v6":
            print("init casnet_v6 model")
            model = casnet_v6(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v7":
            print("init casnet_v7 model")
            model = casnet_v7(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v8":
            print("init casnet_v8 model")
            model = casnet_v8(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v1":
            print("init sepnet_v1 model")
            model = sepnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v2":
            print("init sepnet_v2 model")
            model = sepnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)

        else:
            raise ValueError("Model is NotImplemented. Please check")

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    if args.model == "isensee":
        config["initial_learning_rate"] = 1e-6
    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
コード例 #20
0
ファイル: train.py プロジェクト: votnhan/3DUnet
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           norm_type=config['normalization_type'])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config, re_compile=False)
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  n_base_filters=config["n_base_filters"],
                                  activation_name='softmax')

        optimizer = getattr(
            opts,
            config["optimizer"]["name"])(**config["optimizer"].get('args'))
        loss = getattr(module_metric, config["loss_fc"])
        metrics = [getattr(module_metric, x) for x in config["metrics"]]
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
コード例 #21
0
import os
from unet3d.model import unet_model_3d
from unet3d.training import load_old_model, train_model
from unet3d.generator import ExampleSphereGenerator

from keras import backend as K
K.set_image_dim_ordering('tf')

input_shape = (96, 96, 96, 1)
model = unet_model_3d(input_shape=input_shape)

train_generator, validation_generator = ExampleSphereGenerator.get_training_and_validation(
    input_shape, cnt=5, border=10, batch_size=20, n_samples=500)

train_model(model=model,
            model_file=os.path.abspath("./SphereCNN.h5"),
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=train_generator.num_steps,
            validation_steps=validation_generator.num_steps,
            initial_learning_rate=0.00001,
            learning_rate_drop=0.5,
            learning_rate_epochs=10,
            n_epochs=50)
コード例 #22
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    pdb.set_trace()
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           modality_names=config['all_modalities'],
                           subject_ids=subject_ids,
                           mean_std_file=config['mean_std_file'])
#     return
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators


#     pdb.set_trace()
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        pred_specific=config['pred_specific'],
        overlap_label=config['overlap_label_generator'],
        for_final_val=config['for_final_val'])

    # run training
    #     pdb.set_trace()
    time_0 = time.time()
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_file=config['logging_file'])
    print('Training time:', sec2hms(time.time() - time_0))
    data_file_opened.close()
コード例 #23
0
def finetune(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if args.name != "0":
        model_path = args.name

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    folder = os.path.join(BRATS_DIR, "database", "model", "base")

    if not os.path.exists(config["model_file"]):
        model_baseline_path = get_model_baseline_path(folder=folder, args=args)
        if model_baseline_path is None:
            raise ValueError("can not fine baseline model. Please check")
        else:
            config["model_file"] = model_baseline_path

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    make_dir(config["training_file"])

    print_section("get training and testing generators")
    if args.model_dim == 3:
        from unet3d.generator import get_training_and_validation_and_testing_generators
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            is_create_patch_index_list_original=config[
                "is_create_patch_index_list_original"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"])
    elif args.model_dim == 25:
        from unet25d.generator import get_training_and_validation_and_testing_generators25d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)
    else:
        from unet2d.generator import get_training_and_validation_and_testing_generators2d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    print(">> update config file")
    config.update(config_finetune)
    if not os.path.exists(config["model_file"]):
        raise Exception("{} model file not found. Please try again".format(
            config["model_file"]))
    else:
        from unet3d.utils.model_utils import generate_model
        print(">> load old and generate model")
        model = generate_model(
            config["model_file"],
            initial_learning_rate=config["initial_learning_rate"],
            loss_function=args.loss,
            weight_tv_to_main_loss=args.weight_tv_to_main_loss)
        model.summary()

    # run training
    print("-" * 60)
    print("# start finetuning")
    print("-" * 60)

    print("Number of training steps: ", n_train_steps)
    print("Number of validation steps: ", n_validation_steps)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        args=args,
        dir_read_write="finetune",
        is_finetune=True)

    config["model_file"] = model_path

    if os.path.exists(config["model_file"]):
        print("{} existed. Will skip!!!".format(config["model_file"]))
    else:

        if args.is_test == "1":
            config["n_epochs"] = 5

        if args.is_test == "0":
            experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                    project_name="finetune",
                                    workspace="vuhoangminh")
        else:
            experiment = None

        if args.model_dim == 2 and args.model == "isensee":
            config["initial_learning_rate"] = 1e-7

        print(config["initial_learning_rate"], config["learning_rate_drop"])
        print("data file:", config["data_file"])
        print("model file:", config["model_file"])
        print("training file:", config["training_file"])
        print("validation file:", config["validation_file"])
        print("testing file:", config["testing_file"])

        train_model(experiment=experiment,
                    model=model,
                    model_file=config["model_file"],
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=config["initial_learning_rate"],
                    learning_rate_drop=config["learning_rate_drop"],
                    learning_rate_patience=config["patience"],
                    early_stopping_patience=config["early_stop"],
                    n_epochs=config["n_epochs"])

        if args.is_test == "0":
            experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
コード例 #24
0
ファイル: train_2.py プロジェクト: cchmc-dll/ai_training
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,
                                  n_labels=config["n_labels"],
                                  GPU=config["GPU"])
    else:
        # instantiate new model
        base_model, model = unet_model_3d_multiGPU(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            GPU=config["GPU"])
    # Save Model
    plot_model(base_model,
               to_file="liver_segmentation_model_581_resize_1GPU.png",
               show_shapes=True)
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"] * config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"] * config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print('INFO: Training Details', '\n Batch Size : ',
          config["batch_size"] * config["GPU"], '\n Epoch Size : ',
          config["n_epochs"])

    # For debugging ONLY
    # n_train_steps = 10

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                base_model=base_model)
    data_file_opened.close()
コード例 #25
0
def main(overwrite=False):

    # convert input images into an hdf5 file
    if overwrite or not (os.path.exists(config["data_file0"])
                         and os.path.exists(config["data_file1"])):

        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)
        training_files0, training_files1 = training_files
        subject_ids0, subject_ids1 = subject_ids

        if not os.path.exists(config["data_file0"]):
            write_data_to_file(training_files0,
                               config["data_file0"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids0)
        if not os.path.exists(config["data_file1"]):
            write_data_to_file(training_files1,
                               config["data_file1"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids1)

    data_file_opened0 = open_data_file(config["data_file0"])
    data_file_opened1 = open_data_file(config["data_file1"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = siam3dunet_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

        #model = testnet_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"])

        #if os.path.exists(config["model_file"]):
        #    model = load_weights(config["model_file"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened0,
        data_file_opened1,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    '''
    train_data = []
    train_label = []
    for i in range(n_train_steps):
        a, b = next(train_generator)
        train_data.append(a)
        train_label.append(b)

        a0, a1 = a

        for i in range(len(a0[0,0,0,0,:])):
            a0_0 = a0[0,2,:,:,i]
            if a0_0.min() == a0_0.max():
                a0_0 = a0_0 - a0_0
            else:                
                a0_0 = (a0_0-a0_0.min())/(a0_0.max()-a0_0.min())
        #print (a0_0.shape)
        #print (a0_0.max())
        #print (a0_0.min())
            imsave(f'vis_img/{i}.jpg', a0_0)
        raise
    '''

    test_data, test_label = next(validation_generator)
    test_g = (test_data, test_label)

    train_data, train_label = next(train_generator)
    train_g = (train_data, train_label)

    if not overwrite and os.path.exists(config["model_file"]):

        txt_file = open(f"output_log.txt", "w")

        #res = model.evaluate(test_data, test_label)
        #print (res)
        pre = model.predict(test_data)
        #print ([i for i in pre[0]])
        #print ([int(i) for i in test_label[0]])
        for i in range(len(pre[0])):
            txt_file.write(
                str(pre[0][i][0]) + ' ' + str(test_label[0][i]) + "\n")

        pre_train = model.predict(train_data)
        for i in range(len(pre_train[0])):
            txt_file.write(
                str(pre_train[0][i][0]) + ' ' + str(train_label[0][i]) + "\n")

        txt_file.close()
        raise

    # run training

    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=test_g,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    '''
    for i in range(len(train_label)):
        #scores = model.evaluate(train_data[i], train_label[i], verbose=1)
        scores = model.predict(train_data[i])
        print (len(scores[0]))
    '''

    data_file_opened0.close()
    data_file_opened1.close()