Beispiel #1
0
def main(dataset,
         model_path,
         plot_image=False,
         image_save=False,
         image_save_dir=None,
         image_save_subdir=None,
         image_save_suffix=None,
         calc_ap=False,
         calc_froc=False,
         calc_f1=False,
         eval_class=False,
         NMS_thr=0.5,
         **settings):

    # device
    device = 'cuda'

    if "resnet" not in settings.keys():
        resnet = "RN50"
    else:
        resnet = settings["resnet"]

    # load model
    checkpoint = torch.load(model_path)
    model = RetinaNet(**checkpoint['init_kwargs'], resnet=resnet)
    model.eval()
    model.load_state_dict(checkpoint['state_dict']["model"])
    model.to(device)

    if image_save:
        path_parts = model_path.split("/")
        if image_save_dir is None:
            image_save_dir = "/home/temp/moriz/validation/test_results/" + \
                             path_parts[-4] + "/" + path_parts[-2] + "/" + \
                             path_parts[-1].split(".")[0] + "/"
        else:
            image_save_dir += path_parts[-4] + "/" + path_parts[-2] + "/" + \
                             path_parts[-1].split(".")[0] + "/"

        if image_save_subdir is not None:
            image_save_dir += "/" + image_save_subdir + "/"

        if not os.path.isdir(image_save_dir):
            os.makedirs(image_save_dir)

    tp_list = []
    fp_list = []
    fn_list = []

    image_class_list = []
    ribli_class_list = []

    box_label_list = []
    box_score_list = []

    confidence_values = np.arange(0.05, 1, 0.05)
    num_gt = 0

    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            torch.cuda.empty_cache()

            # get image data
            test_data = dataset[i]["data"]
            if "crops" in dataset[i].keys():
                crops = dataset[i]["crops"]
                test_data = np.concatenate((test_data, crops), axis=0)
            gt_bbox = utils.bounding_box(dataset[i]["seg"])
            gt_label = dataset[i]["label"]
            num_gt += len(gt_bbox)

            # convert data to tensor to forward through the model
            torch.cuda.empty_cache()
            test_data = torch.Tensor(test_data).to(device)
            gt_bbox = torch.Tensor(gt_bbox).to(device)

            # predict anchors and labels for the crops using the loaded model
            anchor_preds, cls_preds = model(test_data.unsqueeze(0))

            # convert the predicted anchors to bboxes
            anchors = Anchors()
            boxes, labels, scores = anchors.generateBoxesFromAnchors(
                anchor_preds[0],
                cls_preds[0], (test_data.shape[2], test_data.shape[1]),
                cls_tresh=0.05,
                nms_thresh=NMS_thr)

            #boxes, scores = eval_utils.wbc(boxes, scores, thr=0.2)
            #boxes, scores, labels = eval_utils.nms(boxes, scores, labels, thr=0.2)
            # boxes, scores, labels = eval_utils.rm_overlapping_boxes(boxes,
            #                                                         scores,
            #                                                         labels,
            #                                                         order="xywh")

            if boxes is None:
                boxes = []
                tp_list.append([
                    torch.tensor(0, device=device)
                    for tp in range(len(confidence_values))
                ])
                fp_list.append([
                    torch.tensor(0, device=device)
                    for fp in range(len(confidence_values))
                ])
                fn_list.append([
                    torch.tensor(len(gt_bbox), device=device)
                    for fn in range(len(confidence_values))
                ])

            else:
                # calculate the required rates for the FROC metric
                tp_crop, fp_crop, fn_crop = \
                    eval_utils.calc_tp_fn_fp(gt_bbox,
                                             boxes,
                                             scores,
                                             confidence_values=confidence_values)
                tp_list.append(tp_crop)
                fp_list.append(fp_crop)
                fn_list.append(fn_crop)

                # determine the overlap of detected bbox with the ground truth
                box_label, box_score, _ = \
                    eval_utils.calc_detection_hits(gt_bbox,
                                                   boxes,
                                                   scores,
                                                   score_thr=0.0)
                box_label_list.append(box_label)
                box_score_list.append(box_score)

                image_class_list.append([
                    gt_label,
                    np.float32(labels[torch.argmax(scores)].cpu()),
                    np.float32(torch.max(scores).cpu())
                ])

            # plot the image with the according bboxes
            if plot_image:
                test_data = test_data.to("cpu")
                gt_bbox = gt_bbox.to("cpu")

                # plot image
                c, h, w = test_data.shape
                #figsize = 0.75 * (w / 100), 0.75 * (h / 100)
                figsize = 0.5 * (w / 100), 0.5 * (h / 100)
                #figsize = 0.25 * (w / 100), 0.25 * (h / 100)
                fig, ax = plt.subplots(1, figsize=figsize)

                ax.imshow(test_data[0, :, :], cmap='Greys_r')

                # show bboxes as saved in data (in red with center)
                for l in range(len(gt_bbox)):
                    pos = tuple(gt_bbox[l][0:2])
                    plt.plot(pos[0], pos[1], 'r.')
                    width = gt_bbox[l][2]
                    height = gt_bbox[l][3]
                    pos = (pos[0] - np.floor(width / 2),
                           pos[1] - np.floor(height / 2))

                    # Create a Rectangle patch
                    rect = patches.Rectangle(pos,
                                             width,
                                             height,
                                             linewidth=1,
                                             edgecolor='r',
                                             facecolor='none')
                    ax.add_patch(rect)
                    ax.add_patch(rect)
                    ax.annotate("{:d}".format(np.int32(gt_label)),
                                pos,
                                fontsize=6,
                                color="r",
                                xytext=(pos[0] - 10, pos[1] - 10))

                # show the predicted bboxes (in blue)
                print("Number of detected bboxes: {0}".format(len(boxes)))
                # keep = scores > 0.15
                # boxes = boxes[keep]
                # scores = scores[keep]
                for j in range(len(boxes)):
                    width = boxes[j][2]
                    height = boxes[j][3]
                    pos = (boxes[j][0] - torch.floor(width / 2),
                           boxes[j][1] - torch.floor(height / 2))

                    # Create a Rectangle patch
                    rect = patches.Rectangle(pos,
                                             width,
                                             height,
                                             linewidth=1,
                                             edgecolor='b',
                                             facecolor='none')
                    ax.add_patch(rect)
                    ax.annotate("{:d}|{:.2f}".format(labels[j], scores[j]),
                                pos,
                                fontsize=6,
                                color="b",
                                xytext=(pos[0] + 10, pos[1] - 10))

                    print("BBox params: {0}, score: {1}".format(
                        boxes[j], scores[j]))
                if image_save:
                    if image_save_suffix is not None:
                        save_name = "image_{0}_".format(i) + \
                                    image_save_suffix +  ".pdf"
                    else:
                        save_name = "image_{0}.pdf".format(i)

                    plt.savefig(image_save_dir + save_name,
                                dpi='figure',
                                format='pdf')
                plt.show()

    results = {}
    if calc_f1:
        f1_list = eval_utils.calc_f1(tp_list, fp_list, fn_list)
        plot_utils.plot_f1(f1_list,
                           confidence_values,
                           image_save=image_save,
                           image_save_dir=image_save_dir)
        results["F1"] = f1_list

    if calc_froc:
        froc_tpr, froc_fppi = eval_utils.calc_froc(tp_list, fp_list, fn_list)
        plot_utils.plot_frocs(froc_tpr,
                              froc_fppi,
                              image_save=image_save,
                              image_save_dir=image_save_dir,
                              left_range=1e-2)
        results["FROC"] = {"TPR": froc_tpr, "FPPI": froc_fppi}

    if calc_ap:
        # ap = eval_utils.calc_ap(box_label_list, box_score_list)
        ap, precision_steps = eval_utils.calc_ap_MDT(box_label_list,
                                                     box_score_list, num_gt)
        #print("Num_gt: {0}".format(num_gt))

        plot_utils.plot_precion_recall_curve(
            precision_steps,
            ap_value=ap,
            image_save=image_save,
            image_save_dir=image_save_dir,
        )
        print("AP: {0}".format(ap))

        prec, rec = eval_utils.calc_pr_values(box_label_list, box_score_list,
                                              num_gt)
        plot_utils.plot_precion_recall_curve(prec,
                                             rec,
                                             image_save=True,
                                             image_save_dir=image_save_dir,
                                             save_suffix="native")
        results["AP"] = {"AP": ap, "prec": prec, "rec": rec}

    if eval_class:
        # fpr, tpr, auroc = eval_utils.classification(image_class_list)
        # plot_utils.plot_roc(fpr, tpr, legend=[str(auroc)])
        cm, occ_classes = eval_utils.conf_matrix(image_class_list)
        plot_utils.plot_confusion_matrix(cm,
                                         classes=occ_classes,
                                         image_save=image_save,
                                         image_save_dir=image_save_dir)

    if image_save:
        # picle values for later (different) plotting
        with open(image_save_dir + "results", "wb") as result_file:
            pickle.dump(results, result_file)
