예제 #1
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        create_data_file(config)

    data_file_opened = open_data_file(config["data_file"])

    seg_loss_func = getattr(fetal_net.metrics, config['loss'])
    dis_loss_func = getattr(fetal_net.metrics, config['dis_loss'])

    # instantiate new model
    seg_model_func = getattr(fetal_net.model, config['model_name'])
    gen_model = seg_model_func(
        input_shape=config["input_shape"],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate':
            config['dropout_rate'],
            'loss_function':
            seg_loss_func,
            'mask_shape':
            None if config["weight_mask"] is None else config["input_shape"],
            'old_model_path':
            config['old_model']
        })

    dis_model_func = getattr(fetal_net.model, config['dis_model_name'])
    dis_model = dis_model_func(
        input_shape=[config["input_shape"][0] + config["n_labels"]] +
        config["input_shape"][1:],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate': config['dropout_rate'],
            'loss_function': dis_loss_func
        })

    if not overwrite \
            and len(glob.glob(config["model_file"] + 'g_*.h5')) > 0:
        # dis_model_path = get_last_model_path(config["model_file"] + 'dis_')
        gen_model_path = get_last_model_path(config["model_file"] + 'g_')
        # print('Loading dis model from: {}'.format(dis_model_path))
        print('Loading gen model from: {}'.format(gen_model_path))
        # dis_model = load_old_model(dis_model_path)
        # gen_model = load_old_model(gen_model_path)
        # dis_model.load_weights(dis_model_path)
        gen_model.load_weights(gen_model_path)

    gen_model.summary()
    dis_model.summary()

    # Build "frozen discriminator"
    frozen_dis_model = Network(dis_model.inputs,
                               dis_model.outputs,
                               name='frozen_discriminator')
    frozen_dis_model.trainable = False

    inputs_real = Input(shape=config["input_shape"])
    inputs_fake = Input(shape=config["input_shape"])
    segs_real = Activation(None, name='seg_real')(gen_model(inputs_real))
    segs_fake = Activation(None, name='seg_fake')(gen_model(inputs_fake))
    valid = Activation(None, name='dis')(frozen_dis_model(
        Concatenate(axis=1)([segs_fake, inputs_fake])))
    combined_model = Model(inputs=[inputs_real, inputs_fake],
                           outputs=[segs_real, valid])
    combined_model.compile(loss=[seg_loss_func, 'binary_crossentropy'],
                           loss_weights=[1, config["gd_loss_ratio"]],
                           optimizer=Adam(config["initial_learning_rate"]))
    combined_model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # get training and testing generators
    _, semi_generator, _, _ = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        val_augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # start training
    scheduler = Scheduler(config["dis_steps"],
                          config["gen_steps"],
                          init_lr=config["initial_learning_rate"],
                          lr_patience=config["patience"],
                          lr_decay=config["learning_rate_drop"])

    best_loss = np.inf
    for epoch in range(config["n_epochs"]):
        postfix = {'g': None, 'd': None}  # , 'val_g': None, 'val_d': None}
        with tqdm(range(n_train_steps // config["gen_steps"]),
                  dynamic_ncols=True,
                  postfix={
                      'gen': None,
                      'dis': None,
                      'val_gen': None,
                      'val_dis': None,
                      None: None
                  }) as pbar:
            for n_round in pbar:
                # train D
                outputs = np.zeros(dis_model.metrics_names.__len__())
                for i in range(scheduler.get_dsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(semi_generator)
                    d_x_batch, d_y_batch = input2discriminator(
                        real_patches, real_segs, semi_patches,
                        gen_model.predict(semi_patches,
                                          batch_size=config["batch_size"]),
                        dis_model.output_shape)
                    outputs += dis_model.train_on_batch(d_x_batch, d_y_batch)
                if scheduler.get_dsteps():
                    outputs /= scheduler.get_dsteps()
                    postfix['d'] = build_dsc(dis_model.metrics_names, outputs)
                    pbar.set_postfix(**postfix)

                # train G (freeze discriminator)
                outputs = np.zeros(combined_model.metrics_names.__len__())
                for i in range(scheduler.get_gsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(validation_generator)
                    g_x_batch, g_y_batch = input2gan(real_patches, real_segs,
                                                     semi_patches,
                                                     dis_model.output_shape)
                    outputs += combined_model.train_on_batch(
                        g_x_batch, g_y_batch)
                outputs /= scheduler.get_gsteps()

                postfix['g'] = build_dsc(combined_model.metrics_names, outputs)
                pbar.set_postfix(**postfix)

            # evaluate on validation set
            dis_metrics = np.zeros(dis_model.metrics_names.__len__(),
                                   dtype=float)
            gen_metrics = np.zeros(gen_model.metrics_names.__len__(),
                                   dtype=float)
            evaluation_rounds = n_validation_steps
            for n_round in range(evaluation_rounds):  # rounds_for_evaluation:
                val_patches, val_segs = next(validation_generator)

                # D
                if scheduler.get_dsteps() > 0:
                    d_x_test, d_y_test = input2discriminator(
                        val_patches, val_segs, val_patches,
                        gen_model.predict(
                            val_patches,
                            batch_size=config["validation_batch_size"]),
                        dis_model.output_shape)
                    dis_metrics += dis_model.evaluate(
                        d_x_test,
                        d_y_test,
                        batch_size=config["validation_batch_size"],
                        verbose=0)

                # G
                # gen_x_test, gen_y_test = input2gan(val_patches, val_segs, dis_model.output_shape)
                gen_metrics += gen_model.evaluate(
                    val_patches,
                    val_segs,
                    batch_size=config["validation_batch_size"],
                    verbose=0)

            dis_metrics /= float(evaluation_rounds)
            gen_metrics /= float(evaluation_rounds)
            # save the model and weights with the best validation loss
            if gen_metrics[0] < best_loss:
                best_loss = gen_metrics[0]
                print('Saving Model...')
                with open(
                        os.path.join(
                            config["base_dir"],
                            "g_{}_{:.3f}.json".format(epoch, gen_metrics[0])),
                        'w') as f:
                    f.write(gen_model.to_json())
                gen_model.save_weights(
                    os.path.join(
                        config["base_dir"],
                        "g_{}_{:.3f}.h5".format(epoch, gen_metrics[0])))

            postfix['val_d'] = build_dsc(dis_model.metrics_names, dis_metrics)
            postfix['val_g'] = build_dsc(gen_model.metrics_names, gen_metrics)
            # pbar.set_postfix(**postfix)
            print('val_d: ' + postfix['val_d'], end=' | ')
            print('val_g: ' + postfix['val_g'])
            # pbar.refresh()

            # update step sizes, learning rates
            scheduler.update_steps(epoch, gen_metrics[0])
            K.set_value(dis_model.optimizer.lr, scheduler.get_lr())
            K.set_value(combined_model.optimizer.lr, scheduler.get_lr())

    data_file_opened.close()
예제 #2
0
import argparse

from fetal_net.data import open_data_file

parser = argparse.ArgumentParser()
parser.add_argument("--data1_path",
                    help="specifies model path",
                    type=str,
                    required=True)
parser.add_argument("--data2_path",
                    help="specifies model path",
                    type=str,
                    required=True)
opts = parser.parse_args()

ids_1 = open_data_file(opts.data1_path).root.subject_ids
ids_2 = open_data_file(opts.data2_path).root.subject_ids

print(all([i1 == i2 for i1, i2 in zip(ids_1, ids_2)]))
예제 #3
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        create_data_file(config)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
        model_path = get_last_model_path(config["model_file"])
        print('Loading model from: {}'.format(model_path))
        model = load_old_model(model_path)
    else:
        # instantiate new model
        loss_func = getattr(fetal_net.metrics, config['loss'])
        model_func = getattr(fetal_net.model, config['model_name'])
        model = model_func(
            input_shape=config["input_shape"],
            initial_learning_rate=config["initial_learning_rate"],
            **{
                'dropout_rate':
                config['dropout_rate'],
                'loss_function':
                loss_func,
                'mask_shape':
                None
                if config["weight_mask"] is None else config["input_shape"],
                # TODO: change to output shape
                'old_model_path':
                config['old_model']
            })
        if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
            model_path = get_last_model_path(config["model_file"])
            print('Loading model from: {}'.format(model_path))
            model.load_weights(model_path)
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                output_folder=config["base_dir"])
    data_file_opened.close()
예제 #4
0
def main_train(config, overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        print("Writing h5 file")
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        _, (mean, std) = write_data_to_file(training_files,
                                            config["data_file"],
                                            subject_ids=subject_ids,
                                            normalize=config['normalization'],
                                            add_pred=config['pred_size'])
        with open(os.path.join(config["base_dir"], 'norm_params.json'),
                  mode='w') as f:
            json.dump({'mean': mean, 'std': std}, f)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
        model_path = get_last_model_path(config["model_file"])
        print('Loading model from: {}'.format(model_path))
        model = load_old_model(model_path)
    else:
        # instantiate new model
        loss_func = getattr(fetal_net.metrics, config['loss'])
        model_func = getattr(fetal_net.model, config['model_name'])
        model = model_func(
            input_shape=config["input_shape"],
            initial_learning_rate=config["initial_learning_rate"],
            **{
                'dropout_rate': config['dropout_rate'],
                'loss_function': loss_func,
                'depth': config['model_params']['depth'],
                'n_base_filters': config['model_params']['n_base_filters'],
                'old_model_path': config['old_model'],
                'truth_index': config['truth_index'],
                'truth_size': config['truth_size']
            })
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        pred_index=config["pred_index"],
        pred_size=config["pred_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                output_folder=config["base_dir"])
    data_file_opened.close()