Ejemplo n.º 1
0
    def explanation(self, dataindex):
        oldIndices = self.unknown.indices.copy()
        self.unknown.indices = dataindex
        datasetLoader = torch.utils.data.DataLoader(dataset=self.unknown,
                                                    batch_size=1,
                                                    shuffle=False)
        self.model.eval()
        # for param in self.model.parameters():
        #     param.requires_grad = False
        avg_loss = []
        #Dont forget to replace indices at end ##########

        layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2)
        #deep lift
        dl = LayerDeepLift(self.model, self.model.layer1[0].conv2)

        # atrr = []
        plt.figure(figsize=(18, 10))

        for i, batch in enumerate(datasetLoader):

            lb = batch[1].to(device)
            print(len(lb))
            img = batch[0].to(device)
            # plt.subplot(2,1,1)
            # plt.imshow(img.squeeze().cpu().numpy())

            lbin = batch[1].cpu().numpy()
            print(lbin)
            pred = self.model(img)
            predlb = torch.argmax(pred, 1)
            print('Prediction label is :', predlb.cpu().numpy())
            print('Ground Truth label is: ', lb.cpu().numpy())

            # gc_attr = layer_gc.attribute(img, target=int(lbin[0]))
            gc_attr = layer_gc.attribute(img, target=int(predlb.cpu().numpy()))
            upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28))

            base = torch.zeros([1, 1, 28, 28]).to(device)
            de_attr = dl.attribute(img, base, target=int(lbin[0]))
            dl_upsampled_attr = LayerAttribution.interpolate(de_attr, (28, 28))

            # upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28))
            # plt.subplot(2,1,2)
            # plt.imshow(upsampled_attr.squeeze().detach().cpu().numpy())
            # atrr.append[gc_attr]
            print("done ...")
            # print(gc_attr,upsampled_attr.squeeze().detach().cpu().numpy())
            # plt.show()
            return img, gc_attr, upsampled_attr.squeeze().detach().cpu().numpy(
            ), dl_upsampled_attr.squeeze().detach().cpu().numpy()
Ejemplo n.º 2
0
def layer_grad_cam(model, image, results_dir, file_name):

    layer_gc = LayerGradCam(model, model.skip_4[0])
    attr = layer_gc.attribute(image)

    attr = LayerAttribution.interpolate(attr, (121, 145, 121))

    attr = attr.squeeze()
    attr = attr.detach().numpy()
    attr = nib.Nifti1Image(attr, affine=np.eye(4))

    nib.save(attr, results_dir + file_name + "-HM.nii")
Ejemplo n.º 3
0
def run(arch, img, target):
    input = Image.open(img).convert('RGB')
    input = apply_transform(input)
    model = models.vgg16(pretrained=True).eval()
    ig = LayerGradCam(model, model.features[28])
    out = FF.softmax(model(input), dim=1)
    class_idx = out.max(1)[-1].item()
    attr = ig.attribute(input, target=target)
    attr = LayerAttribution.interpolate(attr, (224, 224))
    attr = (attr - attr.min()) / (attr.max() - attr.min())
    #attr=attr.squeeze(0).squeeze(0)
    #print('IG Attributions:',attr, attr.shape)
    return attr
Ejemplo n.º 4
0
    def validate_explanation(self):

        predlist = torch.zeros(0, dtype=torch.long, device='cpu')
        lbllist = torch.zeros(0, dtype=torch.long, device='cpu')
        scores = []
        layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2)

        for i, batch in enumerate(self.test_loader):

            lb = batch[1].to(device)
            img = batch[0].to(device)

            pred = self.model(img)
            predlb = torch.argmax(pred, 1)

            gc_attr = layer_gc.attribute(img,
                                         target=predlb,
                                         relu_attributions=False)
            upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64))

            sz = upsampled_attr.size()
            x = upsampled_attr.view(sz[0], sz[1], -1)
            # print(x.size())
            upsampled_attr_soft = F.softmax(x, dim=2).view_as(upsampled_attr)
            # print(upsampled_attr_soft.size())
            # upsampled_attr = F.softmax(upsampled_attr)

            # gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=False)
            # upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64))

            # Append batch prediction results
            predlist = torch.cat(
                [predlist,
                 upsampled_attr_soft.detach().squeeze().cpu()])
            lbllist = torch.cat([
                lbllist, self.sintetic[:upsampled_attr_soft.size(
                )[0], :, :, :].squeeze().cpu()
            ],
                                dim=0)
        print(predlist.size, lbllist.size)
        final_prec, final_rec, final_corr = self.calculate_measures(
            lbllist, predlist)
        # return final_prec, final_rec, final_corr
        # Save checkpoint
        print('Final validation result...')
        print("- final explanation percision: {:.3f}".format(final_prec))
        print("- final explanation recal: {:.3f}".format(final_rec))
        print("- final explanation correlation: {:.3f}".format(final_corr))