Beispiel #2
0
def main(dataset,
         checkpoint_dir,
         start_epoch,
         end_epoch,
         step_size,
         results_save_dir="/home/temp/moriz/validation/pickled_results/",
         **settings):
    '''

    :param dataset: dataset to work with
    :param checkpoint_dir: path to checkpoint directory
    :param start_epoch: first model to validate
    :param end_epoch: last model to validate
    :param step_size: step size that determines in which intervals models
            shall be validated
    :param plot: flag
    :param offset: offset to validate only a part of the loaded data (
            usefull for debugging)
    '''

    # device
    device = 'cuda'
    #device = 'cpu'

    # create dict were all results are saved
    total_results_dict = {}

    # get/ set crop_size
    if "crop_size" in settings and settings["crop_size"] is not None:
        crop_size = settings["crop_size"]
    else:
        crop_size = [600, 600]

    # determine used set
    if "set" in settings and settings["set"] is not None:
        set = settings["set"]
    else:
        raise KeyError("Missing set description!")

    if "fold" in settings and settings["fold"] is not None:
        fold = "0%d" % settings["fold"]
    else:
        fold = "00"

    # set WBC factors
    crop_center = np.asarray(crop_size) / 2.0
    norm_pdf_var = np.int32(min(crop_size[0], crop_size[1]) / 2. - 50)

    # create directory and file name
    cp_dir_date = checkpoint_dir.split("/")[-3]
    results_save_dir = results_save_dir + str(cp_dir_date) + "/" + "fold_" + \
                       fold + "_" + set +  "/image_level_" +  \
                       str(start_epoch) + "_" + str(end_epoch) + "_" + str(step_size)

    # create folder (if necessary)
    if not os.path.isdir(results_save_dir):
        os.makedirs(results_save_dir)

    # gather all important settings in one dict and save them (pickle them)
    settings_dict = {
        "level": "image",
        "checkpoint_dir": checkpoint_dir,
        "start_epoch": start_epoch,
        "end_epoch": end_epoch,
        "step_size": step_size
    }
    settings_dict = {**settings_dict, **settings}

    with open(results_save_dir + "/settings", "wb") as settings_file:
        pickle.dump(settings_dict, settings_file)

    # iterate over the saved epochs and treat each epoch as separate model
    for epoch in tqdm(range(start_epoch, end_epoch + step_size, step_size)):
        checkpoint_path = checkpoint_dir + "/checkpoint_epoch_" + str(
            epoch) + ".pth"

        # load model
        if device == "cpu":
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
        else:
            checkpoint = torch.load(checkpoint_path)
        model = RetinaNet(**checkpoint['init_kwargs']).eval()
        #model.load_state_dict(checkpoint['state_dict'])
        model.load_state_dict(checkpoint['state_dict']['model'])
        model.to(device)

        model_results_dict = {}

        with torch.no_grad():
            for i in tqdm(range(len(dataset))):
                torch.cuda.empty_cache()

                # get image data
                test_data = dataset[i]

                # generate bboxes
                gt_bbox = utils.bounding_box(test_data["seg"])
                gt_bbox = torch.Tensor(gt_bbox).to(device)

                # generate crops
                crop_list, corner_list, heatmap = utils.create_crops(
                    test_data, crop_size=crop_size, heatmap=True)

                # define list for predicted bboxes in crops
                image_bboxes = []
                image_scores = []
                crop_center_factor = []
                heatmap_factor = []

                # iterate over crops
                for j in range(0, len(crop_list)):
                    #CROP LEVEL
                    torch.cuda.empty_cache()
                    test_image = torch.Tensor(crop_list[j]['data']).to(device)

                    # predict anchors and labels for the crops using the loaded model
                    anchor_preds, cls_preds = model(test_image.unsqueeze(0))

                    # convert the predicted anchors to bboxes
                    anchors = Anchors()
                    boxes, labels, score = anchors.generateBoxesFromAnchors(
                        anchor_preds[0],
                        cls_preds[0],
                        (test_image.shape[2], test_image.shape[1]),
                        cls_tresh=0.05)

                    if boxes is None:
                        continue

                    # determine the center of each box and its distance to the
                    # crop center and calculate the resulting down-weighting
                    # factor based on it
                    box_centers = np.asarray(boxes[:, 0:2].to("cpu"))
                    dist = np.linalg.norm(crop_center - box_centers,
                                          ord=2,
                                          axis=1)
                    ccf = norm.pdf(dist, loc=0, scale=norm_pdf_var) * np.sqrt(
                        2 * np.pi) * norm_pdf_var

                    # the detected bboxes are relative to the crop; correct
                    # them with regard to the crop position in the image
                    for k in range(len(boxes)):
                        center_corrected = boxes[k][0:2] + \
                                           torch.Tensor(corner_list[j]).to(device)
                        image_bboxes.append(
                            torch.cat((center_corrected, boxes[k][2:])))
                        image_scores.append(score[k])
                        crop_center_factor.append(ccf[k])

                # IMAGE LEVEL
                # determine heatmap factor based on the center posistion of
                # the bbox (required vor WBC only)
                for c in range(len(image_bboxes)):
                    pos_x = np.int32(image_bboxes[c][0].to("cpu"))
                    pos_x = np.minimum(np.maximum(pos_x, 0),
                                       test_data["data"].shape[2] - 1)

                    pos_y = np.int32(image_bboxes[c][1].to("cpu"))
                    pos_y = np.minimum(np.maximum(pos_y, 0),
                                       test_data["data"].shape[1] - 1)

                    heatmap_factor.append(heatmap[pos_y, pos_x])

                model_results_dict["image_%d" % i] = {
                    "gt_list": gt_bbox,
                    "box_list": image_bboxes,
                    "score_list": image_scores,
                    "merging_utils": {
                        "ccf": crop_center_factor,
                        "hf": heatmap_factor
                    }
                }

                # # convert GT bbox to tensor
                # gt_bbox = torch.Tensor(gt_bbox).to(device)
                #
                # if len(crop_bbox) > 0:
                #     if merging_method == "NMS":
                #         # merge overlapping bounding boxes using NMS
                #         image_bbox, score_bbox = eval_utils.nms(crop_bbox,
                #                                                 score_bbox,
                #                                                 0.2)
                #     elif merging_method == "WBC":
                #         # merge overlapping bounding boxes using WBC
                #         #image_bbox, score_bbox = eval_utils.wbc(crop_bbox, score_bbox, 0.2)
                #
                #         # merge overlapping bounding boxes using my merging
                #         image_bbox, score_bbox = \
                #             eval_utils.my_merging(crop_bbox,
                #                                   score_bbox,
                #                                   crop_center_factor,
                #                                   heatmap_factor,
                #                                   thr=0.2)
                #     else:
                #         raise KeyError("Merging method is not supported.")
                #
                #
                #     # calculate the required rates for the FROC metric
                #     tp_crop, fp_crop, fn_crop = \
                #         eval_utils.calc_tp_fn_fp(gt_bbox,
                #                                  image_bbox,
                #                                  score_bbox,
                #                                  confidence_values=confidence_values)
                #     tp_list.append(tp_crop)
                #     fp_list.append(fp_crop)
                #     fn_list.append(fn_crop)
                #
                #     # determine the overlap of detected bbox with the ground truth
                #     box_label = eval_utils.gt_overlap(gt_bbox, image_bbox)
                #     box_label_list.append(box_label)
                #     box_score_list.append(score_bbox)
                #
                # else:
                #     tp_list.append([torch.tensor(0, device=device) for tp
                #                     in range(len(confidence_values))])
                #     fp_list.append([torch.tensor(0, device=device) for fp
                #                     in range(len(confidence_values))])
                #     fn_list.append([torch.tensor(1, device=device) for fn
                #                     in range(len(confidence_values))])

        # DATASET LEVEL
        total_results_dict[str(epoch)] = model_results_dict

    # MODELS LEVEL
    with open(results_save_dir + "/results", "wb") as result_file:
        torch.save(total_results_dict, result_file)
