Exemplo n.º 1
0
def compute_attr_one_pixel_target(x, net, spatial_coords, method, **kwargs):
    # x is a single input tensor, i.e. shape (batch_size,channel_size,H,W)=(1,1,28,28)
    from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop

    if 'target' in kwargs:
        target = kwargs['target']
    
    if 'wrapper_output' in kwargs:
        output_mode = kwargs['wrapper_output']
    else:
        output_mode = 'yg_pixel'

    idx,idy = spatial_coords
    wnet = WrapperNet(net, output_mode=output_mode, spatial_coords=(idx,idy))
    
    if method=='gradCAM':
        xai = LayerGradCam(wnet, wnet.main_net.channel_adj)
    elif method=='deconv':
        xai = Deconvolution(wnet)
    elif method=='GuidedBP':
        xai = GuidedBackprop(wnet)
    
    if method in ['gradCAM', 'deconv', 'GuidedBP']:
        attr = xai.attribute(x, target=target )
    elif method == 'layerAct':
        attr = xai.attribute(x)

    attr = attr[0][0].clone().detach().cpu().numpy()
    return attr
def main(mixup=False):
    prefix = "Mixup" if mixup else "Large"
    run = wandb.init(
        name=f"Interpretability ({prefix})",
        project="ct-interpretability",
        dir=DEFAULT_DATA_STORAGE,
        reinit=True,
    )
    model = get_model(mixup)
    dataset = get_miccai_2d(
        "test",
        transform=DEGREE[model.hparams.transform_degree]["test"],
        enhanced="Boundary" in model.hparams.loss_fx,
    )

    class_labels = dict(zip(range(1, model._n_classes), miccai.STRUCTURES))
    class_labels[0] = "Void"
    step = 0

    for sample in tqdm(dataset):
        preproc_img, masks, _, *others = sample
        normalized_inp = preproc_img.unsqueeze(0).to(device)
        normalized_inp.requires_grad = True
        masks = _squash_masks(masks, 10, masks.device)

        if len(masks.unique()) < 6:
            # Only displaying structures with atleast 5 structures (excluding background)
            continue

        out = model(normalized_inp)
        out_max = _squash_predictions(out).unsqueeze(1)

        log_samples(preproc_img, masks, out_max, class_labels, step)

        def segmentation_wrapper(input):
            return model(input).sum(dim=(2, 3))

        layer = model.unet.model[2][1].conv.unit0.conv
        lgc = LayerGradCam(segmentation_wrapper, layer)

        figures = []
        for structure in miccai.STRUCTURES:
            idx = structures.index(structure)
            gc_attr = lgc.attribute(normalized_inp, target=idx)
            fig, ax = viz.visualize_image_attr(
                gc_attr[0].cpu().permute(1, 2, 0).detach().numpy(),
                sign="all",
                use_pyplot=False,
            )
            ax.set_title(structure)
            figures.append(wandb.Image(fig))

        wandb.log({"GradCam Attributions": figures}, step=step)
        step += 1

    run.finish()
Exemplo n.º 3
0
def layer_grad_cam(model, image, results_dir, file_name):

    layer_gc = LayerGradCam(model, model.skip_4[0])
    attr = layer_gc.attribute(image)

    attr = LayerAttribution.interpolate(attr, (121, 145, 121))

    attr = attr.squeeze()
    attr = attr.detach().numpy()
    attr = nib.Nifti1Image(attr, affine=np.eye(4))

    nib.save(attr, results_dir + file_name + "-HM.nii")
Exemplo n.º 4
0
def get_grad_cam(x, y, model, is_training=True):
    ''' Choose the last conv layer which only has 7x7 out of 224x224 '''
    with torch.enable_grad(), \
         HackGradAndOutputs(is_training=is_training) as hack:
        lgc = LayerGradCam(model, model.get_grad_cam_layer())
        attributions = lgc.attribute(x, target=y)

        attributions = F.interpolate(attributions,
                                     size=x.shape[-2:],
                                     mode='bilinear')

        return attributions, hack.output
Exemplo n.º 5
0
def run(arch, img, target):
    input = Image.open(img).convert('RGB')
    input = apply_transform(input)
    model = models.vgg16(pretrained=True).eval()
    ig = LayerGradCam(model, model.features[28])
    out = FF.softmax(model(input), dim=1)
    class_idx = out.max(1)[-1].item()
    attr = ig.attribute(input, target=target)
    attr = LayerAttribution.interpolate(attr, (224, 224))
    attr = (attr - attr.min()) / (attr.max() - attr.min())
    #attr=attr.squeeze(0).squeeze(0)
    #print('IG Attributions:',attr, attr.shape)
    return attr
Exemplo n.º 6
0
def heatmaps(args):
    print('heatmaps')
    PROJECT_ID = args['PROJECT_ID']

    CKPT_DIR, PROJECT_DIR, MODEL_DIR, LOGGER_DIR, load_model = folder_check(PROJECT_ID, CKPT_DIR='checkpoint')
    XAI_DIR = os.path.join(PROJECT_DIR,'XAI')
    if not os.path.exists(XAI_DIR): os.mkdir(XAI_DIR)

    from .sampler import Pytorch_GPT_MNIST_Sampler
    samp = Pytorch_GPT_MNIST_Sampler(compenv_mode=None, growth_mode=None)

    from .model import ResGPTNet34
    net = ResGPTNet34(nG0=samp.gen.nG0, Nj=samp.gen.N_neighbor)
    net = torch.load(MODEL_DIR)
    net.output_mode = 'prediction_only'
    net.to(device=device)
    net.eval()

    x, y0, yg0, ys0 = samp.get_sample_batch(class_indices=np.array(range(10)), device=device)
    x.requires_grad=True

    attrs = {}
    SAVE_DIR = os.path.join(XAI_DIR, 'heatmaps.y0.jpeg')

    from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop # ShapleyValueSampling

    xai = LayerGradCam(net, net.channel_adj)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['gradCAM'] = attr

    xai = Deconvolution(net)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['deconv'] = attr

    xai = GuidedBackprop(net)
    attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy()
    attrs['GuidedBP'] = attr

    arrange_heatmaps(x.clone().detach().cpu().numpy() , attrs, save_dir=SAVE_DIR)
