Beispiel #1
0
def main():
    global config
    args = get_args.train25d()

    config = path_utils.update_is_augment(args, config)

    data_path, _, _, _, _ = path_utils.get_training_h5_paths(BRATS_DIR, args)
    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    train(args)
Beispiel #2
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.patch_shape in ["160-192-13", "160-192-15", "160-192-17"]:
        config["initial_learning_rate"] = 1e-4
        print("lr updated...")
    if args.patch_shape in ["160-192-3"]:
        config["initial_learning_rate"] = 1e-2
        print("lr updated...")

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)

    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=args.loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if args.model == "seunet":
            print("init seunet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss,
                is_unet_original=False)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "isensee":
            print("init isensee model")
            model = isensee25d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
def train(overwrite=True,
          crop=True,
          challenge="brats",
          year=2018,
          image_shape="160-160-128",
          is_bias_correction="1",
          is_normalize="z",
          is_denoise="0",
          is_hist_match="0",
          is_test="1",
          depth_unet=4,
          n_base_filters_unet=16,
          model_name="unet",
          patch_shape="128-128-128",
          is_crf="0",
          batch_size=1,
          loss="weighted"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=3)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if overwrite or not os.path.exists(data_path):
        prepare_data(overwrite=overwrite,
                     crop=crop,
                     challenge=challenge,
                     year=year,
                     image_shape=image_shape,
                     is_bias_correction=is_bias_correction,
                     is_normalize=is_normalize,
                     is_denoise=is_denoise,
                     is_hist_match=is_hist_match,
                     is_test=is_test)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
        data_file_opened,
        batch_size=batch_size,
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if model_name == "unet":
            print("init unet model")
            model = unet_model_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "densefcn":
            print("init densenet model")
            # config["initial_learning_rate"] = 1e-5
            model = densefcn_model_3d(
                input_shape=config["input_shape"],
                classes=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                nb_dense_block=5,
                nb_layers_per_block=4,
                early_transition=True,
                dropout_rate=0.2,
                loss_function=loss)

        elif model_name == "denseunet":
            print("init denseunet model")
            model = dense_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "resunet":
            print("init resunet model")
            model = res_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        if model_name == "seunet":
            print("init seunet model")
            model = se_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        else:
            print("init isensee model")
            model = isensee2017_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if is_test == "0":
        experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                project_name="train",
                                workspace="vuhoangminh")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    print("data file:", config["data_file"])
    print("model file:", config["model_file"])
    print("training file:", config["training_file"])
    print("validation file:", config["validation_file"])
    print("testing file:", config["testing_file"])

    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
Beispiel #4
0
def evaluate(args):

    header = ("dice_WholeTumor", "dice_TumorCore", "dice_EnhancingTumor")
    masking_functions = (get_whole_tumor_mask, get_tumor_core_mask,
                         get_enhancing_tumor_mask)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats"
    # prediction_dir = BRATS_DIR
    config["prediction_folder"] = os.path.join(
        prediction_dir, "database/prediction",
        get_filename_without_extension(config["model_file"]))

    name = get_filename_without_extension(config["model_file"]).replace(
        "_aug-0", "")
    config["prediction_folder"] = os.path.join(
        prediction_dir, "database/prediction/wo_augmentation", name)

    if not os.path.exists(config["prediction_folder"]):
        print("model not exists. Please check")
        return None, None
    else:
        prediction_df_csv_folder = os.path.join(BRATS_DIR,
                                                "database/prediction/csv/")
        make_dir(prediction_df_csv_folder)

        config["prediction_df_csv"] = prediction_df_csv_folder + \
            get_filename_without_extension(config["model_file"]) + ".csv"

        if os.path.exists(config["prediction_df_csv"]):
            df = pd.read_csv(config["prediction_df_csv"])
            df1 = df.dice_WholeTumor.T._values
            df2 = df.dice_TumorCore.T._values
            df3 = df.dice_EnhancingTumor.T._values
            rows = np.zeros((df1.size, 3))
            rows[:, 0] = df1
            rows[:, 1] = df2
            rows[:, 2] = df3

            subject_ids = list()
            for case_folder in glob.glob(
                    os.path.join(config["prediction_folder"], "*")):
                if not os.path.isdir(case_folder):
                    continue
                subject_ids.append(os.path.basename(case_folder))

            df = pd.DataFrame.from_records(rows,
                                           columns=header,
                                           index=subject_ids)

            scores = dict()
            for index, score in enumerate(df.columns):
                values = df.values.T[index]
                scores[score] = values[np.isnan(values) == False]

        else:

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("model file:", config["model_file"])
            print("prediction folder:", config["prediction_folder"])
            print("csv file:", config["prediction_df_csv"])
            print("-" * 60)

            rows = list()
            subject_ids = list()
            for case_folder in glob.glob(
                    os.path.join(config["prediction_folder"], "*")):
                if not os.path.isdir(case_folder):
                    continue
                subject_ids.append(os.path.basename(case_folder))
                truth_file = os.path.join(case_folder, "truth.nii.gz")
                truth_image = nib.load(truth_file)
                truth = truth_image.get_data()
                prediction_file = os.path.join(case_folder,
                                               "prediction.nii.gz")
                prediction_image = nib.load(prediction_file)
                prediction = prediction_image.get_data()
                score_case = get_score(truth, prediction, masking_functions)
                rows.append(score_case)

            df = pd.DataFrame.from_records(rows,
                                           columns=header,
                                           index=subject_ids)
            df.to_csv(config["prediction_df_csv"])

            scores = dict()
            for index, score in enumerate(df.columns):
                values = df.values.T[index]
                scores[score] = values[np.isnan(values) == False]

        return scores, model_path