Beispiel #3
0
def plot_mammogram(image,
                   mask=None,
                   margin=None,
                   annotation=None,
                   plot_mask=False,
                   rel_figsize=0.25,
                   image_save=False,
                   image_save_dir="/home/temp/moriz/validation/",
                   save_suffix=None,
                   format="pdf"):

    if len(image.shape) == 3:
        c, h, w = image.shape
        figsize = rel_figsize * (w / 100), rel_figsize * (h / 100)
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(image[0, :, :], cmap='Greys_r')
    else:
        h, w = image.shape
        figsize = rel_figsize * (w / 100), rel_figsize * (h / 100)
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(image, cmap='Greys_r')

    if mask is not None:
        if plot_mask:
            if len(mask.shape) == 3:
                ax.imshow(mask[0, :, :], cmap='hot', alpha=0.5)
            elif len(mask.shape) == 2:
                ax.imshow(mask, cmap='hot', alpha=0.5)
            else:
                raise ValueError("Too many input dimensions!")

        bbox = utils.bounding_box(mask, margin)
        number_bbox = len(bbox)
        #print("Number bboxes: {0}".format(number_bbox))

        for k in range(number_bbox):
            pos = tuple(bbox[k][0:2])
            width = bbox[k][2]
            height = bbox[k][3]

            pos = (pos[0] - np.floor(width / 2), pos[1] - np.floor(height / 2))

            # Create a Rectangle patch
            rect = patches.Rectangle(pos,
                                     width,
                                     height,
                                     linewidth=1,
                                     edgecolor='r',
                                     facecolor='none')
            ax.add_patch(rect)

            if annotation is not None:
                if isinstance(annotation, np.ndarray):
                    ax.annotate("{0}".format(annotation[k]),
                                pos,
                                fontsize=10,
                                color="b",
                                xytext=(pos[0] + 10, pos[1] - 10))
                else:
                    ax.annotate("{0}".format(annotation),
                                pos,
                                fontsize=10,
                                color="b",
                                xytext=(pos[0] + 10, pos[1] - 10))

    if image_save:
        if save_suffix is None:
            save_name = "/mammogram." + format
        else:
            save_name = "/mammogram_" + save_suffix + "." + format
        plt.savefig(image_save_dir + save_name, dpi='figure', format=format)
    plt.show()
    plt.close(fig)