Exemplo n.º 7
0
    def explanation(self, dataindex):
        oldIndices = self.unknown.indices.copy()
        self.unknown.indices = dataindex
        datasetLoader = torch.utils.data.DataLoader(dataset=self.unknown,
                                                    batch_size=1,
                                                    shuffle=False)
        self.model.eval()
        # for param in self.model.parameters():
        #     param.requires_grad = False
        avg_loss = []
        #Dont forget to replace indices at end ##########

        layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2)
        #deep lift
        dl = LayerDeepLift(self.model, self.model.layer1[0].conv2)

        # atrr = []
        plt.figure(figsize=(18, 10))

        for i, batch in enumerate(datasetLoader):

            lb = batch[1].to(device)
            print(len(lb))
            img = batch[0].to(device)
            # plt.subplot(2,1,1)
            # plt.imshow(img.squeeze().cpu().numpy())

            lbin = batch[1].cpu().numpy()
            print(lbin)
            pred = self.model(img)
            predlb = torch.argmax(pred, 1)
            print('Prediction label is :', predlb.cpu().numpy())
            print('Ground Truth label is: ', lb.cpu().numpy())

            # gc_attr = layer_gc.attribute(img, target=int(lbin[0]))
            gc_attr = layer_gc.attribute(img, target=int(predlb.cpu().numpy()))
            upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28))

            base = torch.zeros([1, 1, 28, 28]).to(device)
            de_attr = dl.attribute(img, base, target=int(lbin[0]))
            dl_upsampled_attr = LayerAttribution.interpolate(de_attr, (28, 28))

            # upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28))
            # plt.subplot(2,1,2)
            # plt.imshow(upsampled_attr.squeeze().detach().cpu().numpy())
            # atrr.append[gc_attr]
            print("done ...")
            # print(gc_attr,upsampled_attr.squeeze().detach().cpu().numpy())
            # plt.show()
            return img, gc_attr, upsampled_attr.squeeze().detach().cpu().numpy(
            ), dl_upsampled_attr.squeeze().detach().cpu().numpy()
Exemplo n.º 8
0
def explain_gradXact(model, node_idx, x, edge_index, target, include_edges=None):
    # Captum default implementation of LayerGradCam does not average over nodes for different channels because of
    # different assumptions on tensor shapes
    input_mask = x.clone().requires_grad_(True).to(device)
    layers = get_all_convolution_layers(model)
    node_attrs = []
    for layer in layers:
        layer_gc = LayerGradCam(model_forward_node, layer)
        node_attr = layer_gc.attribute(input_mask, target=target, additional_forward_args=(model, edge_index, node_idx))
        node_attr = node_attr.cpu().detach().numpy().ravel()
        node_attrs.append(node_attr)
    node_attr = np.array(node_attrs).mean(axis=0)
    edge_mask = node_attr_to_edge(edge_index, node_attr)
    return edge_mask
Exemplo n.º 9
0
    def validate_explanation(self):

        predlist = torch.zeros(0, dtype=torch.long, device='cpu')
        lbllist = torch.zeros(0, dtype=torch.long, device='cpu')
        scores = []
        layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2)

        for i, batch in enumerate(self.test_loader):

            lb = batch[1].to(device)
            img = batch[0].to(device)

            pred = self.model(img)
            predlb = torch.argmax(pred, 1)

            gc_attr = layer_gc.attribute(img,
                                         target=predlb,
                                         relu_attributions=False)
            upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64))

            sz = upsampled_attr.size()
            x = upsampled_attr.view(sz[0], sz[1], -1)
            # print(x.size())
            upsampled_attr_soft = F.softmax(x, dim=2).view_as(upsampled_attr)
            # print(upsampled_attr_soft.size())
            # upsampled_attr = F.softmax(upsampled_attr)

            # gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=False)
            # upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64))

            # Append batch prediction results
            predlist = torch.cat(
                [predlist,
                 upsampled_attr_soft.detach().squeeze().cpu()])
            lbllist = torch.cat([
                lbllist, self.sintetic[:upsampled_attr_soft.size(
                )[0], :, :, :].squeeze().cpu()
            ],
                                dim=0)
        print(predlist.size, lbllist.size)
        final_prec, final_rec, final_corr = self.calculate_measures(
            lbllist, predlist)
        # return final_prec, final_rec, final_corr
        # Save checkpoint
        print('Final validation result...')
        print("- final explanation percision: {:.3f}".format(final_prec))
        print("- final explanation recal: {:.3f}".format(final_rec))
        print("- final explanation correlation: {:.3f}".format(final_corr))
Exemplo n.º 10
0
    def compute_gradcam(self, img_path, target):

        # open image
        img, transformed_img, input = self.open_image(img_path)

        # grad cam
        input.requires_grad = True
        gradcam = LayerGradCam(self.model, self.layer)
        attr = gradcam.attribute(input, target)
        cam = (attr.squeeze().cpu().detach().numpy())

        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, input.shape[2:])
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)

        return cam
