コード例 #1
0
def main():
    global config
    args = get_args.train25d()

    config = path_utils.update_is_augment(args, config)

    data_path, _, _, _, _ = path_utils.get_training_h5_paths(BRATS_DIR, args)
    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    train(args)
コード例 #2
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.patch_shape in ["160-192-13", "160-192-15", "160-192-17"]:
        config["initial_learning_rate"] = 1e-4
        print("lr updated...")
    if args.patch_shape in ["160-192-3"]:
        config["initial_learning_rate"] = 1e-2
        print("lr updated...")

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)

    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=args.loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if args.model == "seunet":
            print("init seunet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss,
                is_unet_original=False)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                #   batch_normalization=True,
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet25d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "isensee":
            print("init isensee model")
            model = isensee25d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
コード例 #3
0
def finetune(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    if args.name != "0":
        model_path = args.name

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    folder = os.path.join(BRATS_DIR, "database", "model", "base")

    if not os.path.exists(config["model_file"]):
        model_baseline_path = get_model_baseline_path(folder=folder, args=args)
        if model_baseline_path is None:
            raise ValueError("can not fine baseline model. Please check")
        else:
            config["model_file"] = model_baseline_path

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    make_dir(config["training_file"])

    print_section("get training and testing generators")
    if args.model_dim == 3:
        from unet3d.generator import get_training_and_validation_and_testing_generators
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            is_create_patch_index_list_original=config[
                "is_create_patch_index_list_original"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"])
    elif args.model_dim == 25:
        from unet25d.generator import get_training_and_validation_and_testing_generators25d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators25d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)
    else:
        from unet2d.generator import get_training_and_validation_and_testing_generators2d
        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
            data_file_opened,
            batch_size=args.batch_size,
            data_split=config["validation_split"],
            overwrite=args.overwrite,
            validation_keys_file=config["validation_file"],
            training_keys_file=config["training_file"],
            testing_keys_file=config["testing_file"],
            n_labels=config["n_labels"],
            labels=config["labels"],
            patch_shape=config["patch_shape"],
            validation_batch_size=args.batch_size,
            validation_patch_overlap=config["validation_patch_overlap"],
            training_patch_start_offset=config["training_patch_start_offset"],
            augment_flipud=config["augment_flipud"],
            augment_fliplr=config["augment_fliplr"],
            augment_elastic=config["augment_elastic"],
            augment_rotation=config["augment_rotation"],
            augment_shift=config["augment_shift"],
            augment_shear=config["augment_shear"],
            augment_zoom=config["augment_zoom"],
            n_augment=config["n_augment"],
            skip_blank=config["skip_blank"],
            is_test=args.is_test)

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    print(">> update config file")
    config.update(config_finetune)
    if not os.path.exists(config["model_file"]):
        raise Exception("{} model file not found. Please try again".format(
            config["model_file"]))
    else:
        from unet3d.utils.model_utils import generate_model
        print(">> load old and generate model")
        model = generate_model(
            config["model_file"],
            initial_learning_rate=config["initial_learning_rate"],
            loss_function=args.loss,
            weight_tv_to_main_loss=args.weight_tv_to_main_loss)
        model.summary()

    # run training
    print("-" * 60)
    print("# start finetuning")
    print("-" * 60)

    print("Number of training steps: ", n_train_steps)
    print("Number of validation steps: ", n_validation_steps)

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        args=args,
        dir_read_write="finetune",
        is_finetune=True)

    config["model_file"] = model_path

    if os.path.exists(config["model_file"]):
        print("{} existed. Will skip!!!".format(config["model_file"]))
    else:

        if args.is_test == "1":
            config["n_epochs"] = 5

        if args.is_test == "0":
            experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                    project_name="finetune",
                                    workspace="vuhoangminh")
        else:
            experiment = None

        if args.model_dim == 2 and args.model == "isensee":
            config["initial_learning_rate"] = 1e-7

        print(config["initial_learning_rate"], config["learning_rate_drop"])
        print("data file:", config["data_file"])
        print("model file:", config["model_file"])
        print("training file:", config["training_file"])
        print("validation file:", config["validation_file"])
        print("testing file:", config["testing_file"])

        train_model(experiment=experiment,
                    model=model,
                    model_file=config["model_file"],
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=config["initial_learning_rate"],
                    learning_rate_drop=config["learning_rate_drop"],
                    learning_rate_patience=config["patience"],
                    early_stopping_patience=config["early_stop"],
                    n_epochs=config["n_epochs"])

        if args.is_test == "0":
            experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()
