Exemplo n.º 1
0
def gen_rise_grounding(img, model, device='cuda', index=1):
    # Load black box model for explanations
    model = nn.Sequential(model, nn.Softmax(dim=1))
    model = model.to(device)
    model = model.eval()

    for p in model.parameters():
        p.requires_grad = False

    w, h, _ = img.shape

    #create explainer
    explainer = RISE(model, (w, h), 50)

    # Generate masks for RISE or use the saved ones.
    maskspath = 'masks.npy'
    generate_new = True

    if generate_new or not os.path.isfile(maskspath):
        explainer.generate_masks(N=6000, s=8, p1=0.1, savepath=maskspath)
        print("Masks are generated.")
    else:
        explainer.load_masks(maskspath)
        print('Masks are loaded.')

    #explain instance
    sal = explain_instance(model,
                           explainer,
                           read_tensor(img),
                           top_k=index,
                           device=device)
    print("finished RISE")
    return sal
Exemplo n.º 2
0
def explain_all_batch(imgs, model, device='cuda', show=True):
    model = nn.Sequential(model, nn.Softmax(dim=1))
    model = model.to(device)
    model = model.eval()

    #create explainer
    explainer = RISE(model, (224, 224), 50, device)

    # Generate masks for RISE or use the saved ones.
    maskspath = 'masks.npy'
    generate_new = True

    if generate_new or not os.path.isfile(maskspath):
        explainer.generate_masks(N=6000, s=8, p1=0.1, savepath=maskspath)
        print("Masks are generated.")
    else:
        explainer.load_masks(maskspath)
        print('Masks are loaded.')

    n_batch = len(imgs)
    #b_size = data_loader.batch_size
    total = n_batch
    # Get all predicted labels first
    target = np.empty(total, 'int64')
    #for i, (imgs, _) in enumerate(tqdm(data_loader, total=n_batch, desc='Predicting labels')):
    #    p, c = torch.max(nn.Softmax(1)(explainer.model(imgs.cuda())), dim=1)
    #    target[i * b_size:(i + 1) * b_size] = c
    p, c = torch.max(nn.Softmax(1)(explainer.model(imgs.to(device))), dim=1)
    target = c
    print(target)
    image_size = imgs.shape[-2:]
    print(image_size)

    # Get saliency maps for all images in val loader
    explanations = np.empty((total, *image_size))
    #for i, (imgs, _) in enumerate(tqdm(data_loader, total=n_batch, desc='Explaining images')):
    #    saliency_maps = explainer(imgs.cuda())
    #    explanations[i * b_size:(i + 1) * b_size] = saliency_maps[
    #        range(b_size), target[i * b_size:(i + 1) * b_size]].data.cpu().numpy()
    saliency_maps = explainer(imgs.cuda())
    try:
        explanations[0] = saliency_maps[range(n_batch),
                                        target].data.cpu().numpy()
    except:
        explanations = 'whoops'
    return explanations, saliency_maps
Exemplo n.º 3
0
def visualize(model, img_dir, visualize_dir, CovidDataLoader):

    dataset = CovidDataLoader(image_dir=img_dir, transform=preprocess)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              pin_memory=True)
    gpu_batch_size = 20

    model = model.eval()
    model = model.cuda()

    for p in model.parameters():
        p.requires_grad = False

    explainer = RISE(model, (224, 224), gpu_batch_size)

    # Generate masks for RISE or use the saved ones.
    maskspath = 'masks.npy'
    generate_new = True

    if generate_new or not os.path.isfile(maskspath):
        explainer.generate_masks(N=1000, s=8, p1=0.1, savepath=maskspath)
    else:
        explainer.load_masks(maskspath, p1=0.1)

    def explain_all(data_loader, explainer):
        # Get all predicted labels first
        target = np.empty(len(data_loader), np.int)
        for i, (img, _) in enumerate(
                tqdm(data_loader,
                     total=len(data_loader),
                     desc='Predicting labels')):
            p, c = torch.max(model(torch.autograd.Variable(img.cuda())), dim=1)
            target[i] = c[0]

        # Get saliency maps for all images in val loader
        explanations = np.empty((len(data_loader), 224, 224))
        for i, (img, _) in enumerate(
                tqdm(data_loader,
                     total=len(data_loader),
                     desc='Explaining images')):
            saliency_maps = explainer(torch.autograd.Variable(img.cuda()))
            explanations[i] = saliency_maps[target[i]].cpu().numpy()
        return explanations

    explanations = explain_all(data_loader, explainer)

    if not os.path.exists(visualize_dir):
        os.makedirs(visualize_dir)
    for i, (img, img_name) in enumerate(
            tqdm(data_loader,
                 total=len(data_loader),
                 desc='Generating visualizations')):
        img_name = os.path.splitext(img_name[0])[0]
        p, c = torch.max(model(torch.autograd.Variable(img.cuda())), dim=1)
        p, c = p.data[0], c.data[0]

        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.axis('off')
        tensor_imshow(img[0])
        plt.subplot(122)
        plt.axis('off')
        tensor_imshow(img[0])
        sal = explanations[i]
        plt.imshow(sal, cmap='jet', alpha=0.25)
        plt.savefig(visualize_dir + "/" + img_name + "_visualization.png")