Exemplo n.º 11
0
def train_single_scale(D, G, reals, generators, noise_maps,
                       input_from_prev_scale, noise_amplitudes, opt):
    """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the
    original level, generators and noise_maps contain information from previous scales and will receive information in
    this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold
    the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """
    current_scale = len(generators)
    real = reals[current_scale]

    keepSky = False
    kernel_dims = (2, 2)

    # Initialize real detector
    real0 = preprocess(opt, real, keepSky)
    N, C, H, W = real0.shape

    scale = opt.scales[current_scale] if current_scale < len(opt.scales) else 1

    if opt.cgan:
        detector = PCA_Detector(opt, 'real', real0, kernel_dims)
        real_detection_map = detector(real0)
        detection_scale = 0.1
        real_detection_map *= detection_scale
        real1 = torch.cat(
            [real, F.interpolate(real_detection_map, (H, W))], dim=1)
        divergences = []
    else:
        real1 = real

    if opt.game == 'mario':
        token_group = MARIO_TOKEN_GROUPS
    else:  # if opt.game == 'mariokart':
        token_group = MARIOKART_TOKEN_GROUPS

    nzx = real.shape[2]  # Noise size x
    nzy = real.shape[3]  # Noise size y

    padsize = int(
        1 * opt.num_layer
    )  # As kernel size is always 3 currently, padsize goes up by one per layer

    if not opt.pad_with_noise:
        pad_noise = nn.ZeroPad2d(padsize)
        pad_image = nn.ZeroPad2d(padsize)
    else:
        pad_noise = nn.ReflectionPad2d(padsize)
        pad_image = nn.ReflectionPad2d(padsize)

    # setup optimizer
    optimizerD = optim.Adam(D.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(G.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600, 2500],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600, 2500],
                                                      gamma=opt.gamma)

    if current_scale == 0:  # Generate new noise
        z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                       device=opt.device)
        z_opt = pad_noise(z_opt)
    else:  # Add noise to previous output
        z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device)
        z_opt = pad_noise(z_opt)

    logger.info("Training at scale {}", current_scale)
    for epoch in tqdm(range(opt.niter)):
        step = current_scale * opt.niter + epoch
        noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                        device=opt.device)
        noise_ = pad_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            D.zero_grad()

            output = D(real1).to(opt.device)

            errD_real = -output.mean()
            errD_real.backward(retain_graph=True)

            # train with fake
            if (j == 0) & (epoch == 0):
                if current_scale == 0:  # If we are in the lowest scale, noise is generated from scratch
                    prev = torch.zeros(1, opt.nc_current, nzx,
                                       nzy).to(opt.device)
                    input_from_prev_scale = prev
                    prev = pad_image(prev)
                    z_prev = torch.zeros(1, opt.nc_current, nzx,
                                         nzy).to(opt.device)
                    z_prev = pad_noise(z_prev)
                    opt.noise_amp = 1
                else:  # First step in NOT the lowest scale
                    # We need to adapt our inputs from the previous scale and add noise to it
                    prev = draw_concat(generators, noise_maps, reals,
                                       noise_amplitudes, input_from_prev_scale,
                                       "rand", pad_noise, pad_image, opt)

                    # For the seeding experiment, we need to transform from token_groups to the actual token
                    if current_scale == (opt.token_insert + 1):
                        prev = group_to_token(prev, opt.token_list,
                                              token_group)

                    prev = interpolate(prev,
                                       real1.shape[-2:],
                                       mode="bilinear",
                                       align_corners=False)
                    prev = pad_image(prev)
                    z_prev = draw_concat(generators, noise_maps, reals,
                                         noise_amplitudes,
                                         input_from_prev_scale, "rec",
                                         pad_noise, pad_image, opt)

                    # For the seeding experiment, we need to transform from token_groups to the actual token
                    if current_scale == (opt.token_insert + 1):
                        z_prev = group_to_token(z_prev, opt.token_list,
                                                token_group)

                    z_prev = interpolate(z_prev,
                                         real1.shape[-2:],
                                         mode="bilinear",
                                         align_corners=False)
                    opt.noise_amp = update_noise_amplitude(
                        z_prev, real1[:, :-1], opt)
                    z_prev = pad_image(z_prev)
            else:  # Any other step
                prev = draw_concat(generators, noise_maps, reals,
                                   noise_amplitudes, input_from_prev_scale,
                                   "rand", pad_noise, pad_image, opt)

                # For the seeding experiment, we need to transform from token_groups to the actual token
                if current_scale == (opt.token_insert + 1):
                    prev = group_to_token(prev, opt.token_list, token_group)

                prev = interpolate(prev,
                                   real1.shape[-2:],
                                   mode="bilinear",
                                   align_corners=False)
                prev = pad_image(prev)

            # After creating our correct noise input, we feed it to the generator:
            noise = opt.noise_amp * noise_ + prev
            fake = G(noise.detach(),
                     prev,
                     temperature=1 if current_scale != opt.token_insert else 1)

            fake0 = preprocess(opt, fake, keepSky)
            if opt.cgan:
                Nf, Cf, Hf, Wf = fake0.shape
                fake_detection_map = detector(fake0) * detection_scale
                fake1 = torch.cat(
                    [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1)
            else:
                fake1 = fake

            # Then run the result through the discriminator
            output = D(fake1.detach())
            errD_fake = output.mean()

            # Backpropagation
            errD_fake.backward(retain_graph=False)

            # Gradient Penalty
            gradient_penalty = calc_gradient_penalty(D, real1, fake1,
                                                     opt.lambda_grad,
                                                     opt.device)
            gradient_penalty.backward(retain_graph=False)

            # Logging:
            if step % 10 == 0:
                wandb.log(
                    {
                        f"D(G(z))@{current_scale}":
                        errD_fake.item(),
                        f"D(x)@{current_scale}":
                        -errD_real.item(),
                        f"gradient_penalty@{current_scale}":
                        gradient_penalty.item()
                    },
                    step=step,
                    sync=False)
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            G.zero_grad()
            fake = G(noise.detach(),
                     prev.detach(),
                     temperature=1 if current_scale != opt.token_insert else 1)

            fake0 = preprocess(opt, fake, keepSky)
            Nf, Cf, Hf, Wf = fake0.shape

            if opt.cgan:
                fake_detection_map = detector(fake0) * detection_scale
                fake1 = torch.cat(
                    [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1)
            else:
                fake1 = fake

            output = D(fake1)

            errG = -output.mean()
            errG.backward(retain_graph=False)
            if opt.alpha != 0:  # i. e. we are trying to find an exact recreation of our input in the lat space
                Z_opt = opt.noise_amp * z_opt + z_prev
                G_rec = G(
                    Z_opt.detach(),
                    z_prev,
                    temperature=1 if current_scale != opt.token_insert else 1)
                rec_loss = opt.alpha * F.mse_loss(G_rec, real)
                if opt.cgan:
                    div = divergence(real_detection_map,
                                     preprocess(opt, G_rec, keepSky))
                    rec_loss += div
                rec_loss.backward(
                    retain_graph=False
                )  # TODO: Check for unexpected argument retain_graph=True
                rec_loss = rec_loss.detach()
            else:  # We are not trying to find an exact recreation
                rec_loss = torch.zeros([])
                Z_opt = z_opt

            optimizerG.step()

        # More Logging:
        div = divergence(real_detection_map, preprocess(opt, fake, keepSky))
        divergences.append(div)
        # logger.info("divergence(fake) = {}", div)
        if step % 10 == 0:
            wandb.log(
                {
                    f"noise_amplitude@{current_scale}": opt.noise_amp,
                    f"rec_loss@{current_scale}": rec_loss.item()
                },
                step=step,
                sync=False,
                commit=True)

        # Rendering and logging images of levels
        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            if opt.token_insert >= 0 and opt.nc_current == len(token_group):
                token_list = [list(group.keys())[0] for group in token_group]
            else:
                token_list = opt.token_list

            img = opt.ImgGen.render(
                one_hot_to_ascii_level(fake1[:, :-1].detach(), token_list))
            img2 = opt.ImgGen.render(
                one_hot_to_ascii_level(
                    G(Z_opt.detach(),
                      z_prev,
                      temperature=1 if current_scale != opt.token_insert else
                      1).detach(), token_list))
            real_scaled = one_hot_to_ascii_level(real1[:, :-1].detach(),
                                                 token_list)
            img3 = opt.ImgGen.render(real_scaled)
            wandb.log(
                {
                    f"G(z)@{current_scale}": wandb.Image(img),
                    f"G(z_opt)@{current_scale}": wandb.Image(img2),
                    f"real@{current_scale}": wandb.Image(img3)
                },
                sync=False,
                commit=False)

            real_scaled_path = os.path.join(wandb.run.dir,
                                            f"real@{current_scale}.txt")
            with open(real_scaled_path, "w") as f:
                f.writelines(real_scaled)
            wandb.save(real_scaled_path)

        # Learning Rate scheduler step
        schedulerD.step()
        schedulerG.step()

    if opt.cgan:
        div = divergence(real_detection_map, preprocess(opt, z_opt, keepSky))
        divergences.append(div)

    # visualization config
    folder_name = 'gradcam'
    level_name = opt.input_name.rsplit(".", 1)[0].split("_", 1)[1]

    # GradCAM on D
    camD = LayerGradCam(D, D.tail)
    real0 = one_hot_to_ascii_level(real, opt.token_list)
    real0 = opt.ImgGen.render(real0)
    real0 = np.array(real0)
    attr = camD.attribute(real1, target=(0, 0, 0), relu_attributions=True)
    attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]),
                                        'bilinear')
    attr = attr.permute(2, 3, 1, 0).squeeze(3)
    attr = attr.detach().cpu().numpy()
    fig, ax = plt.subplots(1, 1)
    fig.figsize = (10, 1)
    ax.imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1)
    im = ax.imshow(attr, cmap='jet', alpha=0.5)
    ax.axis('off')
    fig.colorbar(im, ax=ax, location='bottom', shrink=0.85)
    plt.suptitle(f'cGAN {level_name} D(x)@{current_scale} ({step})')
    plt.savefig(rf'{folder_name}\{level_name}_D_{current_scale}_{step}.png',
                bbox_inches='tight',
                pad_inches=0.1)
    # plt.show()
    plt.close()

    # GradCAM on G
    token_names = {
        'M': 'Mario start',
        'F': 'Mario finish',
        'y': 'spiky',
        'Y': 'winged spiky',
        'k': 'green koopa',
        'K': 'winged green koopa',
        '!': 'coin [?]',
        '#': 'pyramid',
        '-': 'sky',
        '1': 'invis. 1 up',
        '2': 'invis. coin',
        'L': '1 up',
        '?': 'special [?]',
        '@': 'special [?]',
        'Q': 'coin [?]',
        '!': 'coin [?]',
        'C': 'coin brick',
        'S': 'normal brick',
        'U': 'mushroom brick',
        'X': 'ground',
        'E': 'goomba',
        'g': 'goomba',
        'k': 'green koopa',
        '%': 'platform',
        '|': 'platform bg',
        'r': 'red koopa',
        'R': 'winged red koopa',
        'o': 'coin',
        't': 'pipe',
        'T': 'plant pipe',
        '*': 'bullet bill',
        '<': 'pipe top left',
        '>': 'pipe top right',
        '[': 'pipe left',
        ']': 'pipe right',
        'B': 'bullet bill head',
        'b': 'bullet bill body',
        'D': 'used block',
    }

    def wrappedG(z):
        return G(z, z_opt)

    camG = LayerGradCam(wrappedG, G.tail[0])
    z_cam = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                   device=opt.device)
    z_cam = pad_noise(z_cam)
    attrs = []
    for i in range(opt.nc_current):
        attr = camG.attribute(z_cam, target=(i, 0, 0), relu_attributions=True)
        attr = LayerAttribution.interpolate(attr,
                                            (real0.shape[0], real0.shape[1]),
                                            'bilinear')
        attr = attr.permute(2, 3, 1, 0).squeeze(3)
        attr = attr.detach().cpu().numpy()
        attrs.append(attr)
    fig, axs = plt.subplots(opt.nc_current, 1)
    fig.figsize = (10, opt.nc_current)
    for i in range(opt.nc_current):
        axs[i].axis('off')
        axs[i].text(-0.1,
                    0.5,
                    token_names[opt.token_list[i]],
                    rotation=0,
                    verticalalignment='center',
                    horizontalalignment='right',
                    transform=axs[i].transAxes)
        axs[i].imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1)
        im = axs[i].imshow(attrs[i], cmap='jet', alpha=0.5)
    fig.colorbar(im, ax=axs, shrink=0.85)
    plt.suptitle(f'cGAN {level_name} G(z)@{current_scale} ({step})')
    plt.savefig(rf'{folder_name}\{level_name}_G_{current_scale}_{step}.png',
                bbox_inches='tight',
                pad_inches=0.1)
    # plt.show()
    plt.close()

    # Save networks
    torch.save(z_opt, "%s/z_opt.pth" % opt.outf)
    save_networks(G, D, z_opt, opt)
    wandb.save(opt.outf)
    return z_opt, input_from_prev_scale, G, divergences