コード例 #4
0
def train(overwrite=True,
          crop=True,
          challenge="brats",
          year=2018,
          image_shape="160-160-128",
          is_bias_correction="1",
          is_normalize="z",
          is_denoise="0",
          is_hist_match="0",
          is_test="1",
          depth_unet=4,
          n_base_filters_unet=16,
          model_name="unet",
          patch_shape="128-128-128",
          is_crf="0",
          batch_size=1,
          loss="weighted"):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR,
        overwrite=overwrite,
        crop=crop,
        challenge=challenge,
        year=year,
        image_shape=image_shape,
        is_bias_correction=is_bias_correction,
        is_normalize=is_normalize,
        is_denoise=is_denoise,
        is_hist_match=is_hist_match,
        is_test=is_test,
        model_name=model_name,
        depth_unet=depth_unet,
        n_base_filters_unet=n_base_filters_unet,
        patch_shape=patch_shape,
        is_crf=is_crf,
        loss=loss,
        model_dim=3)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if overwrite or not os.path.exists(data_path):
        prepare_data(overwrite=overwrite,
                     crop=crop,
                     challenge=challenge,
                     year=year,
                     image_shape=image_shape,
                     is_bias_correction=is_bias_correction,
                     is_normalize=is_normalize,
                     is_denoise=is_denoise,
                     is_hist_match=is_hist_match,
                     is_test=is_test)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators(
        data_file_opened,
        batch_size=batch_size,
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        model = generate_model(config["model_file"], loss_function=loss)
        # model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        if model_name == "unet":
            print("init unet model")
            model = unet_model_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "densefcn":
            print("init densenet model")
            # config["initial_learning_rate"] = 1e-5
            model = densefcn_model_3d(
                input_shape=config["input_shape"],
                classes=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                nb_dense_block=5,
                nb_layers_per_block=4,
                early_transition=True,
                dropout_rate=0.2,
                loss_function=loss)

        elif model_name == "denseunet":
            print("init denseunet model")
            model = dense_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        elif model_name == "resunet":
            print("init resunet model")
            model = res_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        if model_name == "seunet":
            print("init seunet model")
            model = se_unet_3d(
                input_shape=config["input_shape"],
                pool_size=config["pool_size"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=depth_unet,
                n_base_filters=n_base_filters_unet,
                loss_function=loss)

        else:
            print("init isensee model")
            model = isensee2017_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=loss)

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if is_test == "0":
        experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                                project_name="train",
                                workspace="vuhoangminh")
    else:
        experiment = None

    print(config["initial_learning_rate"], config["learning_rate_drop"])
    print("data file:", config["data_file"])
    print("model file:", config["model_file"])
    print("training file:", config["training_file"])
    print("validation file:", config["validation_file"])
    print("testing file:", config["testing_file"])

    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
コード例 #5
0
def train(args):

    data_path, trainids_path, validids_path, testids_path, model_path = get_training_h5_paths(
        brats_dir=BRATS_DIR, args=args)

    config["data_file"] = data_path
    config["model_file"] = model_path
    config["training_file"] = trainids_path
    config["validation_file"] = validids_path
    config["testing_file"] = testids_path
    config["patch_shape"] = get_shape_from_string(args.patch_shape)
    config["input_shape"] = tuple([config["nb_channels"]] +
                                  list(config["patch_shape"]))

    if "casnet" in args.model:
        config["data_type_generator"] = 'cascaded'
    elif "sepnet" in args.model:
        config["data_type_generator"] = 'separated'
    else:
        config["data_type_generator"] = 'combined'

    if args.overwrite or not os.path.exists(data_path):
        prepare_data(args)

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_and_testing_generators2d(
        data_file_opened,
        batch_size=args.batch_size,
        data_split=config["validation_split"],
        overwrite=args.overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        testing_keys_file=config["testing_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=args.batch_size,
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"],
        is_test=args.is_test,
        data_type_generator=config["data_type_generator"])

    print("-" * 60)
    print("# Load or init model")
    print("-" * 60)
    config["input_shape"] = config["input_shape"][0:len(config["input_shape"]
                                                        ) - 1]
    if not args.overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        from unet3d.utils.model_utils import generate_model
        if "casnet" in args.model:
            args.loss = "casweighted"
        model = generate_model(config["model_file"], loss_function=args.loss)
    else:
        # instantiate new model
        if args.model == "isensee":
            print("init isensee model")
            model = isensee2d_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                loss_function=args.loss)
        elif args.model == "unet":
            print("init unet model")
            model = unet_model_2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "segnet":
            print("init segnet model")
            model = segnet2d(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                initial_learning_rate=config["initial_learning_rate"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function=args.loss)
        elif args.model == "casnet_v1":
            print("init casnet_v1 model")
            model = casnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v2":
            print("init casnet_v2 model")
            model = casnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v3":
            print("init casnet_v3 model")
            model = casnet_v3(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v4":
            print("init casnet_v4 model")
            model = casnet_v4(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v5":
            print("init casnet_v5 model")
            model = casnet_v5(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet,
                loss_function="casweighted")
        elif args.model == "casnet_v6":
            print("init casnet_v6 model")
            model = casnet_v6(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v7":
            print("init casnet_v7 model")
            model = casnet_v7(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "casnet_v8":
            print("init casnet_v8 model")
            model = casnet_v8(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v1":
            print("init sepnet_v1 model")
            model = sepnet_v1(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)
        elif args.model == "sepnet_v2":
            print("init sepnet_v2 model")
            model = sepnet_v2(
                input_shape=config["input_shape"],
                initial_learning_rate=config["initial_learning_rate"],
                deconvolution=config["deconvolution"],
                depth=args.depth_unet,
                n_base_filters=args.n_base_filters_unet)

        else:
            raise ValueError("Model is NotImplemented. Please check")

    model.summary()

    print("-" * 60)
    print("# start training")
    print("-" * 60)
    # run training

    if args.is_test == "0":
        experiment = Experiment(api_key="34T3kJ5CkXUtKAbhI6foGNFBL",
                                project_name="train",
                                workspace="guusgrimbergen")
    else:
        experiment = None

    if args.model == "isensee":
        config["initial_learning_rate"] = 1e-6
    print(config["initial_learning_rate"], config["learning_rate_drop"])
    train_model(experiment=experiment,
                model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])

    if args.is_test == "0":
        experiment.log_parameters(config)

    data_file_opened.close()
    from keras import backend as K
    K.clear_session()