def secondary_prediction(mask, vol, config2, model2_path=None,
                         preprocess_method2=None, norm_params2=None,
                         overlap_factor=0.9, augment2=None, num_augment=32, return_all_preds=False):
    model2 = load_old_model(get_last_model_path(model2_path), config=config2)
    pred = mask
    bbox_start, bbox_end = find_bounding_box(pred)
    check_bounding_box(pred, bbox_start, bbox_end)
    padding = [16, 16, 8]
    if padding is not None:
        bbox_start = np.maximum(bbox_start - padding, 0)
        bbox_end = np.minimum(bbox_end + padding, mask.shape)
    data = vol.astype(np.float)[
           bbox_start[0]:bbox_end[0],
           bbox_start[1]:bbox_end[1],
           bbox_start[2]:bbox_end[2]
           ]

    data = preproc_and_norm(data, preprocess_method2, norm_params2)

    prediction = get_prediction(data, model2, augment=augment2, num_augments=num_augment, return_all_preds=return_all_preds,
                                overlap_factor=overlap_factor, config=config2)

    padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end))
    if return_all_preds:
        padding2 = [(0, 0)] + padding2
    print(padding2)
    print(prediction.shape)
    prediction = np.pad(prediction, padding2, mode='constant', constant_values=0)

    return prediction
def main(input_path, output_path, overlap_factor,
         config, model_path, preprocess_method=None, norm_params=None, augment=None, num_augment=0,
         config2=None, model2_path=None, preprocess_method2=None, norm_params2=None, augment2=None, num_augment2=0,
         z_scale=None, xy_scale=None, return_all_preds=False):
    print(model_path)
    model = load_old_model(get_last_model_path(model_path), config=config)
    print('Loading nifti from {}...'.format(input_path))
    nifti = read_img(input_path)
    print('Predicting mask...')
    data = nifti.get_fdata().astype(np.float).squeeze()
    print('original_shape: ' + str(data.shape))
    scan_name = Path(input_path).name.split('.')[0]

    if (z_scale is None):
        z_scale = 1.0
    if (xy_scale is None):
        xy_scale = 1.0
    if z_scale != 1.0 or xy_scale != 1.0:
        data = ndimage.zoom(data, [xy_scale, xy_scale, z_scale])

    data = preproc_and_norm(data, preprocess_method, norm_params,
                            scale=config.get('scale_data', None),
                            preproc=config.get('preproc', None))

    save_nifti(data, os.path.join(output_path, scan_name + '_data.nii.gz'))

    data = np.pad(data, 3, 'constant', constant_values=data.min())

    print('Shape: ' + str(data.shape))
    prediction = get_prediction(data=data, model=model, augment=augment,
                                num_augments=num_augment, return_all_preds=return_all_preds,
                                overlap_factor=overlap_factor, config=config)
    # unpad
    prediction = prediction[3:-3, 3:-3, 3:-3]

    # revert to original size
    if config.get('scale_data', None) is not None:
        prediction = ndimage.zoom(prediction.squeeze(), np.divide([1, 1, 1], config.get('scale_data', None)), order=0)[..., np.newaxis]

    save_nifti(prediction, os.path.join(output_path, scan_name + '_pred.nii.gz'))

    if z_scale != 1.0 or xy_scale != 1.0:
        prediction = ndimage.zoom(prediction.squeeze(), [1.0 / xy_scale, 1.0 / xy_scale, 1.0 / z_scale], order=1)[..., np.newaxis]

    # if prediction.shape[-1] > 1:
    #    prediction = prediction[..., 1]
    if config2 is not None:
        prediction = prediction.squeeze()
        mask = process_pred(prediction, gaussian_std=0.5, threshold=0.5)  # .astype(np.uint8)
        nifti = read_img(input_path)
        prediction = secondary_prediction(mask, vol=nifti.get_fdata().astype(np.float),
                                          config2=config2, model2_path=model2_path,
                                          preprocess_method2=preprocess_method2, norm_params2=norm_params2,
                                          overlap_factor=overlap_factor, augment2=augment2, num_augment=num_augment2,
                                          return_all_preds=return_all_preds)
        save_nifti(prediction, os.path.join(output_path, scan_name + 'pred_roi.nii.gz'))

    print('Saving to {}'.format(output_path))
    print('Finished.')