def finetune(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if args.name != "0":
        model_path = args.name

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    folder = os.path.join(BRATS_DIR, "database", "model", "base")

    if not os.path.exists(config["model_file"]):
        model_baseline_path = get_model_baseline_path(folder=folder, args=args)
        if model_baseline_path is None:
            raise ValueError("can not fine baseline model. Please check")
        else:
            config["model_file"] = model_baseline_path

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    make_dir(config["training_file"])

    print_section("get training and testing generators")
    if args.model_dim == 3:
        from unet3d.generator import get_training_and_validation_and_testing_generators
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            is_create_patch_index_list_original=config[
                "is_create_patch_index_list_original"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"])
    elif args.model_dim == 25:
        from unet25d.generator import get_training_and_validation_and_testing_generators25d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)
    else:
        from unet2d.generator import get_training_and_validation_and_testing_generators2d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    print(">> update config file")
    config.update(config_finetune)
    if not os.path.exists(config["model_file"]):
        raise Exception("{} model file not found. Please try again".format(
            config["model_file"]))
    else:
        from unet3d.utils.model_utils import generate_model
        print(">> load old and generate model")
        model = generate_model(
            config["model_file"],
            initial_learning_rate=config["initial_learning_rate"],
            loss_function=args.loss,
            weight_tv_to_main_loss=args.weight_tv_to_main_loss)
        model.summary()

    # run training
    print("-" * 60)
    print("# start finetuning")
    print("-" * 60)

    print("Number of training steps: ", n_train_steps)
    print("Number of validation steps: ", n_validation_steps)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        args=args,
        dir_read_write="finetune",
        is_finetune=True)

    config["model_file"] = model_path

    if os.path.exists(config["model_file"]):
        print("{} existed. Will skip!!!".format(config["model_file"]))
    else:

        if args.is_test == "1":
            config["n_epochs"] = 5

        if args.is_test == "0":
            experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                    project_name="finetune",
                                    workspace="vuhoangminh")
        else:
            experiment = None

        if args.model_dim == 2 and args.model == "isensee":
            config["initial_learning_rate"] = 1e-7

        print(config["initial_learning_rate"], config["learning_rate_drop"])
        print("data file:", config["data_file"])
        print("model file:", config["model_file"])
        print("training file:", config["training_file"])
        print("validation file:", config["validation_file"])
        print("testing file:", config["testing_file"])

        train_model(experiment=experiment,
                    model=model,
                    model_file=config["model_file"],
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=config["initial_learning_rate"],
                    learning_rate_drop=config["learning_rate_drop"],
                    learning_rate_patience=config["patience"],
                    early_stopping_patience=config["early_stop"],
                    n_epochs=config["n_epochs"])

        if args.is_test == "0":
            experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
