Exemple #1
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading old model file from the location: ",
              config["model_file"])
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Creating new model at the location: ", config["model_file"])
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print("Running the Training. Model file:", config["model_file"])
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #2
0
    def load_model(self):
        model, context_output_name = isensee2017_model(input_shape=self.config.input_shape, n_labels=self.config.n_labels,
                                      initial_learning_rate=self.config.initial_learning_rate,
                                      n_base_filters=self.config.n_base_filters,
                                      loss_function=self.config.loss_function,
                                      shortcut=self.config.shortcut,
                                      compile=False)

        jd = JDOT(model, config=self.config, context_output_name=context_output_name)
        jd.load_old_model(self.config.model_file)
        jd.compile_model()
        self.jd = jd
Exemple #3
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.source_data_file) or not os.path.exists(self.config.target_data_file):
            '''
            We write two files, one with source samples and one with target samples.
            '''
            source_data_files, target_data_files, subject_ids_source, subject_ids_target = self.fetch_training_data_files(return_subject_ids=True)

            if not os.path.exists(self.config.source_data_file) or overwrite_data:
                write_data_to_file(source_data_files, self.config.source_data_file, image_shape=self.config.image_shape,
                               subject_ids=subject_ids_source)
            if not os.path.exists(self.config.target_data_file) or overwrite_data:
                write_data_to_file(target_data_files, self.config.target_data_file, image_shape=self.config.image_shape,
                               subject_ids=subject_ids_target)
        else:
            print("Reusing previously written data file. Set overwrite_data to True to overwrite this file.")

        source_data = open_data_file(self.config.source_data_file)
        target_data = open_data_file(self.config.target_data_file)


        # instantiate new model, compile = False because the compilation is made in JDOT.py

        model, context_output_name = isensee2017_model(input_shape=self.config.input_shape, n_labels=self.config.n_labels,
                                      initial_learning_rate=self.config.initial_learning_rate,
                                      n_base_filters=self.config.n_base_filters,
                                      loss_function=self.config.loss_function,
                                      shortcut=self.config.shortcut,
                                      depth=self.config.depth,
                                      compile=False)
        # get training and testing generators
        if not self.config.depth_jdot:
            context_output_name = []
        jd = JDOT(model, config=self.config, source_data=source_data, target_data=target_data, context_output_name=context_output_name)
        # m = jd.load_old_model(self.config.model_file)
        # print(m)
        if self.config.load_base_model:
            print("Loading trained model")
            jd.load_old_model(os.path.abspath("Data/saved_models/model_center_"+self.config.source_center)+".h5")
        elif not self.config.overwrite_model:
            jd.load_old_model(self.config.model_file)
        else:
            print("Creating new model, this will overwrite your old model")
        jd.compile_model()
        if self.config.train_jdot:
            jd.train_model(self.config.epochs)
        else:
            jd.train_model_on_source(self.config.epochs)
        jd.evaluate_model()

        source_data.close()
        target_data.close()
def main(overwrite=False):
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])

    # get training and validation generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        npy_path=config["npy_path"],
        subject_ids_file=config["subject_ids_file"],
        batch_size=config["batch_size"],
        validation_batch_size=config["validation_batch_size"],
        n_labels=config["n_labels"],
        labels=config["labels"],

        training_keys_file=config["training_keys_file"],
        validation_keys_file=config["validation_keys_file"],
        data_split=config["validation_split"],
        overwrite=overwrite,

        augment=config["augment"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        permute=config["permute"],

        image_shape=config["image_shape"],
        patch_shape=config["patch_shape"],

        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],

        skip_blank=config["skip_blank"]
        )

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
Exemple #5
0
def main(overwrite=False):
    # # convert input images into an hdf5 file
    # # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    # if overwrite or not os.path.exists(config["data_file"]):
    #     training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
    #
    #     write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
    #                        subject_ids=subject_ids)
    # data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='isensee_unet.png', show_shapes=True)