Ejemplo n.º 5
0
layer_gradcam = LayerGradCam(model, model.layer3[1].conv2)
attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx)

_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(
    1, 2, 0).detach().numpy(),
                             sign="all",
                             title="Layer 3 Block 1 Conv 2")

##########################################################################
# We’ll use the convenience method ``interpolate()`` in the
# `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__
# base class to upsample this attribution data for comparison to the input
# image.
#

upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc,
                                               input_img.shape[2:])

print(attributions_lgc.shape)
print(upsamp_attr_lgc.shape)
print(input_img.shape)

_ = viz.visualize_image_attr_multiple(
    upsamp_attr_lgc[0].cpu().permute(1, 2, 0).detach().numpy(),
    transformed_img.permute(1, 2, 0).numpy(),
    ["original_image", "blended_heat_map", "masked_image"],
    ["all", "positive", "positive"],
    show_colorbar=True,
    titles=["Original", "Positive Attribution", "Masked"],
    fig_size=(18, 6))

#######################################################################
Ejemplo n.º 6
0
def train_single_scale(D, G, reals, generators, noise_maps,
                       input_from_prev_scale, noise_amplitudes, opt):
    """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the
    original level, generators and noise_maps contain information from previous scales and will receive information in
    this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold
    the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """
    current_scale = len(generators)
    real = reals[current_scale]

    keepSky = False
    kernel_dims = (2, 2)

    # Initialize real detector
    real0 = preprocess(opt, real, keepSky)
    N, C, H, W = real0.shape

    scale = opt.scales[current_scale] if current_scale < len(opt.scales) else 1

    if opt.cgan:
        detector = PCA_Detector(opt, 'real', real0, kernel_dims)
        real_detection_map = detector(real0)
        detection_scale = 0.1
        real_detection_map *= detection_scale
        real1 = torch.cat(
            [real, F.interpolate(real_detection_map, (H, W))], dim=1)
        divergences = []
    else:
        real1 = real

    if opt.game == 'mario':
        token_group = MARIO_TOKEN_GROUPS
    else:  # if opt.game == 'mariokart':
        token_group = MARIOKART_TOKEN_GROUPS

    nzx = real.shape[2]  # Noise size x
    nzy = real.shape[3]  # Noise size y

    padsize = int(
        1 * opt.num_layer
    )  # As kernel size is always 3 currently, padsize goes up by one per layer

    if not opt.pad_with_noise:
        pad_noise = nn.ZeroPad2d(padsize)
        pad_image = nn.ZeroPad2d(padsize)
    else:
        pad_noise = nn.ReflectionPad2d(padsize)
        pad_image = nn.ReflectionPad2d(padsize)

    # setup optimizer
    optimizerD = optim.Adam(D.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(G.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600, 2500],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600, 2500],
                                                      gamma=opt.gamma)

    if current_scale == 0:  # Generate new noise
        z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                       device=opt.device)
        z_opt = pad_noise(z_opt)
    else:  # Add noise to previous output
        z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device)
        z_opt = pad_noise(z_opt)

    logger.info("Training at scale {}", current_scale)
    for epoch in tqdm(range(opt.niter)):
        step = current_scale * opt.niter + epoch
        noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                        device=opt.device)
        noise_ = pad_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            D.zero_grad()

            output = D(real1).to(opt.device)

            errD_real = -output.mean()
            errD_real.backward(retain_graph=True)

            # train with fake
            if (j == 0) & (epoch == 0):
                if current_scale == 0:  # If we are in the lowest scale, noise is generated from scratch
                    prev = torch.zeros(1, opt.nc_current, nzx,
                                       nzy).to(opt.device)
                    input_from_prev_scale = prev
                    prev = pad_image(prev)
                    z_prev = torch.zeros(1, opt.nc_current, nzx,
                                         nzy).to(opt.device)
                    z_prev = pad_noise(z_prev)
                    opt.noise_amp = 1
                else:  # First step in NOT the lowest scale
                    # We need to adapt our inputs from the previous scale and add noise to it
                    prev = draw_concat(generators, noise_maps, reals,
                                       noise_amplitudes, input_from_prev_scale,
                                       "rand", pad_noise, pad_image, opt)

                    # For the seeding experiment, we need to transform from token_groups to the actual token
                    if current_scale == (opt.token_insert + 1):
                        prev = group_to_token(prev, opt.token_list,
                                              token_group)

                    prev = interpolate(prev,
                                       real1.shape[-2:],
                                       mode="bilinear",
                                       align_corners=False)
                    prev = pad_image(prev)
                    z_prev = draw_concat(generators, noise_maps, reals,
                                         noise_amplitudes,
                                         input_from_prev_scale, "rec",
                                         pad_noise, pad_image, opt)

                    # For the seeding experiment, we need to transform from token_groups to the actual token
                    if current_scale == (opt.token_insert + 1):
                        z_prev = group_to_token(z_prev, opt.token_list,
                                                token_group)

                    z_prev = interpolate(z_prev,
                                         real1.shape[-2:],
                                         mode="bilinear",
                                         align_corners=False)
                    opt.noise_amp = update_noise_amplitude(
                        z_prev, real1[:, :-1], opt)
                    z_prev = pad_image(z_prev)
            else:  # Any other step
                prev = draw_concat(generators, noise_maps, reals,
                                   noise_amplitudes, input_from_prev_scale,
                                   "rand", pad_noise, pad_image, opt)

                # For the seeding experiment, we need to transform from token_groups to the actual token
                if current_scale == (opt.token_insert + 1):
                    prev = group_to_token(prev, opt.token_list, token_group)

                prev = interpolate(prev,
                                   real1.shape[-2:],
                                   mode="bilinear",
                                   align_corners=False)
                prev = pad_image(prev)

            # After creating our correct noise input, we feed it to the generator:
            noise = opt.noise_amp * noise_ + prev
            fake = G(noise.detach(),
                     prev,
                     temperature=1 if current_scale != opt.token_insert else 1)

            fake0 = preprocess(opt, fake, keepSky)
            if opt.cgan:
                Nf, Cf, Hf, Wf = fake0.shape
                fake_detection_map = detector(fake0) * detection_scale
                fake1 = torch.cat(
                    [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1)
            else:
                fake1 = fake

            # Then run the result through the discriminator
            output = D(fake1.detach())
            errD_fake = output.mean()

            # Backpropagation
            errD_fake.backward(retain_graph=False)

            # Gradient Penalty
            gradient_penalty = calc_gradient_penalty(D, real1, fake1,
                                                     opt.lambda_grad,
                                                     opt.device)
            gradient_penalty.backward(retain_graph=False)

            # Logging:
            if step % 10 == 0:
                wandb.log(
                    {
                        f"D(G(z))@{current_scale}":
                        errD_fake.item(),
                        f"D(x)@{current_scale}":
                        -errD_real.item(),
                        f"gradient_penalty@{current_scale}":
                        gradient_penalty.item()
                    },
                    step=step,
                    sync=False)
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            G.zero_grad()
            fake = G(noise.detach(),
                     prev.detach(),
                     temperature=1 if current_scale != opt.token_insert else 1)

            fake0 = preprocess(opt, fake, keepSky)
            Nf, Cf, Hf, Wf = fake0.shape

            if opt.cgan:
                fake_detection_map = detector(fake0) * detection_scale
                fake1 = torch.cat(
                    [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1)
            else:
                fake1 = fake

            output = D(fake1)

            errG = -output.mean()
            errG.backward(retain_graph=False)
            if opt.alpha != 0:  # i. e. we are trying to find an exact recreation of our input in the lat space
                Z_opt = opt.noise_amp * z_opt + z_prev
                G_rec = G(
                    Z_opt.detach(),
                    z_prev,
                    temperature=1 if current_scale != opt.token_insert else 1)
                rec_loss = opt.alpha * F.mse_loss(G_rec, real)
                if opt.cgan:
                    div = divergence(real_detection_map,
                                     preprocess(opt, G_rec, keepSky))
                    rec_loss += div
                rec_loss.backward(
                    retain_graph=False
                )  # TODO: Check for unexpected argument retain_graph=True
                rec_loss = rec_loss.detach()
            else:  # We are not trying to find an exact recreation
                rec_loss = torch.zeros([])
                Z_opt = z_opt

            optimizerG.step()

        # More Logging:
        div = divergence(real_detection_map, preprocess(opt, fake, keepSky))
        divergences.append(div)
        # logger.info("divergence(fake) = {}", div)
        if step % 10 == 0:
            wandb.log(
                {
                    f"noise_amplitude@{current_scale}": opt.noise_amp,
                    f"rec_loss@{current_scale}": rec_loss.item()
                },
                step=step,
                sync=False,
                commit=True)

        # Rendering and logging images of levels
        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            if opt.token_insert >= 0 and opt.nc_current == len(token_group):
                token_list = [list(group.keys())[0] for group in token_group]
            else:
                token_list = opt.token_list

            img = opt.ImgGen.render(
                one_hot_to_ascii_level(fake1[:, :-1].detach(), token_list))
            img2 = opt.ImgGen.render(
                one_hot_to_ascii_level(
                    G(Z_opt.detach(),
                      z_prev,
                      temperature=1 if current_scale != opt.token_insert else
                      1).detach(), token_list))
            real_scaled = one_hot_to_ascii_level(real1[:, :-1].detach(),
                                                 token_list)
            img3 = opt.ImgGen.render(real_scaled)
            wandb.log(
                {
                    f"G(z)@{current_scale}": wandb.Image(img),
                    f"G(z_opt)@{current_scale}": wandb.Image(img2),
                    f"real@{current_scale}": wandb.Image(img3)
                },
                sync=False,
                commit=False)

            real_scaled_path = os.path.join(wandb.run.dir,
                                            f"real@{current_scale}.txt")
            with open(real_scaled_path, "w") as f:
                f.writelines(real_scaled)
            wandb.save(real_scaled_path)

        # Learning Rate scheduler step
        schedulerD.step()
        schedulerG.step()

    if opt.cgan:
        div = divergence(real_detection_map, preprocess(opt, z_opt, keepSky))
        divergences.append(div)

    # visualization config
    folder_name = 'gradcam'
    level_name = opt.input_name.rsplit(".", 1)[0].split("_", 1)[1]

    # GradCAM on D
    camD = LayerGradCam(D, D.tail)
    real0 = one_hot_to_ascii_level(real, opt.token_list)
    real0 = opt.ImgGen.render(real0)
    real0 = np.array(real0)
    attr = camD.attribute(real1, target=(0, 0, 0), relu_attributions=True)
    attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]),
                                        'bilinear')
    attr = attr.permute(2, 3, 1, 0).squeeze(3)
    attr = attr.detach().cpu().numpy()
    fig, ax = plt.subplots(1, 1)
    fig.figsize = (10, 1)
    ax.imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1)
    im = ax.imshow(attr, cmap='jet', alpha=0.5)
    ax.axis('off')
    fig.colorbar(im, ax=ax, location='bottom', shrink=0.85)
    plt.suptitle(f'cGAN {level_name} D(x)@{current_scale} ({step})')
    plt.savefig(rf'{folder_name}\{level_name}_D_{current_scale}_{step}.png',
                bbox_inches='tight',
                pad_inches=0.1)
    # plt.show()
    plt.close()

    # GradCAM on G
    token_names = {
        'M': 'Mario start',
        'F': 'Mario finish',
        'y': 'spiky',
        'Y': 'winged spiky',
        'k': 'green koopa',
        'K': 'winged green koopa',
        '!': 'coin [?]',
        '#': 'pyramid',
        '-': 'sky',
        '1': 'invis. 1 up',
        '2': 'invis. coin',
        'L': '1 up',
        '?': 'special [?]',
        '@': 'special [?]',
        'Q': 'coin [?]',
        '!': 'coin [?]',
        'C': 'coin brick',
        'S': 'normal brick',
        'U': 'mushroom brick',
        'X': 'ground',
        'E': 'goomba',
        'g': 'goomba',
        'k': 'green koopa',
        '%': 'platform',
        '|': 'platform bg',
        'r': 'red koopa',
        'R': 'winged red koopa',
        'o': 'coin',
        't': 'pipe',
        'T': 'plant pipe',
        '*': 'bullet bill',
        '<': 'pipe top left',
        '>': 'pipe top right',
        '[': 'pipe left',
        ']': 'pipe right',
        'B': 'bullet bill head',
        'b': 'bullet bill body',
        'D': 'used block',
    }

    def wrappedG(z):
        return G(z, z_opt)

    camG = LayerGradCam(wrappedG, G.tail[0])
    z_cam = generate_spatial_noise([1, opt.nc_current, nzx, nzy],
                                   device=opt.device)
    z_cam = pad_noise(z_cam)
    attrs = []
    for i in range(opt.nc_current):
        attr = camG.attribute(z_cam, target=(i, 0, 0), relu_attributions=True)
        attr = LayerAttribution.interpolate(attr,
                                            (real0.shape[0], real0.shape[1]),
                                            'bilinear')
        attr = attr.permute(2, 3, 1, 0).squeeze(3)
        attr = attr.detach().cpu().numpy()
        attrs.append(attr)
    fig, axs = plt.subplots(opt.nc_current, 1)
    fig.figsize = (10, opt.nc_current)
    for i in range(opt.nc_current):
        axs[i].axis('off')
        axs[i].text(-0.1,
                    0.5,
                    token_names[opt.token_list[i]],
                    rotation=0,
                    verticalalignment='center',
                    horizontalalignment='right',
                    transform=axs[i].transAxes)
        axs[i].imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1)
        im = axs[i].imshow(attrs[i], cmap='jet', alpha=0.5)
    fig.colorbar(im, ax=axs, shrink=0.85)
    plt.suptitle(f'cGAN {level_name} G(z)@{current_scale} ({step})')
    plt.savefig(rf'{folder_name}\{level_name}_G_{current_scale}_{step}.png',
                bbox_inches='tight',
                pad_inches=0.1)
    # plt.show()
    plt.close()

    # Save networks
    torch.save(z_opt, "%s/z_opt.pth" % opt.outf)
    save_networks(G, D, z_opt, opt)
    wandb.save(opt.outf)
    return z_opt, input_from_prev_scale, G, divergences