Пример #3
0
def run_validation_cases(validation_keys_file,
                         model_file,
                         training_modalities,
                         hdf5_file,
                         patch_shape,
                         output_dir=".",
                         overlap_factor=0,
                         permute=False,
                         prev_truth_index=None,
                         prev_truth_size=None,
                         use_augmentations=False):
    file_names = []
    validation_indices = pickle_load(validation_keys_file)
    model = load_old_model(get_last_model_path(model_file))
    data_file = tables.open_file(hdf5_file, "r")
    for index in validation_indices:
        if 'subject_ids' in data_file.root:
            case_directory = os.path.join(
                output_dir, data_file.root.subject_ids[index].decode('utf-8'))
        else:
            case_directory = os.path.join(output_dir,
                                          "validation_case_{}".format(index))
        file_names.append(
            run_validation_case(data_index=index,
                                output_dir=case_directory,
                                model=model,
                                data_file=data_file,
                                training_modalities=training_modalities,
                                overlap_factor=overlap_factor,
                                permute=permute,
                                patch_shape=patch_shape,
                                prev_truth_index=prev_truth_index,
                                prev_truth_size=prev_truth_size,
                                use_augmentations=use_augmentations))
    data_file.close()
    return file_names
Пример #4
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        create_data_file(config)

    data_file_opened = open_data_file(config["data_file"])

    seg_loss_func = getattr(fetal_net.metrics, config['loss'])
    dis_loss_func = getattr(fetal_net.metrics, config['dis_loss'])

    # instantiate new model
    seg_model_func = getattr(fetal_net.model, config['model_name'])
    gen_model = seg_model_func(
        input_shape=config["input_shape"],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate':
            config['dropout_rate'],
            'loss_function':
            seg_loss_func,
            'mask_shape':
            None if config["weight_mask"] is None else config["input_shape"],
            'old_model_path':
            config['old_model']
        })

    dis_model_func = getattr(fetal_net.model, config['dis_model_name'])
    dis_model = dis_model_func(
        input_shape=[config["input_shape"][0] + config["n_labels"]] +
        config["input_shape"][1:],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate': config['dropout_rate'],
            'loss_function': dis_loss_func
        })

    if not overwrite \
            and len(glob.glob(config["model_file"] + 'g_*.h5')) > 0:
        # dis_model_path = get_last_model_path(config["model_file"] + 'dis_')
        gen_model_path = get_last_model_path(config["model_file"] + 'g_')
        # print('Loading dis model from: {}'.format(dis_model_path))
        print('Loading gen model from: {}'.format(gen_model_path))
        # dis_model = load_old_model(dis_model_path)
        # gen_model = load_old_model(gen_model_path)
        # dis_model.load_weights(dis_model_path)
        gen_model.load_weights(gen_model_path)

    gen_model.summary()
    dis_model.summary()

    # Build "frozen discriminator"
    frozen_dis_model = Network(dis_model.inputs,
                               dis_model.outputs,
                               name='frozen_discriminator')
    frozen_dis_model.trainable = False

    inputs_real = Input(shape=config["input_shape"])
    inputs_fake = Input(shape=config["input_shape"])
    segs_real = Activation(None, name='seg_real')(gen_model(inputs_real))
    segs_fake = Activation(None, name='seg_fake')(gen_model(inputs_fake))
    valid = Activation(None, name='dis')(frozen_dis_model(
        Concatenate(axis=1)([segs_fake, inputs_fake])))
    combined_model = Model(inputs=[inputs_real, inputs_fake],
                           outputs=[segs_real, valid])
    combined_model.compile(loss=[seg_loss_func, 'binary_crossentropy'],
                           loss_weights=[1, config["gd_loss_ratio"]],
                           optimizer=Adam(config["initial_learning_rate"]))
    combined_model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # get training and testing generators
    _, semi_generator, _, _ = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        val_augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # start training
    scheduler = Scheduler(config["dis_steps"],
                          config["gen_steps"],
                          init_lr=config["initial_learning_rate"],
                          lr_patience=config["patience"],
                          lr_decay=config["learning_rate_drop"])

    best_loss = np.inf
    for epoch in range(config["n_epochs"]):
        postfix = {'g': None, 'd': None}  # , 'val_g': None, 'val_d': None}
        with tqdm(range(n_train_steps // config["gen_steps"]),
                  dynamic_ncols=True,
                  postfix={
                      'gen': None,
                      'dis': None,
                      'val_gen': None,
                      'val_dis': None,
                      None: None
                  }) as pbar:
            for n_round in pbar:
                # train D
                outputs = np.zeros(dis_model.metrics_names.__len__())
                for i in range(scheduler.get_dsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(semi_generator)
                    d_x_batch, d_y_batch = input2discriminator(
                        real_patches, real_segs, semi_patches,
                        gen_model.predict(semi_patches,
                                          batch_size=config["batch_size"]),
                        dis_model.output_shape)
                    outputs += dis_model.train_on_batch(d_x_batch, d_y_batch)
                if scheduler.get_dsteps():
                    outputs /= scheduler.get_dsteps()
                    postfix['d'] = build_dsc(dis_model.metrics_names, outputs)
                    pbar.set_postfix(**postfix)

                # train G (freeze discriminator)
                outputs = np.zeros(combined_model.metrics_names.__len__())
                for i in range(scheduler.get_gsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(validation_generator)
                    g_x_batch, g_y_batch = input2gan(real_patches, real_segs,
                                                     semi_patches,
                                                     dis_model.output_shape)
                    outputs += combined_model.train_on_batch(
                        g_x_batch, g_y_batch)
                outputs /= scheduler.get_gsteps()

                postfix['g'] = build_dsc(combined_model.metrics_names, outputs)
                pbar.set_postfix(**postfix)

            # evaluate on validation set
            dis_metrics = np.zeros(dis_model.metrics_names.__len__(),
                                   dtype=float)
            gen_metrics = np.zeros(gen_model.metrics_names.__len__(),
                                   dtype=float)
            evaluation_rounds = n_validation_steps
            for n_round in range(evaluation_rounds):  # rounds_for_evaluation:
                val_patches, val_segs = next(validation_generator)

                # D
                if scheduler.get_dsteps() > 0:
                    d_x_test, d_y_test = input2discriminator(
                        val_patches, val_segs, val_patches,
                        gen_model.predict(
                            val_patches,
                            batch_size=config["validation_batch_size"]),
                        dis_model.output_shape)
                    dis_metrics += dis_model.evaluate(
                        d_x_test,
                        d_y_test,
                        batch_size=config["validation_batch_size"],
                        verbose=0)

                # G
                # gen_x_test, gen_y_test = input2gan(val_patches, val_segs, dis_model.output_shape)
                gen_metrics += gen_model.evaluate(
                    val_patches,
                    val_segs,
                    batch_size=config["validation_batch_size"],
                    verbose=0)

            dis_metrics /= float(evaluation_rounds)
            gen_metrics /= float(evaluation_rounds)
            # save the model and weights with the best validation loss
            if gen_metrics[0] < best_loss:
                best_loss = gen_metrics[0]
                print('Saving Model...')
                with open(
                        os.path.join(
                            config["base_dir"],
                            "g_{}_{:.3f}.json".format(epoch, gen_metrics[0])),
                        'w') as f:
                    f.write(gen_model.to_json())
                gen_model.save_weights(
                    os.path.join(
                        config["base_dir"],
                        "g_{}_{:.3f}.h5".format(epoch, gen_metrics[0])))

            postfix['val_d'] = build_dsc(dis_model.metrics_names, dis_metrics)
            postfix['val_g'] = build_dsc(gen_model.metrics_names, gen_metrics)
            # pbar.set_postfix(**postfix)
            print('val_d: ' + postfix['val_d'], end=' | ')
            print('val_g: ' + postfix['val_g'])
            # pbar.refresh()

            # update step sizes, learning rates
            scheduler.update_steps(epoch, gen_metrics[0])
            K.set_value(dis_model.optimizer.lr, scheduler.get_lr())
            K.set_value(combined_model.optimizer.lr, scheduler.get_lr())

    data_file_opened.close()
Пример #5
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        create_data_file(config)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
        model_path = get_last_model_path(config["model_file"])
        print('Loading model from: {}'.format(model_path))
        model = load_old_model(model_path)
    else:
        # instantiate new model
        loss_func = getattr(fetal_net.metrics, config['loss'])
        model_func = getattr(fetal_net.model, config['model_name'])
        model = model_func(
            input_shape=config["input_shape"],
            initial_learning_rate=config["initial_learning_rate"],
            **{
                'dropout_rate':
                config['dropout_rate'],
                'loss_function':
                loss_func,
                'mask_shape':
                None
                if config["weight_mask"] is None else config["input_shape"],
                # TODO: change to output shape
                'old_model_path':
                config['old_model']
            })
        if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
            model_path = get_last_model_path(config["model_file"])
            print('Loading model from: {}'.format(model_path))
            model.load_weights(model_path)
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                output_folder=config["base_dir"])
    data_file_opened.close()