def main(overwrite=False):
    # convert input images into an hdf5 file
    pdb.set_trace()
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           modality_names=config['all_modalities'],
                           subject_ids=subject_ids,
                           mean_std_file=config['mean_std_file'])
#     return
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators


#     pdb.set_trace()
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        pred_specific=config['pred_specific'],
        overlap_label=config['overlap_label_generator'],
        for_final_val=config['for_final_val'])

    # run training
    #     pdb.set_trace()
    time_0 = time.time()
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_file=config['logging_file'])
    print('Training time:', sec2hms(time.time() - time_0))
    data_file_opened.close()
def train(overwrite=True,
          crop=True,
          challenge="brats",
          year=2018,
          image_shape="160-160-128",
          is_bias_correction="1",
          is_normalize="z",
          is_denoise="0",
          is_hist_match="0",
          is_test="1",
          depth_unet=4,
          n_base_filters_unet=16,
          model_name="unet",
          patch_shape="128-128-128",
          is_crf="0",
          batch_size=1,
          loss="weighted"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=3)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if overwrite or not os.path.exists(data_path):
        prepare_data(overwrite=overwrite,
                     crop=crop,
                     challenge=challenge,
                     year=year,
                     image_shape=image_shape,
                     is_bias_correction=is_bias_correction,
                     is_normalize=is_normalize,
                     is_denoise=is_denoise,
                     is_hist_match=is_hist_match,
                     is_test=is_test)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
        data_file_opened,
        batch_size=batch_size,
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if model_name == "unet":
            print("init unet model")
            model = unet_model_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "densefcn":
            print("init densenet model")
            # config["initial_learning_rate"] = 1e-5
            model = densefcn_model_3d(
                input_shape=config["input_shape"],
                classes=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                nb_dense_block=5,
                nb_layers_per_block=4,
                early_transition=True,
                dropout_rate=0.2,
                loss_function=loss)

        elif model_name == "denseunet":
            print("init denseunet model")
            model = dense_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "resunet":
            print("init resunet model")
            model = res_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        if model_name == "seunet":
            print("init seunet model")
            model = se_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        else:
            print("init isensee model")
            model = isensee2017_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if is_test == "0":
        experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                project_name="train",
                                workspace="vuhoangminh")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    print("data file:", config["data_file"])
    print("model file:", config["model_file"])
    print("training file:", config["training_file"])
    print("validation file:", config["validation_file"])
    print("testing file:", config["testing_file"])

    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
