train_images, train_segs = load_training_data()
val_images, val_segs = load_validation_data()

train_labels = dataset_construction.create_all_area_masks(
    train_images, train_segs)
val_labels = dataset_construction.create_all_area_masks(val_images, val_segs)

NUM_CLASSES = train_segs.shape[1] + 1

train_labels = to_categorical(train_labels, NUM_CLASSES)
val_labels = to_categorical(val_labels, NUM_CLASSES)

train_imdb = imdb.ImageDatabase(images=train_images,
                                labels=train_labels,
                                name=DATASET_NAME,
                                filename=DATASET_NAME,
                                mode_type='fullsize',
                                num_classes=NUM_CLASSES)
val_imdb = imdb.ImageDatabase(images=val_images,
                              labels=val_labels,
                              name=DATASET_NAME,
                              filename=DATASET_NAME,
                              mode_type='fullsize',
                              num_classes=NUM_CLASSES)

# models from the "Automatic choroidal segmentation in OCT images using supervised deep learning methods" paper (currently excluding RNN bottleneck and Combined)
model_residual = sem_models.resnet(8,
                                   4,
                                   2,
                                   1, (3, 3), (2, 2),
                                   input_channels=INPUT_CHANNELS,
    return val_images, val_segs


train_images, train_segs = load_training_data()
val_images, val_segs = load_validation_data()

# if you would like to generate patches differently, comment out these lines and write a replacement function
train_patches, train_patch_labels = dataset_construction.sample_all_training_patches(train_images, train_segs, range(0, train_images.shape[1]), PATCH_SIZE)
val_patches, val_patch_labels = dataset_construction.sample_all_training_patches(val_images, val_segs, range(0, val_images.shape[1]), PATCH_SIZE)

NUM_CLASSES = train_segs.shape[1] + 1

train_patch_labels = keras.utils.to_categorical(train_patch_labels, num_classes=NUM_CLASSES)
val_patch_labels = keras.utils.to_categorical(val_patch_labels, num_classes=NUM_CLASSES)

train_patch_imdb = imdb.ImageDatabase(images=train_patches, labels=train_patch_labels, name=DATASET_NAME, filename=DATASET_NAME, mode_type='patch', num_classes=NUM_CLASSES)
val_patch_imdb = imdb.ImageDatabase(images=val_patches, labels=val_patch_labels, name=DATASET_NAME, filename=DATASET_NAME, mode_type='patch', num_classes=NUM_CLASSES)

# patch-based models from the "Automatic choroidal segmentation in OCT images using supervised deep learning methods" paper
model_cifar = patch_models.cifar_cnn(NUM_CLASSES, train_patches.shape[1], train_patches.shape[2])
model_complex = patch_models.complex_cnn(NUM_CLASSES, train_patches.shape[1], train_patches.shape[2])
model_rnn = patch_models.rnn_stack(4, ('ver', 'hor', 'ver', 'hor'), (True, True, True, True),
                               (CuDNNGRU, CuDNNGRU, CuDNNGRU, CuDNNGRU), (0.25, 0.25, 0.25, 0.25), (1, 1, 2, 2), (1, 1, 2, 2),
                               (16, 16, 16, 16), False, 0, INPUT_CHANNELS, train_patches.shape[1], train_patches.shape[2],
                               NUM_CLASSES)

opt_con = keras.optimizers.Adam
opt_params = {}     # default params
loss = keras.losses.categorical_crossentropy
metric = keras.metrics.categorical_accuracy
epochs = 100
示例#3
0
def evaluate_patch_based_network(eval_params, imdb):
    # patches need to be constructed and passed to the generator for one image at a time
    if eval_params.save_params.output_var is True:
        eval_outputs = []
    else:
        eval_outputs = None

    for ind in imdb.image_range:
        if eval_params.save_params.output_var is True:
            eval_output = eoutput.EvaluationOutput()
        else:
            eval_output = None

        cur_full_image = imdb.get_image(ind)
        cur_patch_labels = imdb.get_patch_label(ind)
        cur_image_name = imdb.get_image_name(ind)
        cur_seg = imdb.get_seg(ind)

        if eval_params.save_params.output_var is True:
            eval_output.raw_image = cur_full_image
            eval_output.raw_label = cur_patch_labels
            eval_output.image_name = cur_image_name
            eval_output.raw_seg = cur_seg

        if eval_params.verbosity >= 2:
            print("Evaluating image number: " + str(ind + 1) + " (" +
                  cur_image_name + ")...")

        if eval_params.save_params.disable is False:
            if eval_helper.check_exists(eval_params.save_foldername,
                                        cur_image_name):
                # if the file for this image exists then we have already begun this at some point
                print("File already exists")
            else:
                eval_helper.save_initial_attributes(eval_params,
                                                    cur_image_name)

            status = eval_helper.get_complete_status(
                eval_params.save_foldername,
                cur_image_name,
                boundaries=eval_params.boundaries)
        else:
            status = 'none'

        if status == 'none' and (eval_params.eval_mode == 'both'
                                 or eval_params.eval_mode == 'network'):
            # PERFORM STEP 1: evaluate/predict patches with network

            if eval_params.verbosity >= 2:
                print("Augmenting data using augmentation: " +
                      eval_params.aug_desc + "...")

            aug_fn = eval_params.aug_fn_arg[0]
            aug_arg = eval_params.aug_fn_arg[1]

            # augment raw full sized image and label
            augment_image, augment_patch_labels, augment_seg, _, augment_time = \
                aug_fn(cur_full_image, cur_patch_labels, cur_seg, aug_arg)

            if eval_params.save_params.output_var is True:
                eval_output.aug_image = augment_image
                eval_output.aug_label = augment_patch_labels
                eval_output.aug_seg = augment_seg

            if eval_params.verbosity >= 2:
                print("Constructing patches...")

            # construct patches
            input_patches, input_labels, patch_time = \
                datacon.construct_patches_whole_image(augment_image, augment_patch_labels,
                                                             eval_params.patch_size)

            patch_imdb = image_db.ImageDatabase(images=input_patches,
                                                labels=input_labels)

            if eval_params.verbosity >= 2:
                print("Running network predictions...")

            # use a generator to supply data to model (predict_generator)
            # we have already previously augmented to image so need to augment the individual patches

            start_predict_time = time.time()

            import keras

            class CustomCallback(keras.callbacks.Callback):
                def __init__(self, gen):
                    keras.callbacks.Callback.__init__(self)
                    self.gen = gen

                def on_predict_begin(self, logs=None):
                    self.gen.batch_gen.batch_counter = 0
                    self.gen.batch_gen.full_counter = 0
                    self.gen.batch_gen.aug_counter = 0

            if not eval_params.ensemble:
                start_gen_time = time.time()
                gen = data_generator.DataGenerator(
                    patch_imdb,
                    eval_params.batch_size,
                    aug_fn_args=[],
                    aug_mode='none',
                    aug_probs=[],
                    aug_fly=False,
                    shuffle=False,
                    normalise=eval_params.normalise_input,
                    transpose=eval_params.transpose)
                end_gen_time = time.time()
                gen_time = end_gen_time - start_gen_time

                cust_callback = CustomCallback(gen)
                predicted_labels = eval_params.loaded_model.predict_generator(
                    gen,
                    verbose=eval_params.predict_verbosity,
                    callbacks=[cust_callback])
                print(predicted_labels.shape)
            else:
                predicted_labels = []

                for i in range(len(eval_params.loaded_models)):
                    start_gen_time = time.time()
                    gen = data_generator.DataGenerator(
                        patch_imdb,
                        eval_params.batch_size,
                        aug_fn_args=[],
                        aug_mode='none',
                        aug_probs=[],
                        aug_fly=False,
                        shuffle=False,
                        normalise=eval_params.normalise_input,
                        transpose=eval_params.transpose)
                    end_gen_time = time.time()
                    gen_time = end_gen_time - start_gen_time

                    predicted_labels.append(
                        eval_params.loaded_models[i].predict_generator(
                            gen, verbose=eval_params.predict_verbosity))

            end_predict_time = time.time()
            predict_time = end_predict_time - start_predict_time

            if eval_params.verbosity >= 2:
                print("Converting predictions to boundary maps...")

            # convert predictions to usable probability maps
            start_convert_time = time.time()

            if eval_params.boundaries is True and eval_params.save_params.boundary_maps is True:

                if not eval_params.ensemble:

                    prob_maps = convert_predictions_to_maps_patch_based(
                        predicted_labels, imdb.image_width, imdb.image_height)
                else:
                    prob_maps = []

                    for i in range(len(predicted_labels)):
                        prob_maps.append(
                            np.expand_dims(
                                convert_predictions_to_maps_patch_based(
                                    predicted_labels[i], imdb.image_width,
                                    imdb.image_height),
                                axis=0))

                    prob_maps = eval_helper.perform_ensemble_patch(prob_maps)
            else:
                prob_maps = None

            if eval_params.save_params.output_var is True:
                eval_output.boundary_maps = prob_maps

            end_convert_time = time.time()
            convert_time = end_convert_time - start_convert_time

            # save data to file
            if eval_params.save_params.disable is False:
                eval_helper.intermediate_save_patch_based(
                    eval_params, imdb, cur_image_name, prob_maps, predict_time,
                    augment_time, gen_time, convert_time, patch_time,
                    augment_image, augment_patch_labels, augment_seg,
                    cur_full_image, cur_patch_labels, cur_seg)

        if eval_params.save_params.disable is False:
            status = eval_helper.get_complete_status(
                eval_params.save_foldername,
                cur_image_name,
                boundaries=eval_params.boundaries)
        else:
            status = 'predict'

        if status == 'predict' and eval_params.boundaries is True and \
                (eval_params.eval_mode == 'both' or eval_params.eval_mode == 'gs'):
            aug_fn = eval_params.aug_fn_arg[0]
            aug_arg = eval_params.aug_fn_arg[1]

            # augment raw full sized image and label
            augment_image, augment_patch_labels, augment_seg, _, augment_time = \
                aug_fn(cur_full_image, cur_patch_labels, cur_seg, aug_arg)

            # load probability maps from previous step
            if eval_params.save_params.disable is False and eval_params.save_params.boundary_maps is True:
                prob_maps = eval_helper.load_dataset_extra(
                    eval_params, cur_image_name, "boundary_maps")

            # PERFORM STEP 2: segment probability maps using graph search
            boundary_maps = get_boundary_maps_only(imdb, prob_maps)
            eval_helper.eval_second_step(eval_params,
                                         boundary_maps,
                                         augment_seg,
                                         cur_image_name,
                                         augment_image,
                                         augment_patch_labels,
                                         imdb,
                                         dices=None,
                                         eval_output=eval_output)
        elif eval_params.boundaries is False:
            if eval_params.save_params.disable is False and eval_params.save_params.attributes is True:
                eval_helper.save_final_attributes(eval_params,
                                                  cur_image_name,
                                                  graph_time=None)

        if eval_params.save_params.disable is False and eval_params.save_params.temp_extra is True:
            eval_helper.delete_loadsaveextra_file(eval_params, cur_image_name)

        if eval_params.verbosity >= 2:
            print("DONE image number: " + str(ind + 1) + " (" +
                  cur_image_name + ")...")
            print("______________________________")

    return eval_outputs
NUM_CLASSES = test_segs.shape[1] + 1                 # update for required number of classes

# boundary names should be a list of strings with length = NUM_CLASSES - 1
# class names should be a list of strings with length = NUM_CLASSES
AREA_NAMES = ["area_" + str(i) for i in range(NUM_CLASSES)]
BOUNDARY_NAMES = ["boundary_" + str(i) for i in range(NUM_CLASSES - 1)]
PATCH_CLASS_NAMES = ["BG"]
for i in range(len(BOUNDARY_NAMES)):
    PATCH_CLASS_NAMES.append(BOUNDARY_NAMES[i])

GSGRAD = 1
CUSTOM_OBJECTS = dict(list(custom_losses.custom_loss_objects.items()) +
                      list(custom_metrics.custom_metric_objects.items()))

eval_imdb = imdb.ImageDatabase(images=test_images, labels=None, patch_labels=test_patch_labels, segs=test_segs, image_names=test_image_names,
                               boundary_names=BOUNDARY_NAMES, area_names=AREA_NAMES,
                               fullsize_class_names=AREA_NAMES, patch_class_names=PATCH_CLASS_NAMES, num_classes=NUM_CLASSES, name=TEST_DATA_NAME, filename=TEST_DATA_NAME, mode_type='fullsize')

batch_size = 992    # CURRENTLY THIS NEEDS TO BE CHOSEN AS A VALUE WHICH IS A FACTOR OF THE AREA (IN PIXELS) OF THE FULL IMAGE (i.e. 992 is a factor of a 761856 (1536x496) pixel image [992 x 768 = 761856])
network_folder = parameters.RESULTS_LOCATION + "\\2020-03-10 14_25_08 Cifar CNN 32x32 stargardt_girard_patches_fold1\\" # name of network folder for which to evaluate model
model_name = "model_epoch04.hdf5"   # name of model file inside network folder to evaluate

loaded_model = load_model(network_folder + "/" + model_name, custom_objects=CUSTOM_OBJECTS)

aug_fn_arg = (aug.no_aug, {})

eval_helper.evaluate_network(eval_imdb, model_name, network_folder,
                             batch_size, save_parameters.SaveParameters(pngimages=True, raw_image=True, temp_extra=True, boundary_maps=True, area_maps=True, comb_area_maps=True, seg_plot=True),
                             gsgrad=GSGRAD, aug_fn_arg=aug_fn_arg, eval_mode='both', boundaries=True, boundary_errors=True, dice_errors=False, col_error_range=None, normalise_input=True, transpose=False)

示例#5
0
def evaluate_single_images(eval_params, imdb):
    # pass images to network one at a time
    eval_outputs = []

    for ind in imdb.image_range:
        eval_output = eoutput.EvaluationOutput()

        cur_raw_image = imdb.get_image(ind)
        cur_label = imdb.get_label(ind)
        cur_image_name = imdb.get_image_name(ind)
        cur_seg = imdb.get_seg(ind)

        eval_output.raw_image = cur_raw_image
        eval_output.raw_label = cur_label
        eval_output.image_name = cur_image_name
        eval_output.raw_seg = cur_seg

        if eval_params.verbosity >= 2:
            print("Evaluating image number: " + str(ind + 1) + " (" +
                  cur_image_name + ")...")

        if eval_params.save_params.disable is False:
            if eval_helper.check_exists(eval_params.save_foldername,
                                        cur_image_name):
                # if the file for this image exists then we have already begun this at some point
                print("File already exists")
            else:
                eval_helper.save_initial_attributes(eval_params,
                                                    cur_image_name)

            status = eval_helper.get_complete_status(
                eval_params.save_foldername, cur_image_name,
                eval_params.boundaries)
        else:
            status = 'none'

        if status == 'none' and (eval_params.eval_mode == 'both'
                                 or eval_params.eval_mode == 'network'):
            # PERFORM STEP 1: evaluate/predict patches with network

            if eval_params.verbosity >= 2:
                print("Augmenting data using augmentation: " +
                      eval_params.aug_desc + "...")

            aug_fn = eval_params.aug_fn_arg[0]
            aug_arg = eval_params.aug_fn_arg[1]

            # augment raw full sized image and label
            augment_image, augment_label, augment_seg, _, augment_time = \
                aug_fn(cur_raw_image, cur_label, cur_seg, aug_arg, sample_ind=ind, set=imdb.set)

            eval_output.aug_image = augment_image
            eval_output.aug_label = augment_label
            eval_output.aug_seg = augment_seg

            if eval_params.verbosity >= 2:
                print("Running network predictions...")

            images = np.expand_dims(augment_image, axis=0)
            labels = np.expand_dims(augment_label, axis=0)
            single_image_imdb = image_db.ImageDatabase(images=images,
                                                       labels=labels)

            # use a generator to supply data to model (predict_generator)

            start_gen_time = time.time()
            gen = data_generator.DataGenerator(
                single_image_imdb,
                eval_params.batch_size,
                aug_fn_args=[],
                aug_mode='none',
                aug_probs=[],
                aug_fly=False,
                shuffle=False,
                transpose=eval_params.transpose,
                normalise=eval_params.normalise_input)
            end_gen_time = time.time()
            gen_time = end_gen_time - start_gen_time

            start_predict_time = time.time()

            if not eval_params.ensemble:
                predicted_labels = eval_params.loaded_model.predict_generator(
                    gen, verbose=eval_params.predict_verbosity)
            else:
                predicted_labels = []

                for i in range(len(eval_params.loaded_models)):
                    predicted_labels.append(
                        eval_params.loaded_models[i].predict_generator(
                            gen, verbose=eval_params.predict_verbosity))

            end_predict_time = time.time()

            predict_time = end_predict_time - start_predict_time

            if eval_params.save_params.activations is True:
                if not eval_params.ensemble:
                    if eval_params.save_params.act_layers is None:
                        layer_outputs = [
                            layer.output for layer in eval_params.loaded_model.
                            layers[1:len(eval_params.loaded_model.layers)]
                        ]
                    else:
                        layer_outputs = [
                            layer.output
                            for layer in eval_params.save_params.act_layers
                        ]

                    activation_model = Model(
                        inputs=eval_params.loaded_model.input,
                        outputs=layer_outputs)
                    # Creates a model that will return these outputs, given the model input

                    if eval_params.normalise_input:
                        images_norm = images / 255
                    else:
                        images_norm = images

                    activations = activation_model.predict(images_norm)
                else:
                    layer_outputs = []
                    activations = []

                    # TODO: implement write handling for ensemble activations
                    for i in range(len(eval_params.loaded_models)):
                        layer_outputs.append([
                            layer.output
                            for layer in eval_params.loaded_models[i].
                            layers[1:len(eval_params.loaded_models[i].layers)]
                        ])

                        activation_model = Model(
                            inputs=eval_params.loaded_models[i].input,
                            outputs=layer_outputs[i])
                        # Creates a model that will return these outputs, given the model input

                        if eval_params.normalise_input:
                            images_norm = images / 255
                        else:
                            images_norm = images

                        activations.append(
                            activation_model.predict(images_norm))
            else:
                activations = None
                layer_outputs = None

            if eval_params.verbosity >= 2:
                print("Converting predictions to boundary maps...")

            if not eval_params.ensemble:
                if eval_params.transpose is True:
                    predicted_labels = np.transpose(predicted_labels,
                                                    axes=(0, 2, 1, 3))

                # convert predictions to usable boundary probability maps

                start_convert_time = time.time()

                [comb_area_map, area_maps
                 ] = eval_helper.perform_argmax(predicted_labels,
                                                ensemble=False,
                                                bin=eval_params.binarize)
            else:
                if eval_params.transpose is True:
                    for i in range(len(eval_params.loaded_models)):
                        predicted_labels[i] = np.transpose(predicted_labels[i],
                                                           axes=(0, 2, 1, 3))

                # convert predictions to usable boundary probability maps

                start_convert_time = time.time()

                [comb_area_map_sep, area_maps_sep
                 ] = eval_helper.perform_argmax(predicted_labels,
                                                ensemble=True,
                                                bin=eval_params.binarize)

                # ensemble using majority voting scheme
                [comb_area_map,
                 area_maps] = eval_helper.perform_ensemble(area_maps_sep)

                print(area_maps.shape)

                if eval_params.binarize_after is True:
                    num_maps = area_maps.shape[1]

                    if eval_params.use_thresh:
                        area_maps[:,
                                  1][area_maps[:, 1] >= eval_params.thresh] = 1
                        area_maps[:,
                                  1][area_maps[:, 1] < eval_params.thresh] = 0

                        area_maps[:,
                                  0][area_maps[:, 0] < eval_params.thresh] = 1
                        area_maps[:,
                                  0][area_maps[:, 0] >= eval_params.thresh] = 0
                        area_maps = np.argmax(area_maps, axis=1)
                    else:
                        area_maps = np.argmax(area_maps, axis=1)

                    area_maps = to_categorical(area_maps, num_maps)

                    area_maps = np.transpose(area_maps, axes=(0, 3, 1, 2))

                print(area_maps.shape)

            eval_output.comb_area_map = comb_area_map
            eval_output.area_maps = area_maps

            if eval_params.boundaries is False or eval_params.save_params.boundary_maps is False:
                boundary_maps = None
            else:
                if eval_params.vertical_graph_search is False:
                    boundary_maps = convert_predictions_to_maps_semantic(
                        np.array(area_maps),
                        bg_ilm=eval_params.bg_ilm,
                        bg_csi=eval_params.bg_csi)
                elif eval_params.vertical_graph_search is True:
                    boundary_maps = convert_predictions_to_maps_semantic_vertical(
                        np.array(area_maps),
                        bg_ilm=eval_params.bg_ilm,
                        bg_csi=eval_params.bg_csi)
                elif eval_params.vertical_graph_search == "ilm_vertical":
                    ilm_map = np.expand_dims(
                        convert_predictions_to_maps_semantic_vertical(
                            np.array(area_maps),
                            bg_ilm=eval_params.bg_ilm,
                            bg_csi=eval_params.bg_csi)[0, 0],
                        axis=0)
                    other_maps = convert_predictions_to_maps_semantic(
                        np.array(area_maps),
                        bg_ilm=eval_params.bg_ilm,
                        bg_csi=eval_params.bg_csi)[0][1:]

                    boundary_maps = np.expand_dims(np.concatenate(
                        [ilm_map, other_maps], axis=0),
                                                   axis=0)

            eval_output.boundary_maps = boundary_maps

            end_convert_time = time.time()
            convert_time = end_convert_time - start_convert_time

            if eval_params.dice_errors is True:
                dices = eval_helper.calc_dice(eval_params, area_maps, labels)
            else:
                dices = None

            area_maps = np.squeeze(area_maps)
            comb_area_map = np.squeeze(comb_area_map)
            boundary_maps = np.squeeze(boundary_maps)

            # save data to files
            if eval_params.save_params.disable is False:
                eval_helper.intermediate_save_semantic(
                    eval_params, imdb, cur_image_name, boundary_maps,
                    predict_time, augment_time, gen_time, augment_image,
                    augment_label, augment_seg, cur_raw_image, cur_label,
                    cur_seg, area_maps, comb_area_map, dices, convert_time,
                    activations, layer_outputs)
        if eval_params.save_params.disable is False:
            status = eval_helper.get_complete_status(
                eval_params.save_foldername, cur_image_name,
                eval_params.boundaries)
        else:
            status = 'predict'

        if status == 'predict' and eval_params.boundaries is True and \
                (eval_params.eval_mode == 'both' or eval_params.eval_mode == 'gs'):
            cur_image_name = imdb.get_image_name(ind)
            cur_seg = imdb.get_seg(ind)
            cur_raw_image = imdb.get_image(ind)
            cur_label = imdb.get_label(ind)

            aug_fn = eval_params.aug_fn_arg[0]
            aug_arg = eval_params.aug_fn_arg[1]

            # augment raw full sized image and label
            augment_image, augment_label, augment_seg, _, _ = \
                aug_fn(cur_raw_image, cur_label, cur_seg, aug_arg, sample_ind=ind, set=imdb.set)

            if eval_params.save_params.disable is False and eval_params.save_params.boundary_maps is True:
                boundary_maps = eval_helper.load_dataset_extra(
                    eval_params, cur_image_name, "boundary_maps")
                if eval_params.dice_errors is True:
                    dices = eval_helper.load_dataset(eval_params,
                                                     cur_image_name, "dices")
                else:
                    dices = None

            # PERFORM STEP 2: segment probability maps using graph search
            eval_output = eval_helper.eval_second_step(
                eval_params, boundary_maps, augment_seg, cur_image_name,
                augment_image, augment_label, imdb, dices, eval_output)
        elif eval_params.boundaries is False:
            if eval_params.save_params.disable is False and eval_params.save_params.attributes is True:
                eval_helper.save_final_attributes(eval_params,
                                                  cur_image_name,
                                                  graph_time=None)

        if eval_params.save_params.disable is False and eval_params.save_params.temp_extra is True:
            eval_helper.delete_loadsaveextra_file(eval_params, cur_image_name)

        if eval_params.verbosity >= 2:
            print("DONE image number: " + str(ind + 1) + " (" +
                  cur_image_name + ")...")
            print("______________________________")

    return eval_outputs