Beispiel #6
0
def predict(args, prediction_dir="desktop"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(args.patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))

        if prediction_dir == "SERVER":
            prediction_dir = "brats"
        else:
            prediction_dir = "/mnt/sda/3DUnetCNN_BRATS/brats"
        # prediction_dir = BRATS_DIR
        config["prediction_folder"] = os.path.join(
            prediction_dir, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if args.model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif args.model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif args.model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        args.model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])
Beispiel #7
0
def predict(overwrite=True,
            crop=True,
            challenge="brats",
            year=2018,
            image_shape="160-192-128",
            is_bias_correction="1",
            is_normalize="z",
            is_denoise="0",
            is_hist_match="0",
            is_test="1",
            depth_unet=4,
            n_base_filters_unet=16,
            model_name="unet",
            patch_shape="128-128-128",
            is_crf="0",
            batch_size=1,
            loss="minh",
            model_dim=3):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=model_dim,
        dir_read_write="finetune",
        is_finetune=True)

    if not os.path.exists(model_path):
        print("model not exists. Please check")
    else:
        config["data_file"] = data_path
        config["model_file"] = model_path
        config["training_file"] = trainids_path
        config["validation_file"] = validids_path
        config["testing_file"] = testids_path
        config["patch_shape"] = get_shape_from_string(patch_shape)
        config["input_shape"] = tuple([config["nb_channels"]] +
                                      list(config["patch_shape"]))
        config["prediction_folder"] = os.path.join(
            BRATS_DIR, "database/prediction",
            get_filename_without_extension(config["model_file"]))

        if is_all_cases_predicted(config["prediction_folder"],
                                  config["testing_file"]):
            print("Already predicted. Skip...")
            list_already_predicted.append(config["prediction_folder"])
        else:
            make_dir(config["prediction_folder"])

            print("-" * 60)
            print("SUMMARY")
            print("-" * 60)
            print("data file:", config["data_file"])
            print("model file:", config["model_file"])
            print("training file:", config["training_file"])
            print("validation file:", config["validation_file"])
            print("testing file:", config["testing_file"])
            print("prediction folder:", config["prediction_folder"])
            print("-" * 60)

            if not os.path.exists(config["model_file"]):
                raise ValueError("can not find model {}. Please check".format(
                    config["model_file"]))

            if model_dim == 3:
                from unet3d.prediction import run_validation_cases
            elif model_dim == 25:
                from unet25d.prediction import run_validation_cases
            elif model_dim == 2:
                from unet2d.prediction import run_validation_cases
            else:
                raise ValueError(
                    "dim {} NotImplemented error. Please check".format(
                        model_dim))

            run_validation_cases(
                validation_keys_file=config["testing_file"],
                model_file=config["model_file"],
                training_modalities=config["training_modalities"],
                labels=config["labels"],
                hdf5_file=config["data_file"],
                output_label_map=True,
                output_dir=config["prediction_folder"])
Beispiel #8
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if "casnet" in args.model:
        config["data_type_generator"] = 'cascaded'
    elif "sepnet" in args.model:
        config["data_type_generator"] = 'separated'
    else:
        config["data_type_generator"] = 'combined'

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test,
        data_type_generator=config["data_type_generator"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    config["input_shape"] = config["input_shape"][0:len(config["input_shape"]
                                                        ) - 1]
    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        if "casnet" in args.model:
            args.loss = "casweighted"
        model = generate_model(config["model_file"], loss_function=args.loss)
    else:
        # instantiate new model
        if args.model == "isensee":
            print("init isensee model")
            model = isensee2d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "casnet_v1":
            print("init casnet_v1 model")
            model = casnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v2":
            print("init casnet_v2 model")
            model = casnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v3":
            print("init casnet_v3 model")
            model = casnet_v3(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v4":
            print("init casnet_v4 model")
            model = casnet_v4(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v5":
            print("init casnet_v5 model")
            model = casnet_v5(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v6":
            print("init casnet_v6 model")
            model = casnet_v6(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v7":
            print("init casnet_v7 model")
            model = casnet_v7(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v8":
            print("init casnet_v8 model")
            model = casnet_v8(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v1":
            print("init sepnet_v1 model")
            model = sepnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v2":
            print("init sepnet_v2 model")
            model = sepnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)

        else:
            raise ValueError("Model is NotImplemented. Please check")

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    if args.model == "isensee":
        config["initial_learning_rate"] = 1e-6
    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()