Exemplo n.º 4
0
def visualize(model, img_dir, visualize_dir, CovidDataLoader):

    dataset = CovidDataLoader(image_dir=img_dir, transform=preprocess)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              pin_memory=True)
    gpu_batch_size = 20

    model = model.eval()
    model = model.cuda()

    for p in model.parameters():
        p.requires_grad = False

    explainer = RISE(model, (224, 224), gpu_batch_size)

    # Generate masks for RISE or use the saved ones.
    maskspath = 'masks.npy'
    generate_new = True

    if generate_new or not os.path.isfile(maskspath):
        explainer.generate_masks(N=10000, s=8, p1=0.1, savepath=maskspath)
    else:
        explainer.load_masks(maskspath)

    def explain_all(data_loader, explainer):
        # Get all predicted labels first
        target = np.empty(len(data_loader), np.int)
        for i, (img, _) in enumerate(
                tqdm(data_loader,
                     total=len(data_loader),
                     desc='Predicting labels')):
            p, c = torch.max(model(torch.autograd.Variable(img.cuda()))[0],
                             dim=1)
            target[i] = c[0]

        # Get saliency maps for all images in val loader
        explanations = np.empty((len(data_loader), 224, 224))
        for i, (img, _) in enumerate(
                tqdm(data_loader,
                     total=len(data_loader),
                     desc='Explaining images')):
            saliency_maps = explainer(torch.autograd.Variable(img.cuda()))
            explanations[i] = saliency_maps[target[i]].cpu().numpy()
        return explanations

    explanations = explain_all(data_loader, explainer)
    cm = plt.get_cmap("jet")
    if not os.path.exists(visualize_dir):
        os.makedirs(visualize_dir)
    for i, (img, imgName) in enumerate(
            tqdm(data_loader,
                 total=len(data_loader),
                 desc='Generating visualizations')):
        img_name = os.path.splitext(imgName[0])[0]
        p, c = torch.max(model(torch.autograd.Variable(img.cuda()))[0], dim=1)
        p, c = p.data[0], c.data[0]

        ######
        #         index=0
        #         for index,fpath in enumerate(mapsPaths):
        #             aName = mapsPaths[index].split("/")[-1][:-4]
        #             if "vis_"+img_name==aName:
        #                 break
        #         imageAt = Image.open(mapsPaths[index]).convert('RGB')
        #         imgPre = preprocess(imageAt)

        ######

        plt.figure(figsize=(16, 10))
        plt.subplot(121)
        plt.axis('off')
        tensor_imshow(img[0])
        plt.subplot(122)
        plt.axis('off')
        tensor_imshow(img[0])
        sal = explanations[i]

        #         print(sal)

        ## Heat map for cv2 image
        #         sal2 = (sal*256).astype(np.uint8)
        #         heatmap =cv2.applyColorMap(sal2, cv2.COLORMAP_JET)
        #         imgVal = reverse(img[0])
        #         rescaled= np.uint8(imgVal*256)

        #         fin = cv2.addWeighted(heatmap, 0.75, rescaled, 0.25, 0)
        #         img_pil = Image.fromarray(fin).convert('RGB')
        #         img_pil = img_pil.resize((480,480))
        #         img_pil.save(visualize_dir+"/"+imgName[0])

        #########################
        plt.imshow(sal, cmap='jet', alpha=0.25)
        plt.savefig(visualize_dir + "/" + img_name + "_visualization.png")