Exemple #8
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    with open('isensemodel_original.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # Save Model
    plot_model(model, to_file="isensemodel_original.png", show_shapes=True)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
Exemple #9
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        # if this happens, then the code wont care what is in "model_name" in config because it will take whatever
        # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful
        # with this.
        model = load_old_model(config, re_compile=False)
        model.summary()
        # visualize_filters_shape(model)
    else:
        # instantiate new model
        if (config["model_name"] == "3d_unet_residual"):
            """3D Unet Residual Model"""
            model = isensee2017_model(input_shape=config["input_shape"],
                                      n_labels=config["n_labels"],
                                      n_base_filters=config["n_base_filters"],
                                      activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        elif (config["model_name"] == "attention_unet"):
            """Attention Unet Model"""
            model = attention_unet_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                n_base_filters=config["n_base_filters"],
                activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        else:
            """Wrong entry for model_name"""
            raise Exception(
                'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.'
            )

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
Exemple #10
0
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           norm_type=config['normalization_type'])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config, re_compile=False)
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  n_base_filters=config["n_base_filters"],
                                  activation_name='softmax')

        optimizer = getattr(
            opts,
            config["optimizer"]["name"])(**config["optimizer"].get('args'))
        loss = getattr(module_metric, config["loss_fc"])
        metrics = [getattr(module_metric, x) for x in config["metrics"]]
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
Exemple #11
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            training_files, subject_ids = self.fetch_training_data_files(
                return_subject_ids=True)
            write_data_to_file(training_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        else:
            print(
                "Reusing previously written data file. Set overwrite_data to True to overwrite this file."
            )

        data_file_opened = open_data_file(self.config.data_file)

        if not overwrite_model and os.path.exists(self.config.model_file):
            model = load_old_model(self.config.model_file)
        else:
            # instantiate new model

            model, context_output_name = isensee2017_model(
                input_shape=self.config.input_shape,
                n_labels=self.config.n_labels,
                initial_learning_rate=self.config.initial_learning_rate,
                n_base_filters=self.config.n_base_filters,
                loss_function=self.config.loss_function,
                shortcut=self.config.shortcut)

        # get training and testing generators

        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
            data_file_opened,
            batch_size=self.config.batch_size,
            data_split=self.config.validation_split,
            overwrite_data=overwrite_data,
            validation_keys_file=self.config.validation_file,
            training_keys_file=self.config.training_file,
            n_labels=self.config.n_labels,
            labels=self.config.labels,
            patch_shape=self.config.patch_shape,
            validation_batch_size=self.config.validation_batch_size,
            validation_patch_overlap=self.config.validation_patch_overlap,
            training_patch_overlap=self.config.training_patch_overlap,
            training_patch_start_offset=self.config.
            training_patch_start_offset,
            permute=self.config.permute,
            augment=self.config.augment,
            skip_blank=self.config.skip_blank,
            augment_flip=self.config.flip,
            augment_distortion_factor=self.config.distort)

        # run training
        train_model(model=model,
                    model_file=self.config.model_file,
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=self.config.initial_learning_rate,
                    learning_rate_drop=self.config.learning_rate_drop,
                    learning_rate_patience=self.config.patience,
                    early_stopping_patience=self.config.early_stop,
                    n_epochs=self.config.epochs,
                    niseko=self.config.niseko)

        data_file_opened.close()
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
        # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
        #                              initial_learning_rate=config["initial_learning_rate"],
        #                              n_base_filters=config["n_base_filters"])

        if config['freeze_encoder']:
            last_index = list(layer.name for layer in model.layers) \
                .index('up_sampling3d_1')
            for layer in model.layers[:last_index]:
                layer.trainable = False
            from keras.optimizers import Adam
            from unet3d.model.isensee2017 import weighted_dice_coefficient_loss
            model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss)
        # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]):
        #     assert new_layer.name == layer.name
        #     new_layer.set_weights(layer.get_weights())
        # model = new_model
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment=config["augment"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
def train_and_predict():
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_gtruth_train = load_train_data()

    imgs_train = np.transpose(imgs_train, (0, 4, 1, 2, 3))
    imgs_gtruth_train = np.transpose(imgs_gtruth_train, (0, 4, 1, 2, 3))
    
    print('-'*30)
    print('Loading and preprocessing validation data...')
    print('-'*30)
    
    imgs_val, imgs_gtruth_val  = load_validatation_data()
    imgs_val = np.transpose(imgs_val, (0, 4, 1, 2, 3))
    imgs_gtruth_val = np.transpose(imgs_gtruth_val, (0, 4, 1, 2, 3))
    
    
    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

   # create a model
    model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                          initial_learning_rate=config["initial_learning_rate"],
                                          n_base_filters=config["n_base_filters"],loss_function=dice_coef_loss)

    model.summary()



    
    #summarize layers
    #print(model.summary())
    # plot graph
    #plot_model(model, to_file='3d_unet.png')
    
    print('-'*30)
    print('Fitting model...')
    print('-'*30)
    
    #============================================================================
    print('training starting..')
    log_filename = 'outputs/' + image_type +'_model_train.csv' 
    
    
    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
    
#    early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min')
    
    #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5"
    checkpoint_filepath = 'outputs/' + 'weights.h5'
    
    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    
    callbacks_list = [csv_log, checkpoint]
    callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"],
                                           verbose=True))
    callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"]))

    #============================================================================
    hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) #              validation_split=0.2,
        
     
    model_name = 'outputs/' + image_type + '_model_last'
    model.save(model_name)  # creates a HDF5 file 'my_model.h5'