Exemplo n.º 12
0
                                                     test_list, model_type,
                                                     config_dict)
        test_loader = loader_dict['test']
        captum_gc = LayerGradCam(model, target_layer)
        output_file = file_dir + '/heatmaps/imageset_captum_new_' + model_type + "_" + str(
            test)
        process_sequence_captum(captum_gc, test_loader, model_type, device,
                                output_file)

    #%%
    finalconv_name = 'layer4'
    target_layer = model.cnn._modules.get(finalconv_name)
    inputs, inputs_aug, _ = next(iter(loader_dict['train']))
    captum_gc = LayerGradCam(model, target_layer)
    attributions = captum_gc.attribute((inputs.to(
        device, dtype=torch.float), inputs_aug.to(device, dtype=torch.float)),
                                       target=0,
                                       relu_attributions=True)
    upsampled_attr = captum_gc.interpolate(attributions, (224, 224), 'bicubic')
    saliency_map_min, saliency_map_max = upsampled_attr.min(
    ), upsampled_attr.max()
    upsampled_attr2 = (upsampled_attr -
                       saliency_map_min).div(saliency_map_max -
                                             saliency_map_min).data
    upsampled_attr = (upsampled_attr -
                      saliency_map_min).div(saliency_map_max -
                                            np.spacing(1)).data
    img = unnorm(inputs.squeeze())
    heatmap_x, results_all_x = visualize_cam(upsampled_attr, img, alpha=0.5)
    plt.imshow(results_all_x.cpu().detach().numpy().transpose(1, 2, 0))