Beispiel #4
0
def main(dataset,
         checkpoint_dir,
         start_epoch,
         end_epoch,
         step_size,
         results_save_dir="/home/temp/moriz/validation/pickled_results/",
         plot=False,
         **settings):
    '''

    :param dataset: dataset to work with
    :param checkpoint_dir: path to checkpoint directory
    :param start_epoch: first model to validate
    :param end_epoch: last model to validate
    :param step_size: step size that determines in which intervals models
            shall be validated
    :param plot: flag
    :param offset: offset to validate only a part of the loaded data (
            usefull for debugging)
    '''

    # device
    device = 'cuda'
    #device = 'cpu'

    # create dict were all results are saved
    total_results_dict = {}

    # get/ set crop_size
    if "crop_size" in settings and settings["crop_size"] is not None:
        crop_size = settings["crop_size"]
    else:
        crop_size = [600, 600]

    # determine used set
    if "set" in settings and settings["set"] is not None:
        set = settings["set"]
    else:
        raise KeyError("Missing set description!")

    # set WBC factors
    crop_center = np.asarray(crop_size) / 2.0
    norm_pdf_var = np.int32(min(crop_size[0], crop_size[1]) / 2. - 50)

    # create directory and file name
    cp_dir_date = checkpoint_dir.split("/")[-3]
    results_save_dir = results_save_dir + str(cp_dir_date) + "/" + set + \
                       "/crop_level_" + str(start_epoch) + \
                       "_" + str(end_epoch) + "_" + str(step_size)

    # create folder (if necessary)
    if not os.path.isdir(results_save_dir):
        os.makedirs(results_save_dir)

    # gather all important settings in one dict and save them (pickle them)
    settings_dict = {
        "level": "crops",
        "checkpoint_dir": checkpoint_dir,
        "start_epoch": start_epoch,
        "end_epoch": end_epoch,
        "step_size": step_size
    }
    settings_dict = {**settings_dict, **settings}

    with open(results_save_dir + "/settings", "wb") as settings_file:
        pickle.dump(settings_dict, settings_file)

    for epoch in tqdm(range(start_epoch, end_epoch + step_size, step_size)):
        checkpoint_path = checkpoint_dir + "/checkpoint_epoch_" + str(
            epoch) + ".pth"

        # load model
        if device == "cpu":
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
        else:
            checkpoint = torch.load(checkpoint_path)
        model = RetinaNet(**checkpoint['init_kwargs']).eval()
        model.load_state_dict(checkpoint['state_dict']["model"])
        #model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        # dict for results from one model
        model_results_dict = {}

        with torch.no_grad():
            for i in tqdm(range(len(dataset))):
                torch.cuda.empty_cache()

                # get image data
                data = dataset[i]

                # iterate over the number of crops generated from image i
                for k in range(len(data)):
                    crop_data = data[k]
                    gt_bbox = utils.bounding_box(crop_data["seg"])

                    torch.cuda.empty_cache()
                    crop = torch.Tensor(crop_data['data']).to(device)
                    gt_bbox = torch.Tensor(gt_bbox).to(device)
                    gt_label = crop_data["label"]

                    # predict anchors and labels for the crops using the loaded model
                    anchor_preds, cls_preds = model(crop.unsqueeze(0))

                    # convert the predicted anchors to bboxes
                    anchors = Anchors()

                    boxes, labels, scores = \
                        anchors.generateBoxesFromAnchors(anchor_preds[0],
                                                         cls_preds[0],
                                                         (crop.shape[2],
                                                          crop.shape[1]),
                                                         cls_tresh=0.05)

                    # show the predicted bboxes (in blue)
                    # plot the image with the according bboxes
                    if plot:
                        gt_bbox = gt_bbox.to("cpu")

                        # plot image
                        c, h, w = crop_data["data"].shape
                        figsize = (w / 100), (h / 100)
                        fig, ax = plt.subplots(1, figsize=figsize)

                        ax.imshow(crop_data["data"][0, :, :], cmap='Greys_r')

                        # show bboxes as saved in data (in red with center)
                        for l in range(len(gt_bbox)):
                            pos = tuple(gt_bbox[l][0:2])
                            plt.plot(pos[0], pos[1], 'r.')
                            width = gt_bbox[l][2]
                            height = gt_bbox[l][3]
                            pos = (pos[0] - np.floor(width / 2),
                                   pos[1] - np.floor(height / 2))

                            # Create a Rectangle patch
                            rect = patches.Rectangle(pos,
                                                     width,
                                                     height,
                                                     linewidth=1,
                                                     edgecolor='r',
                                                     facecolor='none')
                            ax.add_patch(rect)

                        # keep = score_bbox > 0.5
                        # crop_boxes = crop_boxes[keep]
                        # score_bbox = score_bbox[keep]
                        for j in range(len(boxes)):
                            width = boxes[j][2]
                            height = boxes[j][3]
                            pos = (boxes[j][0] - torch.floor(width / 2),
                                   boxes[j][1] - torch.floor(height / 2))

                            # Create a Rectangle patch
                            rect = patches.Rectangle(pos,
                                                     width,
                                                     height,
                                                     linewidth=1,
                                                     edgecolor='b',
                                                     facecolor='none')
                            ax.add_patch(rect)
                            ax.annotate("{:.2f}".format(scores[j]),
                                        pos,
                                        fontsize=10,
                                        color="b",
                                        xytext=(pos[0] + 10, pos[1] - 10))
                        plt.show()

                    model_results_dict["image_{0}_crop_{1}".format(i, k)] = \
                        {"gt_list": gt_bbox,
                         "gt_label": gt_label,
                         "box_list": boxes,
                         "score_list": scores}

        # DATASET LEVEL
        total_results_dict[str(epoch)] = model_results_dict

    # MODELS LEVEL
    with open(results_save_dir + "/results", "wb") as result_file:
        torch.save(total_results_dict, result_file)
