コード例 #1
0
ファイル: main.py プロジェクト: P0ntiff/thesis
def attribute_panel_wrapper(model_name: str):
    methods = [LIME, LIFT, GRAD]
    att = Attributer(model_name)
    for i in [11]:  # range(6, 7):
        ih = ImageHandler(img_no=i, model_name=model_name)
        att.attribute_panel(ih=ih,
                            methods=methods,
                            save=True,
                            visualise=True,
                            take_threshold=True,
                            take_absolute=True,
                            sigma_multiple=1)
コード例 #2
0
ファイル: main.py プロジェクト: P0ntiff/thesis
def demo_attribute(img_nos: list = None, att: Attributer = None):
    if att is None:
        model_name = VGG
        att = Attributer(model_name=model_name)
    if img_nos is None:
        img_nos = [11, 13, 15]
        #img_nos = [6, 97, 278]
    for img_no in img_nos:
        # image handler for later (method attributions)
        ih = ImageHandler(img_no=img_no, model_name=VGG)
        # predictions
        max_pred, max_p = att.predict_for_model(ih)
        plt.figure(figsize=(15, 10))
        plt.suptitle(
            'Attributions for example {}, prediction = `{}`, probability = {:.2f}'
            .format(img_no, max_pred, max_p))
        # original image
        plt.subplot(2, 4, 1)
        plt.axis('off')
        plt.title('ImageNet Example {}'.format(img_no))
        plt.imshow(
            plt.imread(get_image_file_name(IMG_BASE_PATH, img_no) + '.JPEG'))
        # annotated image
        plt.subplot(2, 4, 2)
        plt.title('Annotated Example {}'.format(img_no))
        plt.imshow(
            plt.imread(
                get_image_file_name(ANNOTATE_BASE_PATH, img_no) + '.JPEG'))
        # processed image
        plt.subplot(2, 4, 3)
        plt.title('Reshaped Example')
        plt.imshow(demo_resizer(img_no=img_no, target_size=ih.get_size()))
        # processed image
        plt.subplot(2, 4, 4)
        plt.title('Annotation Mask')
        plt.imshow(get_mask_for_eval(img_no=img_no, target_size=ih.get_size()),
                   cmap='seismic',
                   clim=(-1, 1))

        attributions = att.attribute_panel(ih=ih,
                                           methods=METHODS,
                                           save=False,
                                           visualise=False,
                                           take_threshold=False,
                                           take_absolute=False,
                                           sigma_multiple=1)
        # show attributions
        for i, a in enumerate(attributions.keys()):
            plt.subplot(2, 4, 5 + i)
            plt.title(a)
            plt.axis('off')
            plt.imshow(ih.get_original_img(), cmap='gray', alpha=0.75)
            plt.imshow(attributions[a],
                       cmap='seismic',
                       clim=(-1, 1),
                       alpha=0.8)
        plt.show()
        plt.clf()
        plt.close()