def LayerGradCAM(classifier_model,
                 config,
                 dataset_features,
                 GNNgraph_list,
                 current_fold=None,
                 cuda=0):
    '''
		Attribute to input layer using soft assign
		:param classifier_model: trained classifier model
		:param config: parsed configuration file of config.yml
		:param dataset_features: a dictionary of dataset features obtained from load_data.py
		:param GNNgraph_list: a list of GNNgraphs obtained from the dataset
		:param current_fold: has no use in this method
		:param cuda: whether to use GPU to perform conversion to tensor
	'''
    # Initialise settings
    config = config
    interpretability_config = config["interpretability_methods"][
        "LayerGradCAM"]
    dataset_features = dataset_features
    assign_type = interpretability_config["assign_attribution"]

    # Perform grad cam on the classifier model and on a specific layer
    layer_idx = interpretability_config["layer"]
    if layer_idx == 0:
        gc = LayerGradCam(classifier_model, classifier_model.graph_convolution)
    else:
        gc = LayerGradCam(classifier_model,
                          classifier_model.conv_modules[layer_idx - 1])

    output_for_metrics_calculation = []
    output_for_generating_saliency_map = {}

    # Obtain attribution score for use in qualitative metrics
    tmp_timing_list = []

    for GNNgraph in GNNgraph_list:
        output = {'graph': GNNgraph}
        for _, label in dataset_features["label_dict"].items():
            # Relabel all just in case, may only relabel those that need relabelling
            # if performance is poor
            original_label = GNNgraph.label
            GNNgraph.label = label

            node_feat, n2n, subg = graph_to_tensor(
                [GNNgraph], dataset_features["feat_dim"],
                dataset_features["edge_feat_dim"], cuda)

            start_generation = perf_counter()

            attribution = gc.attribute(node_feat,
                                       additional_forward_args=(n2n, subg,
                                                                [GNNgraph]),
                                       target=label,
                                       relu_attributions=True)

            # Attribute to the input layer using the assign method specified
            reverse_assign_tensor_list = []
            for i in range(1, layer_idx + 1):
                assign_tensor = classifier_model.cur_assign_tensor_list[i - 1]
                max_index = torch.argmax(assign_tensor, dim=1, keepdim=True)
                if assign_type == "hard":
                    reverse_assign_tensor = torch.transpose(
                        torch.zeros(assign_tensor.size()).scatter_(1,
                                                                   max_index,
                                                                   value=1), 0,
                        1)
                else:
                    reverse_assign_tensor = torch.transpose(
                        assign_tensor, 0, 1)

                reverse_assign_tensor_list.append(reverse_assign_tensor)

            attribution = torch.transpose(attribution, 0, 1)

            for reverse_tensor in reversed(reverse_assign_tensor_list):
                attribution = attribution @ reverse_tensor

            attribution = torch.transpose(attribution, 0, 1)
            tmp_timing_list.append(perf_counter() - start_generation)

            attribution_score = torch.sum(attribution, dim=1).tolist()
            attribution_score = standardize_scores(attribution_score)

            GNNgraph.label = original_label

            output[label] = attribution_score
        output_for_metrics_calculation.append(output)

    execution_time = sum(tmp_timing_list) / (len(tmp_timing_list))

    # Obtain attribution score for use in generating saliency map for comparison with zero tensors
    if interpretability_config["sample_ids"] is not None:
        if ',' in str(interpretability_config["sample_ids"]):
            sample_graph_id_list = list(
                map(int, interpretability_config["sample_ids"].split(',')))
        else:
            sample_graph_id_list = [int(interpretability_config["sample_ids"])]

        output_for_generating_saliency_map.update({
            "layergradcam_%s_%s" % (str(assign_type), str(label)): []
            for _, label in dataset_features["label_dict"].items()
        })

        for index in range(len(output_for_metrics_calculation)):
            tmp_output = output_for_metrics_calculation[index]
            tmp_label = tmp_output['graph'].label
            if tmp_output['graph'].graph_id in sample_graph_id_list:
                element_name = "layergradcam_%s_%s" % (str(assign_type),
                                                       str(tmp_label))
                output_for_generating_saliency_map[element_name].append(
                    (tmp_output['graph'], tmp_output[tmp_label]))

    elif interpretability_config["number_of_samples"] > 0:
        # Randomly sample from existing list:
        graph_idxes = list(range(len(output_for_metrics_calculation)))
        random.shuffle(graph_idxes)
        output_for_generating_saliency_map.update({
            "layergradcam_%s_%s" % (str(assign_type), str(label)): []
            for _, label in dataset_features["label_dict"].items()
        })

        # Begin appending found samples
        for index in graph_idxes:
            tmp_label = output_for_metrics_calculation[index]['graph'].label
            element_name = "layergradcam_%s_%s" % (str(assign_type),
                                                   str(tmp_label))
            if len(output_for_generating_saliency_map[element_name]
                   ) < interpretability_config["number_of_samples"]:
                output_for_generating_saliency_map[element_name].append(
                    (output_for_metrics_calculation[index]['graph'],
                     output_for_metrics_calculation[index][tmp_label]))

    return output_for_metrics_calculation, output_for_generating_saliency_map, execution_time