Beispiel #5
0
def eval(dataset, model_path, plot=False):
    # device
    device = 'cuda'

    # load model
    checkpoint = torch.load(model_path)
    model = RetinaNet(**checkpoint['init_kwargs']).eval()
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    # hyperparams
    crop_size = [600, 600]
    overlapped_boxes = 0.5
    confidence_values = np.arange(0.5, 1, 0.05)
    tpr_list = []
    fppi_list = []

    with torch.no_grad():
        for i in tqdm(range(len(dataset))):
            torch.cuda.empty_cache()

            # get image data
            test_data = dataset[i]

            # crop background
            test_data = inbreast_utils.segment_breast(test_data)
            image_bbox = utils.bounding_box(dataset[i]["seg"])

            # generate crops
            crop_list, corner_list = inbreast_utils.create_crops(test_data)

            # define list for predicted bboxes in crops
            crop_bbox = []
            score_bbox = []

            # plot the image with the according bboxes
            if plot:
                # plot image
                plt.figure(1, figsize=(15, 10))
                fig, ax = plt.subplots(1)

                ax.imshow(test_data["data"][0, :, :], cmap='Greys_r')

                # show bboxes as saved in data (in red with center)
                for l in range(len(image_bbox)):
                    pos = tuple(image_bbox[l][0:2])
                    plt.plot(pos[0], pos[1], 'r.')
                    width = image_bbox[l][2]
                    height = image_bbox[l][3]
                    pos = (pos[0] - np.floor(width / 2),
                           pos[1] - np.floor(height / 2))

                    # Create a Rectangle patch
                    rect = patches.Rectangle(pos,
                                             width,
                                             height,
                                             linewidth=1,
                                             edgecolor='r',
                                             facecolor='none')
                    ax.add_patch(rect)

            # iterate over crops
            for j in tqdm(range(0, len(crop_list))):
                torch.cuda.empty_cache()
                test_image = torch.Tensor(crop_list[j]['data']).to(device)
                test_bbox = utils.bounding_box(crop_list[j]['seg'])

                # predict anchors and labels for the crops using the loaded model
                anchor_preds, cls_preds = model(test_image.unsqueeze(0))

                # convert the predicted anchors to bboxes
                anchors = Anchors()
                boxes, labels, score = anchors.generateBoxesFromAnchors(
                    anchor_preds[0].to('cpu'),
                    cls_preds[0].to('cpu'),
                    tuple(test_image.shape[1:]),
                    cls_tresh=0.05)

                # correct the predicted bboxes
                for k in range(len(boxes)):
                    center_corrected = boxes[k][0:2] + \
                                       torch.Tensor(corner_list[j])
                    crop_bbox.append(
                        torch.cat((center_corrected, boxes[k][2:])))
                    score_bbox.append(score[k])

            # merge overlapping bounding boxes
            crop_bbox, score_bbox = merge(crop_bbox, score_bbox)

            # calculate the FROC metric (TPR vs. FPPI)
            tpr_int = []
            fppi_int = []
            image_bbox = change_box_order(torch.Tensor(image_bbox),
                                          order='xywh2xyxy').to('cpu')
            iou_thr = 0.2
            for j in confidence_values:
                current_bbox = crop_bbox[score_bbox > j]

                if len(current_bbox) == 0:
                    tpr_int.append(torch.Tensor([0]))
                    fppi_int.append(torch.Tensor([0]))
                    continue
                    #break

                iou_matrix = box_iou(
                    image_bbox,
                    change_box_order(current_bbox, order="xywh2xyxy"))
                iou_matrix = iou_matrix > iou_thr

                # true positives are the lesions that are recognized
                tp = iou_matrix.sum()

                # false negatives are the lesions that are NOT recognized
                fn = image_bbox.shape[0] - tp

                # true positive rate
                tpr = tp.type(torch.float32) / (tp + fn).type(torch.float32)
                tpr = torch.clamp(tpr, 0, 1)

                # number of false positives per image
                fp = (current_bbox.shape[0] - tp).type(torch.float32)

                tpr_int.append(tpr)
                fppi_int.append(fp)
            tpr_list.append(tpr_int)
            fppi_list.append(fppi_int)

            if plot:
                # show the predicted bboxes (in blue)
                print("Number of detected bboxes: {0}".format(len(crop_bbox)))
                keep = score_bbox > 0.5
                crop_bbox = crop_bbox[keep]
                score_bbox = score_bbox[keep]
                for j in range(len(crop_bbox)):
                    width = crop_bbox[j][2]
                    height = crop_bbox[j][3]
                    pos = (crop_bbox[j][0] - torch.floor(width / 2),
                           crop_bbox[j][1] - torch.floor(height / 2))

                    # Create a Rectangle patch
                    rect = patches.Rectangle(pos,
                                             width,
                                             height,
                                             linewidth=1,
                                             edgecolor='b',
                                             facecolor='none')
                    ax.add_patch(rect)
                    ax.annotate("{:.2f}".format(score_bbox[j]),
                                pos,
                                fontsize=6,
                                xytext=(pos[0] + 10, pos[1] - 10))

                    print("BBox params: {0}, score: {1}".format(
                        crop_bbox[j], score_bbox[j]))
                plt.show()

            #     fig.savefig("../plots/" + "_".join(model_path.split("/")[5:8]) + ".png")

        # calculate FROC over all test images
        tpr_list = np.asarray(tpr_list)
        tpr = np.sum(tpr_list, axis=0) / tpr_list.shape[0]

        fppi_list = np.asarray(fppi_list)
        fppi = np.sum(fppi_list, axis=0) / fppi_list.shape[0]

    # plt.figure(1)
    # plt.ylim(0, 1.1)
    # plt.xlabel("False Positve per Image (FPPI)")
    # plt.ylabel("True Positive Rate (TPR)")
    # plt.title("Free Response Operating Characteristic (FROC)")
    # plt.plot(np.asarray(fppi), np.asarray(tpr), "rx-")
    # plt.show()

    return tpr, fppi