Ejemplo n.º 7
0
    def get_insights(self, tensor_data, _, target=0):
        default_cmap = LinearSegmentedColormap.from_list(
            "custom blue",
            [(0, "#ffffff"), (0.25, "#0000ff"), (1, "#0000ff")],
            N=256,
        )

        attributions_ig, _ = self.attribute_image_features(
            self.ig,
            tensor_data,
            baselines=tensor_data * 0,
            return_convergence_delta=True,
            n_steps=15,
        )

        attributions_occ = self.attribute_image_features(
            self.occlusion,
            tensor_data,
            strides=(3, 8, 8),
            sliding_window_shapes=(3, 15, 15),
            baselines=tensor_data * 0,
        )

        attributions_lgc = self.attribute_image_features(
            self.layer_gradcam, tensor_data)

        upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc,
                                                       tensor_data.shape[2:])

        matplot_viz_ig, _ = viz.visualize_image_attr_multiple(
            np.transpose(attributions_ig.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            np.transpose(tensor_data.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            use_pyplot=False,
            methods=["original_image", "heat_map"],
            cmap=default_cmap,
            show_colorbar=True,
            signs=["all", "positive"],
            titles=["Original", "Integrated Gradients"],
        )

        matplot_viz_occ, _ = viz.visualize_image_attr_multiple(
            np.transpose(attributions_occ.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            np.transpose(tensor_data.squeeze().cpu().detach().numpy(),
                         (1, 2, 0)),
            [
                "original_image",
                "heat_map",
                "heat_map",
            ],
            ["all", "positive", "negative"],
            show_colorbar=True,
            titles=[
                "Original",
                "Positive Attribution",
                "Negative Attribution",
            ],
            fig_size=(18, 6),
            use_pyplot=False,
        )

        matplot_viz_lgc, _ = viz.visualize_image_attr_multiple(
            upsamp_attr_lgc[0].cpu().permute(1, 2, 0).detach().numpy(),
            tensor_data.squeeze().permute(1, 2, 0).cpu().numpy(),
            use_pyplot=False,
            methods=["original_image", "blended_heat_map", "blended_heat_map"],
            signs=["all", "positive", "negative"],
            show_colorbar=True,
            titles=[
                "Original",
                "Positive Attribution",
                "Negative Attribution",
            ],
            fig_size=(18, 6))

        occ_bytes = self.output_bytes(matplot_viz_occ)
        ig_bytes = self.output_bytes(matplot_viz_ig)
        lgc_bytes = self.output_bytes(matplot_viz_lgc)

        output = [{
            "b64": b64encode(row).decode("utf8")
        } if isinstance(row, (bytes, bytearray)) else row
                  for row in [ig_bytes, occ_bytes, lgc_bytes]]
        return output
Ejemplo n.º 8
0
    def vis_explanation(self, number):
        if len(self.explainVis) == 0:
            for i, batch in enumerate(self.test_loader):
                self.explainVis = batch
                break

        # oldIndices = self.test_loader.indices.copy()
        # self.test_loader.indices = self.test_loader.indices[:2]

        # datasetLoader = self.test_loader
        layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2)

        # for i, batch in enumerate(datasetLoader):

        lb = self.explainVis[1].to(device)
        # print(len(lb))
        img = self.explainVis[0].to(device)
        # plt.subplot(2,1,1)
        # plt.imshow(img.squeeze().cpu().numpy())

        pred = self.model(img)
        predlb = torch.argmax(pred, 1)
        imgCQ = img.clone()

        # print('Prediction label is :',predlb.cpu().numpy())
        # print('Ground Truth label is: ',lb.cpu().numpy())
        ##explain to me :
        gc_attr = layer_gc.attribute(imgCQ,
                                     target=predlb,
                                     relu_attributions=False)
        upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64))

        gc_attr = layer_gc.attribute(imgCQ, target=lb, relu_attributions=False)
        upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64))
        if not os.path.exists('./pic'):
            os.mkdir('./pic')

        ####PLot################################################
        plotMe = viz.visualize_image_attr(
            upsampled_attr[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.2,
            show_colorbar=True,
            title=str(predlb[7]),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQPred.jpg')
        ################################################

        plotMe = viz.visualize_image_attr(
            upsampled_attrB[7].detach().cpu().numpy().transpose([1, 2, 0]),
            original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]),
            method='heat_map',
            sign='all',
            plt_fig_axis=None,
            outlier_perc=2,
            cmap='inferno',
            alpha_overlay=0.9,
            show_colorbar=True,
            title=str(lb[7].cpu()),
            fig_size=(8, 10),
            use_pyplot=True)

        plotMe[0].savefig('./pic/' + str(number) + 'NotEQLabel.jpg')
        ################################################

        outImg = img[7].squeeze().detach().cpu().numpy()
        fig2 = plt.figure(figsize=(12, 12))
        prImg = plt.imshow(outImg)
        fig2.savefig('./pic/' + str(number) + 'NotEQOrig.jpg')
        ################################################
        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_subplot(111, projection='3d')

        z = upsampled_attr[7].squeeze().detach().cpu().numpy()
        x = np.arange(0, 64, 1)
        y = np.arange(0, 64, 1)
        X, Y = np.meshgrid(x, y)

        plll = ax.plot_surface(X, Y, z, cmap=cm.coolwarm)
        # Customize the z axis.
        # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z))
        ax.set_zlim(-0.02, 0.1)
        ax.zaxis.set_major_locator(LinearLocator(10))
        ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

        # Add a color bar which maps values to colors.
        fig.colorbar(plll, shrink=0.5, aspect=5)
        fig.savefig('./pic/' + str(number) + 'NotEQ3D.jpg')