Exemplo n.º 14
0
    def train_known_expl(self, n_epochs, ite, max_ite):
        self.model.train()
        avg_loss = []
        # if os.path.exists('./checkpoint'):
        #   try:
        #     print('model found ...')
        #     self.model.load_state_dict(torch.load('./checkpoint/resnet18.pth'))
        #     print('Model loaded sucessfully.')
        #     continue
        #   except:
        #     print('Not found any model ... ')

        optim = torch.optim.Adam(self.model.parameters(),
                                 lr=0.001,
                                 weight_decay=0)
        # lossOH = OhemCELoss(0.2,50) #cause batch is 100

        criteria1 = nn.CrossEntropyLoss()
        criteria2 = nn.BCELoss()
        criteria21 = nn.BCEWithLogitsLoss()
        criteriaL2 = nn.MSELoss()

        layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2)
        # layer_DLS = LayerDeepLiftShap(self.model, self.model.layer1[0].conv2, multiply_by_inputs=False)

        looss = nn.CrossEntropyLoss()
        print('Init known training with explanation...')
        number = 0
        # wandb.watch(self.model, log='all')
        for epoch in range(n_epochs):
            for i, batch in enumerate(self.KnownLoader):
                lb = batch[1].to(device)
                # print(batch[0].size())
                # # img = batch[0].to(device)
                # img = F.interpolate(batch[0],(100,1,224,224))
                # print(img.size())
                img = batch[0].to(device)
                # print(img.size())

                #define mask
                maskLb = batch[0].clone()
                maskLb = maskLb.squeeze()
                maskLb[maskLb == -0.5] = 0
                maskLb[maskLb != 0] = 1
                maskLb = maskLb.to(device)

                # Training
                optim.zero_grad()
                # a,b,c,out = self.model(img)
                out = self.model(img)
                predlb = torch.argmax(out, 1)
                # predlb = predlb.cpu().numpy()

                # print('Prediction label is :',predlb.cpu().numpy())
                # print('Ground Truth label is: ',lb.cpu().numpy())

                ##explain to me :
                gc_attr = layer_gc.attribute(img,
                                             target=predlb,
                                             relu_attributions=True)
                upsampled_attr = LayerAttribution.interpolate(
                    gc_attr, (64, 64))

                gc_attr = layer_gc.attribute(img,
                                             target=lb,
                                             relu_attributions=True)
                upsampled_attrB = LayerAttribution.interpolate(
                    gc_attr, (64, 64))
                exitpic = upsampled_attrB.clone()
                exitpic = exitpic.detach().cpu().numpy()
                # wandb.log({"Examples": exitpic})
                print(self.model.layer2[1].conv2.grads.data)

                # upsampled_attr_B = upsampled_attr.clone()
                # sz = upsampled_attr.size()
                # x = upsampled_attr_B.view(sz[0],sz[1], -1)
                # # print(x.size())
                # upsampled_attr_B = F.softmax(x,dim=2).view_as(upsampled_attr_B)

                ########################################
                # grid = torchvision.utils.make_grid(img)
                # self.writer.add_image('images', grid, 0)
                # self.writer.add_graph(self.model, img)

                # baseLine = torch.zeros(img.size())
                # # baseLine = baseLine[:1]
                # # print(baseLine.size())
                # baseLine = baseLine.to(device)

                # DLS_attr,delta = layer_DLS.attribute(img,baseLine,target=predlb,return_convergence_delta =True)
                # upsampled_attrDLS = LayerAttribution.interpolate(DLS_attr, (64, 64))
                # upsampled_attrDLSSum = torch.sum(upsampled_attrDLS,dim=(1),keepdim=True)
                # print(upsampled_attrDLSSum.size())
                # print(delta.size(),DLS_attr.size())

                # if number % 60 ==0:

                #   z = torch.eq(lb,predlb)
                #   z = ~z
                #   z = z.nonzero()
                #   try:
                #     z = z.cpu().numpy()[-1]
                #   except:
                #     z = [0]
                #   # if z.size().cpu()>0:
                #   print(lb[z[0]],predlb[z[0]],z[0])
                #   ################################################
                #   plotMe = viz.visualize_image_attr(upsampled_attr[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       method='heat_map',
                #                       sign='absolute_value', plt_fig_axis=None, outlier_perc=2,
                #                       cmap='inferno', alpha_overlay=0.2, show_colorbar=True,
                #                       title=str(predlb[z[0]]),
                #                       fig_size=(8, 10), use_pyplot=True)

                #   plotMe[0].savefig(str(number)+'NotEQPred.jpg')
                #   ################################################

                #   plotMe = viz.visualize_image_attr(upsampled_attrB[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       method='heat_map',
                #                       sign='absolute_value', plt_fig_axis=None, outlier_perc=2,
                #                       cmap='inferno', alpha_overlay=0.9, show_colorbar=True,
                #                       title=str(lb[z[0]].cpu()),
                #                       fig_size=(8, 10), use_pyplot=True)

                #   plotMe[0].savefig(str(number)+'NotEQLabel.jpg')
                #   ################################################

                #   outImg = img[z[0]].squeeze().detach().cpu().numpy()
                #   fig2 = plt.figure(figsize=(12,12))
                #   prImg = plt.imshow(outImg)
                #   fig2.savefig(str(number)+'NotEQOrig.jpg')
                #   ################################################
                #   fig = plt.figure(figsize=(15,10))
                #   ax = fig.add_subplot(111, projection='3d')

                #   z = upsampled_attr[z[0]].squeeze().detach().cpu().numpy()
                #   x = np.arange(0,64,1)
                #   y = np.arange(0,64,1)
                #   X, Y = np.meshgrid(x, y)

                #   plll = ax.plot_surface(X, Y , z, cmap=cm.coolwarm)
                #   # Customize the z axis.
                #   ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z))
                #   ax.zaxis.set_major_locator(LinearLocator(10))
                #   ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

                #   # Add a color bar which maps values to colors.
                #   fig.colorbar(plll, shrink=0.5, aspect=5)
                #   fig.savefig(str(number)+'NotEQ3D.jpg')
                #######explainVis####################
                # if number%30 == 0:
                #   self.vis_explanation(number)
                # self.vis_explanation(number)
                #####################################
                # reluMe = nn.ReLU(upsampled_attr)
                # upsampled_attr = reluMe(upsampled_attr)
                size_batch = upsampled_attr.size()[0]
                # upsampled_attr = F.relu(upsampled_attr, inplace=False)
                # print(upsampled_attr.size(),self.sintetic.size())
                losssemantic = criteria2(
                    upsampled_attr[:size_batch, :, :, :],
                    self.sintetic[:size_batch, :, :, :].to(device))

                loss1 = criteria1(out, lb)
                # lossall = 0.7*loss1 + 0.3*losssemantic
                lossall = losssemantic.to(device)
                optim.zero_grad()

                # loss2 = criteria21(upsampled_attr.squeeze()*64,maskLb)

                # loss3 = criteriaL2(maskLb*upsampled_attr.squeeze(),maskLb*upsampled_attrB.squeeze())
                # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze()*64,img.squeeze()*upsampled_attrB.squeeze()*64)
                # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze(),img.squeeze()*upsampled_attrB.squeeze())

                if number % 30 == 0:
                    # print()
                    print(loss1, losssemantic)
                # loss3 = torch.log(-loss3)
                # lossall = 0.7*loss1 + 0.3*loss2
                # lossall = 0*loss1 + 0.3*loss3
                # lossall = loss3
                # print('Losss to cjeck is:--- ',torch.max(loss3))
                avg_loss = torch.mean(lossall)
                lossall.backward()
                optim.step()
                # print(avg_loss)
                number += 1
                # if number == 2:
                # gradds = []
                # for tag, parm in self.model.state_dict().items():
                #   if parm.requires_grad:
                #     # print(p.name, p.data)
                #     gradds.append(parm.grad.data.cpu().numpy())
                # self.writer.add_histogram('grads', torch.from_numpy(gradds), number)

                # self.writer.his
                # self.plot_grad_flow(self.model.state_dict().items(),number)

                if number % 10 == 0:
                    # print(number)
                    print(
                        "Epoch: {}/{} batch: {}/{} iteration: {}/{} average-loss: {:0.4f}"
                        .format(epoch + 1, n_epochs, i + 1,
                                len(self.KnownLoader), ite + 1, max_ite,
                                avg_loss.cpu()))
        # self.writer.close()
        # Save checkpoint
        if (os.path.exists("./checkpoint")):
            torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth")
        else:
            os.mkdir('checkpoint')
            torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth")
        number = 0