Beispiel #6
0
def main(dataset,
         checkpoint_dir,
         start_epoch,
         end_epoch,
         step_size,
         results_save_dir="/home/temp/moriz/validation/pickled_results/",
         **settings):
    '''

    :param dataset: dataset to work with
    :param checkpoint_dir: path to checkpoint directory
    :param start_epoch: first model to validate
    :param end_epoch: last model to validate
    :param step_size: step size that determines in which intervals models
            shall be validated
    :param plot: flag
    :param offset: offset to validate only a part of the loaded data (
            usefull for debugging)
    '''

    # device
    device = 'cuda'
    #device = 'cpu'

    total_results_dict = {}

    # determine used set
    if "set" in settings and settings["set"] is not None:
        set = settings["set"]
    else:
        raise KeyError("Missing set description!")

    if "resnet" not in settings.keys():
        resnet = "RN50"
    else:
        resnet = settings["resnet"]

    # create directory and file name
    cp_dir_date = checkpoint_dir.split("/")[-3]
    if "fold" not in settings:
        results_save_dir = results_save_dir + \
                           str(cp_dir_date) + "/" + \
                           set + \
                           "/whole_image_level_" + \
                           str(start_epoch) + "_" + \
                           str(end_epoch) + "_" + \
                           str(step_size)
    else:
        results_save_dir = results_save_dir + \
                           str(cp_dir_date) + "/" + \
                           set + \
                           "/whole_image_level_" + \
                           str(start_epoch) + "_" + \
                           str(end_epoch) + "_" + \
                           str(step_size) + \
                           "/fold_" + str(settings["fold"])

    # create folder (if necessary)
    if not os.path.isdir(results_save_dir):
        os.makedirs(results_save_dir)

    # gather all important settings in one dict and save them (pickle them)
    settings_dict = {
        "level": "whole_image",
        "checkpoint_dir": checkpoint_dir,
        "start_epoch": start_epoch,
        "end_epoch": end_epoch,
        "step_size": step_size
    }
    settings_dict = {**settings_dict, **settings}

    with open(results_save_dir + "/settings", "wb") as settings_file:
        pickle.dump(settings_dict, settings_file)

    for epoch in tqdm(range(start_epoch, end_epoch + step_size, step_size)):
        checkpoint_path = checkpoint_dir + "/checkpoint_epoch_" + str(
            epoch) + ".pth"

        # load model
        if device == "cpu":
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
        else:
            checkpoint = torch.load(checkpoint_path)
        model = RetinaNet(**checkpoint['init_kwargs'], resnet=resnet).eval()
        model.load_state_dict(checkpoint['state_dict']['model'])
        model.to(device)

        model_results_dict = {}

        with torch.no_grad():
            for i in tqdm(range(len(dataset))):
                torch.cuda.empty_cache()

                # get image data
                test_data = dataset[i]["data"]
                if "crops" in dataset[i].keys():
                    crops = dataset[i]["crops"]
                    test_data = np.concatenate((test_data, crops), axis=0)

                gt_bbox = utils.bounding_box(dataset[i]["seg"])
                gt_label = dataset[i]["label"]

                # convert data to tensor to forward through the model
                torch.cuda.empty_cache()
                test_data = torch.Tensor(test_data).to(device)
                gt_bbox = torch.Tensor(gt_bbox).to(device)

                # predict anchors and labels for the crops using the loaded model
                anchor_preds, cls_preds = model(test_data.unsqueeze(0))

                # convert the predicted anchors to bboxes
                anchors = Anchors()
                boxes, labels, scores = anchors.generateBoxesFromAnchors(
                    anchor_preds[0],
                    cls_preds[0], (test_data.shape[2], test_data.shape[1]),
                    cls_tresh=0.05)

                model_results_dict["image_%d" % i] = {
                    "gt_list": gt_bbox,
                    "gt_label": gt_label,
                    "box_list": boxes,
                    "score_list": scores,
                    "labels_list": labels
                }
        # DATASET LEVEL
        total_results_dict[str(epoch)] = model_results_dict

    # MODELS LEVEL
    with open(results_save_dir + "/results", "wb") as result_file:
        torch.save(total_results_dict, result_file)