예제 #1
0
def get_saliency(explainer, input, target):
    input = utils.cuda_var(input.clone(), requires_grad=True)
    # sum over 3 input channels
    #saliency = explainer.explain(input, target)
    #saliency[input < 0] = 0
    saliency = explainer.explain(input, target).sum(axis=1).unsqueeze(1).detach()    # average pool - should sum according to experiment in Kindermans et al,
    # but avg_pool2d is functionally equivalent since each patch is the same size.
    saliency = torch.nn.functional.avg_pool2d(saliency, (9,9),(9,9))
    # shape = (batch_size, num_patches)
    saliency = saliency.reshape(saliency.shape[0],-1)
    saliency = torch.argsort(saliency, dim=1, descending=True)
    return saliency
예제 #2
0
    model = utils.load_model('vgg16').cuda()

    results = {}

    for batch_no in range(num_batches):
        print("Batch {} of {}".format(batch_no + 1, num_batches))
        image_batch = image_paths[batch_no * batch_size:(batch_no + 1) *
                                  batch_size]
        raw_imgs = [viz.pil_loader(image_path) for image_path in image_batch]
        # make sure preprocessing is correct
        inputs = [
            get_preprocess('vgg16', 'pattern_vanilla_grad')(raw_img)
            for raw_img in raw_imgs
        ]
        inputs = torch.stack(inputs).cuda()
        inputs = utils.cuda_var(inputs, requires_grad=True)

        diff_sum = 0

        with torch.cuda.device(0):
            torch.cuda.empty_cache()
            model = utils.load_model('vgg16').cuda()
            explainer = get_explainer(model, 'vanilla_grad')

            out = torch.softmax(model(inputs.clone()), axis=-1)
            classes = torch.max(out, dim=1)[1]
            out = out.detach().cpu().numpy()

            # get baseline val
            baseline_inp = torch.zeros_like(inputs)
            baseline_out = torch.softmax(model(baseline_inp),
예제 #3
0
    def saliency(self, path):

        #path = path.replace("base2","base")
        path_gt = path.replace("base/", "base2/")

        path = "./" + path
        path_gt = "./" + path_gt
        print path_gt
        if os.path.isdir(path_gt):

            for model_name, method_name, _ in self.model_methods:

                dirc = os.listdir(path)
                dirc_gt = os.listdir(path_gt)

                files = [fname for fname in dirc if fname.endswith('png')]
                files_gt = [
                    fname for fname in dirc_gt if fname.endswith('mat')
                ]

            def bounding_box(points):
                x_coordinates, y_coordinates = zip(*points)
                return [(min(x_coordinates), min(y_coordinates)),
                        (max(x_coordinates), max(y_coordinates))]

            for index in self.myRange(0, len(files), 16):
                #print "frame ke", index, "from", len(files)
                if index == len(files):
                    continue
                video = []
                flows = []
                boxes = []

                for filename in sorted(files)[index:index + 16]:
                    video.append(Image.open(path + filename))

                diff = 0
                matfile = scipy.io.loadmat(path_gt + files_gt[0])
                coor = np.array(matfile["pos_img"]).transpose(
                    2, 1, 0).tolist()[index:index + 16]
                scale = matfile["scale"][0]

                if len(coor) == 0:
                    continue

                print len(video), len(coor)

                if len(coor) != 16 or len(video) != 16:
                    diff = len(video)
                    video = video * (16 / len(video)) * 2
                    video = video[0:16]
                    coor = coor * (16 / len(coor)) * 2
                    coor = coor[0:16]

                print len(video), len(coor)
                if len(coor) == 0:
                    continue
                for e in range(0, len(coor), 1):

                    box = bounding_box([(abs(dots[0]) * 112.0 / 320.0,
                                         abs(dots[1]) * 112.0 / 240.0)
                                        for dots in coor[e]])
                    boxes.append(box)

                self.spatial_transform.randomize_parameters()
                self.spatial_transform2.randomize_parameters()
                self.spatial_transform3.randomize_parameters()

                clip = [self.spatial_transform3(img) for img in video]
                inp = torch.stack(clip, 0).permute(1, 0, 2, 3)

                all_saliency_maps = []
                for model_name, method_name, _ in self.model_methods:

                    if method_name == 'googlenet':  # swap channel due to caffe weights
                        inp_copy = inp.clone()
                        inp[0] = inp_copy[2]
                        inp[2] = inp_copy[0]
                    inp = utils.cuda_var(inp.unsqueeze(0), requires_grad=True)

                    saliency, s, kls, scr, c = self.explainer.explain(inp)
                    saliency2, s2, kls, scr, c2 = self.explainer2.explain(inp)
                    saliency3, s3, kls, scr, c3 = self.explainer3.explain(inp)
                    saliency4, s4, kls, scr, c4 = self.explainer4.explain(inp)
                    saliency5, s5, kls, scr, c5 = self.explainer5.explain(inp)
                    saliency6, pool, kls, scr, c6 = self.explainer6.explain(
                        inp)

                    torch.cuda.empty_cache()

                    saliency = (saliency6 + saliency5 + saliency4 + saliency3 +
                                saliency2 + saliency)

                    if self.classes[kls] == path.split("/")[5]:
                        label = 1
                    else:
                        label = 0

                    saliency = torch.clamp(saliency, min=0)

                    temp = saliency.shape[2]

                    if temp > 1:
                        all_saliency_maps.append(
                            saliency.squeeze().cpu().data.numpy())
                    else:
                        all_saliency_maps.append(
                            saliency.squeeze().unsqueeze(0).cpu().numpy())

                    del pool, inp, saliency, saliency6
                    torch.cuda.empty_cache()

                plt.figure(figsize=(50, 5))

                for k in range(len(video[0:(16 - diff)])):
                    hit = 0
                    hit2 = 0
                    hit3 = 0
                    hit4 = 0
                    hit5 = 0
                    hit6 = 0
                    hit7 = 0
                    plt.subplot(2, 16, k + 1)
                    img = self.spatial_transform2(video[k])

                    if len(boxes) > 0:
                        box = boxes[k]

                        x = box[0][0]
                        y = box[0][1]
                        w = box[1][0] - x
                        h = box[1][1] - y

                        if (x + w) > 112:
                            w = (112 - x)
                        if (y + h) > 112:
                            h = (112 - y)

                        ax = viz.plot_bbox([x, y, w, h], img)

                    plt.axis('off')
                    ax = plt.gca()

                    ax.imshow(img)
                    sal = all_saliency_maps[0][k]
                    sal = (sal - np.mean(sal)) / np.std(sal)
                    ret, thresh = cv2.threshold(
                        sal,
                        np.mean(sal) + ((np.amax(sal) - np.mean(sal)) * 0.5),
                        1, cv2.THRESH_BINARY)

                    contours, hierarchy = cv2.findContours(
                        thresh.astype(np.uint8), 1, 2)

                    areas = [cv2.contourArea(c) for c in contours]

                    if len(contours) > 0:

                        glob = np.array(
                            [cv2.boundingRect(cnt) for cnt in contours])
                        #print glob.shape
                        x3 = np.amin(glob[:, 0])
                        y3 = np.amin(glob[:, 1])
                        x13 = np.amax(glob[:, 0] + glob[:, 2])
                        y13 = np.amax(glob[:, 1] + glob[:, 3])

                        rect3 = patches.Rectangle((x3, y3),
                                                  x13 - x3,
                                                  y13 - y3,
                                                  linewidth=2,
                                                  edgecolor='y',
                                                  facecolor='none')
                        ax.add_patch(rect3)

                        for cnt in contours:
                            x2, y2, w2, h2 = cv2.boundingRect(cnt)

                            rect2 = patches.Rectangle((x2, y2),
                                                      w2,
                                                      h2,
                                                      linewidth=2,
                                                      edgecolor='r',
                                                      facecolor='none')
                            ax.add_patch(rect2)

                        overlap = nms.get_iou([x, x + w, y, y + h],
                                              [x3, x13, y3, y13])

                        if label == 1:
                            if overlap >= 0.6:
                                hit = 1
                            if overlap >= 0.5:
                                hit2 = 1
                            if overlap >= 0.4:
                                hit3 = 1
                            if overlap >= 0.3:
                                hit4 = 1
                            if overlap >= 0.2:
                                hit5 = 1
                            if overlap >= 0.1:
                                hit6 = 1
                            if overlap > 0.0:
                                hit7 = 1

                    self.totalhit += hit
                    self.totalhit2 += hit2
                    self.totalhit3 += hit3
                    self.totalhit4 += hit4
                    self.totalhit5 += hit5
                    self.totalhit6 += hit6
                    self.totalhit7 += hit7
                    self.totalframes += 1
                    print "================="
                    print "accuracy0.6=", float(
                        self.totalhit) / self.totalframes
                    print "accuracy0.5=", float(
                        self.totalhit2) / self.totalframes
                    print "accuracy0.4=", float(
                        self.totalhit3) / self.totalframes
                    print "accuracy0.3=", float(
                        self.totalhit4) / self.totalframes
                    print "accuracy0.2=", float(
                        self.totalhit5) / self.totalframes
                    print "accuracy0.1=", float(
                        self.totalhit6) / self.totalframes
                    print "accuracy0.0=", float(
                        self.totalhit7) / self.totalframes

                    for saliency in all_saliency_maps:
                        show_style = 'camshow'

                        plt.subplot(2, 16, k + 17)
                        if show_style == 'camshow':

                            viz.plot_cam(np.abs(saliency[k]).squeeze(),
                                         img,
                                         'jet',
                                         alpha=0.5)

                            plt.axis('off')
                            plt.title(float(np.average(saliency[k])))

                            self.seq.append(
                                np.array(np.expand_dims(saliency[k], axis=2)) *
                                np.array(img))

                        else:
                            if model_name == 'googlenet' or method_name == 'pattern_net':
                                saliency = saliency.squeeze()[::-1].transpose(
                                    1, 2, 0)
                            else:
                                saliency = saliency.squeeze().transpose(
                                    1, 2, 0)
                            saliency -= saliency.min()
                            saliency /= (saliency.max() + 1e-20)
                            plt.imshow(saliency, cmap='gray')

                        if method_name == 'excitation_backprop':
                            plt.title('Exc_bp')
                        elif method_name == 'contrastive_excitation_backprop':
                            plt.title('CExc_bp')
                        else:
                            plt.title('%s' % (method_name))

                plt.tight_layout()
                print path.split("/")

                plt.savefig('./embrace_%i_%s.png' %
                            (index, path.split("/")[-2]))

            torch.cuda.empty_cache()

            print path.split("/")

            return self.seq, self.kls, self.scr
예제 #4
0
transf = transforms.Compose([
    PatternPreprocess((224, 224))
])  # this transform is the right one for pattern-based models and methods

for model_name, method_name, _ in model_methods:
    # transf = transforms.Compose([
    #     transforms.Resize((225, 225)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                          std=[0.229, 0.224, 0.225])
    # ])
    for pattern_attribution in [False, True]:
        torch.cuda.empty_cache()

        img_input = transf(raw_img).cuda()
        img_input = utils.cuda_var(img_input.unsqueeze(0), requires_grad=True)

        model = utils.load_model(model_name).cuda()

        explainer = get_explainer(model, method_name)
        if pattern_attribution:
            explainer = pattern_augment(explainer,
                                        method='pattern_attribution')
        pred = model(img_input)

        ind = pred.data.max(1)[1]
        ind = torch.tensor(np.expand_dims(np.array(image_class), 0)).cuda()
        print(
            f'Processing {method_name}, predicted class is {lrp_utils.imgclasses[ind.item()]}'
        )
예제 #5
0
for model_name, method_name, _ in model_methods:
    # Get a specific picture transformation (see torchvision.transforms documentation)
    transf = get_preprocess(model_name, method_name)
    # Load the pretrained model
    model = utils.load_model(model_name)
    model.cuda()
    # Get the explainer
    explainer = get_explainer(model, method_name)

    # Transform the image
    inp = transf(raw_img)
    if method_name == 'googlenet':  # swap channel due to caffe weights
        inp_copy = inp.clone()
        inp[0] = inp_copy[2]
        inp[2] = inp_copy[0]
    inp = utils.cuda_var(inp.unsqueeze(0), requires_grad=True)

    target = torch.LongTensor([image_class]).cuda()
    saliency = explainer.explain(inp, target)
    saliency = utils.upsample(saliency, (raw_img.height, raw_img.width))
    #all_saliency_maps.append(saliency.cpy().numpy())
    all_saliency_maps.append(saliency.cpu().numpy())

# Display all the results
plt.figure(figsize=(25, 15))
plt.subplot(3, 5, 1)
plt.imshow(raw_img)
plt.axis('off')
plt.title(displayed_class)
for i, (saliency,
        (model_name, method_name,
예제 #6
0
def compute_saliency_map(model_name, displayed_class, number_image):
    model_methods = [
        [model_name, 'vanilla_grad', 'imshow'],
        [model_name, 'grad_x_input', 'imshow'],
        [model_name, 'saliency', 'imshow'],
        [model_name, 'integrate_grad', 'imshow'],
        [model_name, 'deconv', 'imshow'],
        [model_name, 'guided_backprop', 'imshow'],
        #[model_name, 'gradcam', 'camshow'],
        #[model_name, 'excitation_backprop', 'camshow'],
        #[model_name, 'contrastive_excitation_backprop', 'camshow']
    ]
    # Change 'image_class' to 0 if you want to display for a dog
    if (displayed_class == "dog"):
        image_class = 0
    elif (displayed_class == "cat"):
        image_class = 1
    else:
        print("ERROR: wrong displayed class")

    # Take the sample image, and display it (original form)
    image_path = "models/test_" + displayed_class + "_images/" + str(
        number_image)

    raw_img = viz.pil_loader(image_path)
    plt.figure(figsize=(5, 5))
    plt.imshow(raw_img)
    plt.axis('off')
    plt.title(displayed_class)

    # Now, we want to display the saliency maps of this image, for every model_method element
    all_saliency_maps = []

    for model_name, method_name, _ in model_methods:
        # Get a specific picture transformation (see torchvision.transforms documentation)
        transf = get_preprocess(model_name, method_name)
        # Load the pretrained model
        model = utils.load_model(model_name)
        model.cuda()
        # Get the explainer
        explainer = get_explainer(model, method_name)

        # Transform the image
        inp = transf(raw_img)
        if method_name == 'googlenet':  # swap channel due to caffe weights
            inp_copy = inp.clone()
            inp[0] = inp_copy[2]
            inp[2] = inp_copy[0]
        inp = utils.cuda_var(inp.unsqueeze(0), requires_grad=True)

        target = torch.LongTensor([image_class]).cuda()
        saliency = explainer.explain(inp, target)
        saliency = utils.upsample(saliency, (raw_img.height, raw_img.width))
        #all_saliency_maps.append(saliency.cpy().numpy())
        all_saliency_maps.append(saliency.cpu().numpy())

    # Display all the results
    plt.figure(figsize=(25, 15))
    plt.subplot(3, 5, 1)
    plt.imshow(raw_img)
    plt.axis('off')
    plt.title(displayed_class)
    for i, (saliency,
            (model_name, method_name,
             show_style)) in enumerate(zip(all_saliency_maps, model_methods)):
        plt.subplot(3, 5, i + 2 + i // 4)
        if show_style == 'camshow':
            viz.plot_cam(np.abs(saliency).max(axis=1).squeeze(),
                         raw_img,
                         'jet',
                         alpha=0.5)
        else:
            if model_name == 'googlenet' or method_name == 'pattern_net':
                saliency = saliency.squeeze()[::-1].transpose(1, 2, 0)
            else:
                saliency = saliency.squeeze().transpose(1, 2, 0)
            saliency -= saliency.min()
            saliency /= (saliency.max() + 1e-20)
            plt.imshow(saliency, cmap='gray')

        plt.axis('off')
        if method_name == 'excitation_backprop':
            plt.title('Exc_bp')
        elif method_name == 'contrastive_excitation_backprop':
            plt.title('CExc_bp')
        else:
            plt.title('%s' % (method_name))

    plt.tight_layout()

    if not os.path.exists('images/' + model_name + '/'):
        os.makedirs('images/' + model_name + '/')
    save_destination = 'images/' + model_name + '/' + str(
        number_image[:-4]) + '_saliency.png'

    plt.savefig(save_destination)
    plt.clf()