Exemplo n.º 15
0
def get_grad_imp(model,
                 X,
                 y=None,
                 mode='grad',
                 return_y=False,
                 clip=False,
                 baselines=None):
    X.requires_grad_(True)
    #     X = X.cuda()

    if mode in ['grad']:
        logits = model(X)
        if y is None:
            y = logits.argmax(dim=1)

        attributions = torch.autograd.grad(
            logits[torch.arange(len(logits)), y].sum(), X)[0].detach()
    else:
        if y is None:
            with torch.no_grad():
                logits = model(X)
                y = logits.argmax(dim=1)

        if mode == 'deeplift':
            dl = DeepLift(model)

            attributions = dl.attribute(inputs=X, baselines=0., target=y)
            attributions = attributions.detach()
#             attributions = (attributions.detach() ** 2).sum(dim=1, keepdim=True)
#         elif mode in ['deepliftshap', 'deepliftshap_mean']:
        elif mode in ['deepliftshap']:
            dl = DeepLiftShap(model)
            attributions = []
            for idx in range(0, len(X), 2):
                the_x, the_y = X[idx:(idx + 2)], y[idx:(idx + 2)]

                attribution = dl.attribute(inputs=the_x,
                                           baselines=baselines,
                                           target=the_y)
                attributions.append(attribution.detach())
            attributions = torch.cat(attributions, dim=0)


#             attributions = dl.attribute(inputs=X, baselines=baselines, target=y).detach()
#             if mode == 'deepliftshap':
#                 attributions = (attributions ** 2).sum(dim=1, keepdim=True)
#             else:
#                 attributions = (attributions).mean(dim=1, keepdim=True)
        elif mode in ['gradcam']:
            orig_lgc = LayerGradCam(model, model.body[0])
            attributions = orig_lgc.attribute(X, target=y)

            attributions = F.interpolate(attributions,
                                         size=X.shape[-2:],
                                         mode='bilinear')
        else:
            raise NotImplementedError(f'${mode} is not specified.')

    # Do clipping!
    if clip:
        attributions = myclip(attributions)

    X.requires_grad_(False)
    if not return_y:
        return attributions
    return attributions, y
Exemplo n.º 16
0
def get_attribution(real_img, 
                    fake_img, 
                    real_class, 
                    fake_class, 
                    net_module, 
                    checkpoint_path, 
                    input_shape, 
                    channels,
                    methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"],
                    output_classes=6,
                    downsample_factors=[(2,2), (2,2), (2,2), (2,2)]):


    imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), 
            image_to_tensor(normalize_image(fake_img).astype(np.float32))]

    classes = [real_class, fake_class]
    net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False,
                       downsample_factors=downsample_factors)

    attrs = []
    attrs_names = []

    if "residual" in methods:
        res = np.abs(real_img - fake_img)
        res = res - np.min(res)
        attrs.append(torch.tensor(res/np.max(res)))
        attrs_names.append("residual")

    if "random" in methods:
        rand = np.abs(np.random.randn(*np.shape(real_img)))
        rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4))
        rand = rand - np.min(rand)
        rand = rand/np.max(np.abs(rand))
        attrs.append(torch.tensor(rand))
        attrs_names.append("random")

    if "gc" in methods:
        net.zero_grad()
        last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1]
        layer_name = last_conv_layer[0]
        layer = last_conv_layer[1]
        layer_gc = LayerGradCam(net, layer)
        gc_real = layer_gc.attribute(imgs[0], target=classes[0])
        gc_fake = layer_gc.attribute(imgs[1], target=classes[1])

        gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1]))
        gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1]))

        attrs.append(torch.tensor(gc_real[0,0,:,:]))
        attrs_names.append("gc_real")

        attrs.append(torch.tensor(gc_fake[0,0,:,:]))
        attrs_names.append("gc_fake")

        # SCAM
        gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, 
                                     fake_class, net_module, checkpoint_path, 
                                     input_shape, channels, None, output_classes=output_classes,
                                     downsample_factors=downsample_factors)
        attrs.append(gc_diff_0)
        attrs_names.append("gc_diff_0")

        attrs.append(gc_diff_1)
        attrs_names.append("gc_diff_1")

    if "ggc" in methods:
        net.zero_grad()
        last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1]
        guided_gc = GuidedGradCam(net, last_conv)
        ggc_real = guided_gc.attribute(imgs[0], target=classes[0])
        ggc_fake = guided_gc.attribute(imgs[1], target=classes[1])

        attrs.append(ggc_real[0,0,:,:])
        attrs_names.append("ggc_real")

        attrs.append(ggc_fake[0,0,:,:])
        attrs_names.append("ggc_fake")

        net.zero_grad()
        gbp = GuidedBackprop(net)
        gbp_real = gbp.attribute(imgs[0], target=classes[0])
        gbp_fake = gbp.attribute(imgs[1], target=classes[1])
        
        attrs.append(gbp_real[0,0,:,:])
        attrs_names.append("gbp_real")

        attrs.append(gbp_fake[0,0,:,:])
        attrs_names.append("gbp_fake")

        ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0
        ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1

        attrs.append(ggc_diff_0)
        attrs_names.append("ggc_diff_0")

        attrs.append(ggc_diff_1)
        attrs_names.append("ggc_diff_1")

    # IG
    if "ig" in methods:
        baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32))
        net.zero_grad()
        ig = IntegratedGradients(net)
        ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True)
        ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True)
        ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True)
        ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True)

        attrs.append(ig_real[0,0,:,:])
        attrs_names.append("ig_real")

        attrs.append(ig_fake[0,0,:,:])
        attrs_names.append("ig_fake")

        attrs.append(ig_diff_0[0,0,:,:])
        attrs_names.append("ig_diff_0")

        attrs.append(ig_diff_1[0,0,:,:])
        attrs_names.append("ig_diff_1")

        
    # DL
    if "dl" in methods:
        net.zero_grad()
        dl = DeepLift(net)
        dl_real = dl.attribute(imgs[0], target=classes[0])
        dl_fake = dl.attribute(imgs[1], target=classes[1])
        dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0])
        dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1])

        attrs.append(dl_real[0,0,:,:])
        attrs_names.append("dl_real")

        attrs.append(dl_fake[0,0,:,:])
        attrs_names.append("dl_fake")

        attrs.append(dl_diff_0[0,0,:,:])
        attrs_names.append("dl_diff_0")

        attrs.append(dl_diff_1[0,0,:,:])
        attrs_names.append("dl_diff_1")

    # INGRAD
    if "ingrad" in methods:
        net.zero_grad()
        saliency = Saliency(net)
        grads_real = saliency.attribute(imgs[0], 
                                        target=classes[0]) 
        grads_fake = saliency.attribute(imgs[1], 
                                        target=classes[1]) 

        attrs.append(grads_real[0,0,:,:])
        attrs_names.append("grads_real")

        attrs.append(grads_fake[0,0,:,:])
        attrs_names.append("grads_fake")

        net.zero_grad()
        input_x_gradient = InputXGradient(net)
        ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0])
        ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1])

        ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1])
        ingrad_diff_1 = grads_real * (imgs[1] - imgs[0])

        attrs.append(torch.abs(ingrad_real[0,0,:,:]))
        attrs_names.append("ingrad_real")

        attrs.append(torch.abs(ingrad_fake[0,0,:,:]))
        attrs_names.append("ingrad_fake")

        attrs.append(torch.abs(ingrad_diff_0[0,0,:,:]))
        attrs_names.append("ingrad_diff_0")

        attrs.append(torch.abs(ingrad_diff_1[0,0,:,:]))
        attrs_names.append("ingrad_diff_1")

    attrs = [a.detach().cpu().numpy() for a in attrs]
    attrs_norm = [a/np.max(np.abs(a)) for a in attrs]

    return attrs_norm, attrs_names