Ejemplo n.º 9
0
    def train_known_expl(self, n_epochs, ite, max_ite):
        self.model.train()
        avg_loss = []
        # if os.path.exists('./checkpoint'):
        #   try:
        #     print('model found ...')
        #     self.model.load_state_dict(torch.load('./checkpoint/resnet18.pth'))
        #     print('Model loaded sucessfully.')
        #     continue
        #   except:
        #     print('Not found any model ... ')

        optim = torch.optim.Adam(self.model.parameters(),
                                 lr=0.001,
                                 weight_decay=0)
        # lossOH = OhemCELoss(0.2,50) #cause batch is 100

        criteria1 = nn.CrossEntropyLoss()
        criteria2 = nn.BCELoss()
        criteria21 = nn.BCEWithLogitsLoss()
        criteriaL2 = nn.MSELoss()

        layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2)
        # layer_DLS = LayerDeepLiftShap(self.model, self.model.layer1[0].conv2, multiply_by_inputs=False)

        looss = nn.CrossEntropyLoss()
        print('Init known training with explanation...')
        number = 0
        # wandb.watch(self.model, log='all')
        for epoch in range(n_epochs):
            for i, batch in enumerate(self.KnownLoader):
                lb = batch[1].to(device)
                # print(batch[0].size())
                # # img = batch[0].to(device)
                # img = F.interpolate(batch[0],(100,1,224,224))
                # print(img.size())
                img = batch[0].to(device)
                # print(img.size())

                #define mask
                maskLb = batch[0].clone()
                maskLb = maskLb.squeeze()
                maskLb[maskLb == -0.5] = 0
                maskLb[maskLb != 0] = 1
                maskLb = maskLb.to(device)

                # Training
                optim.zero_grad()
                # a,b,c,out = self.model(img)
                out = self.model(img)
                predlb = torch.argmax(out, 1)
                # predlb = predlb.cpu().numpy()

                # print('Prediction label is :',predlb.cpu().numpy())
                # print('Ground Truth label is: ',lb.cpu().numpy())

                ##explain to me :
                gc_attr = layer_gc.attribute(img,
                                             target=predlb,
                                             relu_attributions=True)
                upsampled_attr = LayerAttribution.interpolate(
                    gc_attr, (64, 64))

                gc_attr = layer_gc.attribute(img,
                                             target=lb,
                                             relu_attributions=True)
                upsampled_attrB = LayerAttribution.interpolate(
                    gc_attr, (64, 64))
                exitpic = upsampled_attrB.clone()
                exitpic = exitpic.detach().cpu().numpy()
                # wandb.log({"Examples": exitpic})
                print(self.model.layer2[1].conv2.grads.data)

                # upsampled_attr_B = upsampled_attr.clone()
                # sz = upsampled_attr.size()
                # x = upsampled_attr_B.view(sz[0],sz[1], -1)
                # # print(x.size())
                # upsampled_attr_B = F.softmax(x,dim=2).view_as(upsampled_attr_B)

                ########################################
                # grid = torchvision.utils.make_grid(img)
                # self.writer.add_image('images', grid, 0)
                # self.writer.add_graph(self.model, img)

                # baseLine = torch.zeros(img.size())
                # # baseLine = baseLine[:1]
                # # print(baseLine.size())
                # baseLine = baseLine.to(device)

                # DLS_attr,delta = layer_DLS.attribute(img,baseLine,target=predlb,return_convergence_delta =True)
                # upsampled_attrDLS = LayerAttribution.interpolate(DLS_attr, (64, 64))
                # upsampled_attrDLSSum = torch.sum(upsampled_attrDLS,dim=(1),keepdim=True)
                # print(upsampled_attrDLSSum.size())
                # print(delta.size(),DLS_attr.size())

                # if number % 60 ==0:

                #   z = torch.eq(lb,predlb)
                #   z = ~z
                #   z = z.nonzero()
                #   try:
                #     z = z.cpu().numpy()[-1]
                #   except:
                #     z = [0]
                #   # if z.size().cpu()>0:
                #   print(lb[z[0]],predlb[z[0]],z[0])
                #   ################################################
                #   plotMe = viz.visualize_image_attr(upsampled_attr[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       method='heat_map',
                #                       sign='absolute_value', plt_fig_axis=None, outlier_perc=2,
                #                       cmap='inferno', alpha_overlay=0.2, show_colorbar=True,
                #                       title=str(predlb[z[0]]),
                #                       fig_size=(8, 10), use_pyplot=True)

                #   plotMe[0].savefig(str(number)+'NotEQPred.jpg')
                #   ################################################

                #   plotMe = viz.visualize_image_attr(upsampled_attrB[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]),
                #                       method='heat_map',
                #                       sign='absolute_value', plt_fig_axis=None, outlier_perc=2,
                #                       cmap='inferno', alpha_overlay=0.9, show_colorbar=True,
                #                       title=str(lb[z[0]].cpu()),
                #                       fig_size=(8, 10), use_pyplot=True)

                #   plotMe[0].savefig(str(number)+'NotEQLabel.jpg')
                #   ################################################

                #   outImg = img[z[0]].squeeze().detach().cpu().numpy()
                #   fig2 = plt.figure(figsize=(12,12))
                #   prImg = plt.imshow(outImg)
                #   fig2.savefig(str(number)+'NotEQOrig.jpg')
                #   ################################################
                #   fig = plt.figure(figsize=(15,10))
                #   ax = fig.add_subplot(111, projection='3d')

                #   z = upsampled_attr[z[0]].squeeze().detach().cpu().numpy()
                #   x = np.arange(0,64,1)
                #   y = np.arange(0,64,1)
                #   X, Y = np.meshgrid(x, y)

                #   plll = ax.plot_surface(X, Y , z, cmap=cm.coolwarm)
                #   # Customize the z axis.
                #   ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z))
                #   ax.zaxis.set_major_locator(LinearLocator(10))
                #   ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

                #   # Add a color bar which maps values to colors.
                #   fig.colorbar(plll, shrink=0.5, aspect=5)
                #   fig.savefig(str(number)+'NotEQ3D.jpg')
                #######explainVis####################
                # if number%30 == 0:
                #   self.vis_explanation(number)
                # self.vis_explanation(number)
                #####################################
                # reluMe = nn.ReLU(upsampled_attr)
                # upsampled_attr = reluMe(upsampled_attr)
                size_batch = upsampled_attr.size()[0]
                # upsampled_attr = F.relu(upsampled_attr, inplace=False)
                # print(upsampled_attr.size(),self.sintetic.size())
                losssemantic = criteria2(
                    upsampled_attr[:size_batch, :, :, :],
                    self.sintetic[:size_batch, :, :, :].to(device))

                loss1 = criteria1(out, lb)
                # lossall = 0.7*loss1 + 0.3*losssemantic
                lossall = losssemantic.to(device)
                optim.zero_grad()

                # loss2 = criteria21(upsampled_attr.squeeze()*64,maskLb)

                # loss3 = criteriaL2(maskLb*upsampled_attr.squeeze(),maskLb*upsampled_attrB.squeeze())
                # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze()*64,img.squeeze()*upsampled_attrB.squeeze()*64)
                # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze(),img.squeeze()*upsampled_attrB.squeeze())

                if number % 30 == 0:
                    # print()
                    print(loss1, losssemantic)
                # loss3 = torch.log(-loss3)
                # lossall = 0.7*loss1 + 0.3*loss2
                # lossall = 0*loss1 + 0.3*loss3
                # lossall = loss3
                # print('Losss to cjeck is:--- ',torch.max(loss3))
                avg_loss = torch.mean(lossall)
                lossall.backward()
                optim.step()
                # print(avg_loss)
                number += 1
                # if number == 2:
                # gradds = []
                # for tag, parm in self.model.state_dict().items():
                #   if parm.requires_grad:
                #     # print(p.name, p.data)
                #     gradds.append(parm.grad.data.cpu().numpy())
                # self.writer.add_histogram('grads', torch.from_numpy(gradds), number)

                # self.writer.his
                # self.plot_grad_flow(self.model.state_dict().items(),number)

                if number % 10 == 0:
                    # print(number)
                    print(
                        "Epoch: {}/{} batch: {}/{} iteration: {}/{} average-loss: {:0.4f}"
                        .format(epoch + 1, n_epochs, i + 1,
                                len(self.KnownLoader), ite + 1, max_ite,
                                avg_loss.cpu()))
        # self.writer.close()
        # Save checkpoint
        if (os.path.exists("./checkpoint")):
            torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth")
        else:
            os.mkdir('checkpoint')
            torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth")
        number = 0