Exemplo n.º 17
0
# GradCAM computes the gradients of the target output with respect to the
# given layer, averages for each output channel (dimension 2 of output),
# and multiplies the average gradient for each channel by the layer
# activations. The results are summed over all channels. GradCAM is
# designed for convnets; since the activity of convolutional layers often
# maps spatially to the input, GradCAM attributions are often upsampled
# and used to mask the input.
#
# Layer attribution is set up similarly to input attribution, except that
# in addition to the model, you must specify a hidden layer within the
# model that you wish to examine. As above, when we call ``attribute()``,
# we specify the target class of interest.
#

layer_gradcam = LayerGradCam(model, model.layer3[1].conv2)
attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx)

_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(
    1, 2, 0).detach().numpy(),
                             sign="all",
                             title="Layer 3 Block 1 Conv 2")

##########################################################################
# We’ll use the convenience method ``interpolate()`` in the
# `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__
# base class to upsample this attribution data for comparison to the input
# image.
#

upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc,
                                               input_img.shape[2:])
Exemplo n.º 18
0
    def vis_explanation(self, number):
        if len(self.explainVis) == 0:
            for i, batch in enumerate(self.test_loader):
                self.explainVis = batch
                break

        # oldIndices = self.test_loader.indices.copy()
        # self.test_loader.indices = self.test_loader.indices[:2]

        # datasetLoader = self.test_loader
        layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2)

        # for i, batch in enumerate(datasetLoader):

        lb = self.explainVis[1].to(device)
        # print(len(lb))
        img = self.explainVis[0].to(device)
        # plt.subplot(2,1,1)
        # plt.imshow(img.squeeze().cpu().numpy())

        pred = self.model(img)
        predlb = torch.argmax(pred, 1)
        imgCQ = img.clone()

        # print('Prediction label is :',predlb.cpu().numpy())
        # print('Ground Truth label is: ',lb.cpu().numpy())
        ##explain to me :
        gc_attr = layer_gc.attribute(imgCQ,
                                     target=predlb,
                                     relu_attributions=False)
        upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64))

        gc_attr = layer_gc.attribute(imgCQ, target=lb, relu_attributions=False)
        upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64))
        if not os.path.exists('./pic'):
            os.mkdir('./pic')

        ####PLot################################################
        plotMe = viz.visualize_image_attr(
            upsampled_attr[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.2,
            show_colorbar=True,
            title=str(predlb[7]),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQPred.jpg')
        ################################################

        plotMe = viz.visualize_image_attr(
            upsampled_attrB[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.9,
            show_colorbar=True,
            title=str(lb[7].cpu()),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQLabel.jpg')
        ################################################

        outImg = img[7].squeeze().detach().cpu().numpy()
        fig2 = plt.figure(figsize=(12, 12))
        prImg = plt.imshow(outImg)
        fig2.savefig('./pic/' + str(number) + 'NotEQOrig.jpg')
        ################################################
        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_subplot(111, projection='3d')

        z = upsampled_attr[7].squeeze().detach().cpu().numpy()
        x = np.arange(0, 64, 1)
        y = np.arange(0, 64, 1)
        X, Y = np.meshgrid(x, y)

        plll = ax.plot_surface(X, Y, z, cmap=cm.coolwarm)
        # Customize the z axis.
        # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z))
        ax.set_zlim(-0.02, 0.1)
        ax.zaxis.set_major_locator(LinearLocator(10))
        ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

        # Add a color bar which maps values to colors.
        fig.colorbar(plll, shrink=0.5, aspect=5)
        fig.savefig('./pic/' + str(number) + 'NotEQ3D.jpg')