def explanation(self, dataindex): oldIndices = self.unknown.indices.copy() self.unknown.indices = dataindex datasetLoader = torch.utils.data.DataLoader(dataset=self.unknown, batch_size=1, shuffle=False) self.model.eval() # for param in self.model.parameters(): # param.requires_grad = False avg_loss = [] #Dont forget to replace indices at end ########## layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2) #deep lift dl = LayerDeepLift(self.model, self.model.layer1[0].conv2) # atrr = [] plt.figure(figsize=(18, 10)) for i, batch in enumerate(datasetLoader): lb = batch[1].to(device) print(len(lb)) img = batch[0].to(device) # plt.subplot(2,1,1) # plt.imshow(img.squeeze().cpu().numpy()) lbin = batch[1].cpu().numpy() print(lbin) pred = self.model(img) predlb = torch.argmax(pred, 1) print('Prediction label is :', predlb.cpu().numpy()) print('Ground Truth label is: ', lb.cpu().numpy()) # gc_attr = layer_gc.attribute(img, target=int(lbin[0])) gc_attr = layer_gc.attribute(img, target=int(predlb.cpu().numpy())) upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28)) base = torch.zeros([1, 1, 28, 28]).to(device) de_attr = dl.attribute(img, base, target=int(lbin[0])) dl_upsampled_attr = LayerAttribution.interpolate(de_attr, (28, 28)) # upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28)) # plt.subplot(2,1,2) # plt.imshow(upsampled_attr.squeeze().detach().cpu().numpy()) # atrr.append[gc_attr] print("done ...") # print(gc_attr,upsampled_attr.squeeze().detach().cpu().numpy()) # plt.show() return img, gc_attr, upsampled_attr.squeeze().detach().cpu().numpy( ), dl_upsampled_attr.squeeze().detach().cpu().numpy()
def layer_grad_cam(model, image, results_dir, file_name): layer_gc = LayerGradCam(model, model.skip_4[0]) attr = layer_gc.attribute(image) attr = LayerAttribution.interpolate(attr, (121, 145, 121)) attr = attr.squeeze() attr = attr.detach().numpy() attr = nib.Nifti1Image(attr, affine=np.eye(4)) nib.save(attr, results_dir + file_name + "-HM.nii")
def run(arch, img, target): input = Image.open(img).convert('RGB') input = apply_transform(input) model = models.vgg16(pretrained=True).eval() ig = LayerGradCam(model, model.features[28]) out = FF.softmax(model(input), dim=1) class_idx = out.max(1)[-1].item() attr = ig.attribute(input, target=target) attr = LayerAttribution.interpolate(attr, (224, 224)) attr = (attr - attr.min()) / (attr.max() - attr.min()) #attr=attr.squeeze(0).squeeze(0) #print('IG Attributions:',attr, attr.shape) return attr
def validate_explanation(self): predlist = torch.zeros(0, dtype=torch.long, device='cpu') lbllist = torch.zeros(0, dtype=torch.long, device='cpu') scores = [] layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2) for i, batch in enumerate(self.test_loader): lb = batch[1].to(device) img = batch[0].to(device) pred = self.model(img) predlb = torch.argmax(pred, 1) gc_attr = layer_gc.attribute(img, target=predlb, relu_attributions=False) upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64)) sz = upsampled_attr.size() x = upsampled_attr.view(sz[0], sz[1], -1) # print(x.size()) upsampled_attr_soft = F.softmax(x, dim=2).view_as(upsampled_attr) # print(upsampled_attr_soft.size()) # upsampled_attr = F.softmax(upsampled_attr) # gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=False) # upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64)) # Append batch prediction results predlist = torch.cat( [predlist, upsampled_attr_soft.detach().squeeze().cpu()]) lbllist = torch.cat([ lbllist, self.sintetic[:upsampled_attr_soft.size( )[0], :, :, :].squeeze().cpu() ], dim=0) print(predlist.size, lbllist.size) final_prec, final_rec, final_corr = self.calculate_measures( lbllist, predlist) # return final_prec, final_rec, final_corr # Save checkpoint print('Final validation result...') print("- final explanation percision: {:.3f}".format(final_prec)) print("- final explanation recal: {:.3f}".format(final_rec)) print("- final explanation correlation: {:.3f}".format(final_corr))
layer_gradcam = LayerGradCam(model, model.layer3[1].conv2) attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx) _ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute( 1, 2, 0).detach().numpy(), sign="all", title="Layer 3 Block 1 Conv 2") ########################################################################## # We’ll use the convenience method ``interpolate()`` in the # `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__ # base class to upsample this attribution data for comparison to the input # image. # upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, input_img.shape[2:]) print(attributions_lgc.shape) print(upsamp_attr_lgc.shape) print(input_img.shape) _ = viz.visualize_image_attr_multiple( upsamp_attr_lgc[0].cpu().permute(1, 2, 0).detach().numpy(), transformed_img.permute(1, 2, 0).numpy(), ["original_image", "blended_heat_map", "masked_image"], ["all", "positive", "positive"], show_colorbar=True, titles=["Original", "Positive Attribution", "Masked"], fig_size=(18, 6)) #######################################################################
def train_single_scale(D, G, reals, generators, noise_maps, input_from_prev_scale, noise_amplitudes, opt): """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the original level, generators and noise_maps contain information from previous scales and will receive information in this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """ current_scale = len(generators) real = reals[current_scale] keepSky = False kernel_dims = (2, 2) # Initialize real detector real0 = preprocess(opt, real, keepSky) N, C, H, W = real0.shape scale = opt.scales[current_scale] if current_scale < len(opt.scales) else 1 if opt.cgan: detector = PCA_Detector(opt, 'real', real0, kernel_dims) real_detection_map = detector(real0) detection_scale = 0.1 real_detection_map *= detection_scale real1 = torch.cat( [real, F.interpolate(real_detection_map, (H, W))], dim=1) divergences = [] else: real1 = real if opt.game == 'mario': token_group = MARIO_TOKEN_GROUPS else: # if opt.game == 'mariokart': token_group = MARIOKART_TOKEN_GROUPS nzx = real.shape[2] # Noise size x nzy = real.shape[3] # Noise size y padsize = int( 1 * opt.num_layer ) # As kernel size is always 3 currently, padsize goes up by one per layer if not opt.pad_with_noise: pad_noise = nn.ZeroPad2d(padsize) pad_image = nn.ZeroPad2d(padsize) else: pad_noise = nn.ReflectionPad2d(padsize) pad_image = nn.ReflectionPad2d(padsize) # setup optimizer optimizerD = optim.Adam(D.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600, 2500], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600, 2500], gamma=opt.gamma) if current_scale == 0: # Generate new noise z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_opt = pad_noise(z_opt) else: # Add noise to previous output z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device) z_opt = pad_noise(z_opt) logger.info("Training at scale {}", current_scale) for epoch in tqdm(range(opt.niter)): step = current_scale * opt.niter + epoch noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) noise_ = pad_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real D.zero_grad() output = D(real1).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) # train with fake if (j == 0) & (epoch == 0): if current_scale == 0: # If we are in the lowest scale, noise is generated from scratch prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) input_from_prev_scale = prev prev = pad_image(prev) z_prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) z_prev = pad_noise(z_prev) opt.noise_amp = 1 else: # First step in NOT the lowest scale # We need to adapt our inputs from the previous scale and add noise to it prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) z_prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rec", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): z_prev = group_to_token(z_prev, opt.token_list, token_group) z_prev = interpolate(z_prev, real1.shape[-2:], mode="bilinear", align_corners=False) opt.noise_amp = update_noise_amplitude( z_prev, real1[:, :-1], opt) z_prev = pad_image(z_prev) else: # Any other step prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) # After creating our correct noise input, we feed it to the generator: noise = opt.noise_amp * noise_ + prev fake = G(noise.detach(), prev, temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) if opt.cgan: Nf, Cf, Hf, Wf = fake0.shape fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake # Then run the result through the discriminator output = D(fake1.detach()) errD_fake = output.mean() # Backpropagation errD_fake.backward(retain_graph=False) # Gradient Penalty gradient_penalty = calc_gradient_penalty(D, real1, fake1, opt.lambda_grad, opt.device) gradient_penalty.backward(retain_graph=False) # Logging: if step % 10 == 0: wandb.log( { f"D(G(z))@{current_scale}": errD_fake.item(), f"D(x)@{current_scale}": -errD_real.item(), f"gradient_penalty@{current_scale}": gradient_penalty.item() }, step=step, sync=False) optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): G.zero_grad() fake = G(noise.detach(), prev.detach(), temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) Nf, Cf, Hf, Wf = fake0.shape if opt.cgan: fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake output = D(fake1) errG = -output.mean() errG.backward(retain_graph=False) if opt.alpha != 0: # i. e. we are trying to find an exact recreation of our input in the lat space Z_opt = opt.noise_amp * z_opt + z_prev G_rec = G( Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1) rec_loss = opt.alpha * F.mse_loss(G_rec, real) if opt.cgan: div = divergence(real_detection_map, preprocess(opt, G_rec, keepSky)) rec_loss += div rec_loss.backward( retain_graph=False ) # TODO: Check for unexpected argument retain_graph=True rec_loss = rec_loss.detach() else: # We are not trying to find an exact recreation rec_loss = torch.zeros([]) Z_opt = z_opt optimizerG.step() # More Logging: div = divergence(real_detection_map, preprocess(opt, fake, keepSky)) divergences.append(div) # logger.info("divergence(fake) = {}", div) if step % 10 == 0: wandb.log( { f"noise_amplitude@{current_scale}": opt.noise_amp, f"rec_loss@{current_scale}": rec_loss.item() }, step=step, sync=False, commit=True) # Rendering and logging images of levels if epoch % 500 == 0 or epoch == (opt.niter - 1): if opt.token_insert >= 0 and opt.nc_current == len(token_group): token_list = [list(group.keys())[0] for group in token_group] else: token_list = opt.token_list img = opt.ImgGen.render( one_hot_to_ascii_level(fake1[:, :-1].detach(), token_list)) img2 = opt.ImgGen.render( one_hot_to_ascii_level( G(Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1).detach(), token_list)) real_scaled = one_hot_to_ascii_level(real1[:, :-1].detach(), token_list) img3 = opt.ImgGen.render(real_scaled) wandb.log( { f"G(z)@{current_scale}": wandb.Image(img), f"G(z_opt)@{current_scale}": wandb.Image(img2), f"real@{current_scale}": wandb.Image(img3) }, sync=False, commit=False) real_scaled_path = os.path.join(wandb.run.dir, f"real@{current_scale}.txt") with open(real_scaled_path, "w") as f: f.writelines(real_scaled) wandb.save(real_scaled_path) # Learning Rate scheduler step schedulerD.step() schedulerG.step() if opt.cgan: div = divergence(real_detection_map, preprocess(opt, z_opt, keepSky)) divergences.append(div) # visualization config folder_name = 'gradcam' level_name = opt.input_name.rsplit(".", 1)[0].split("_", 1)[1] # GradCAM on D camD = LayerGradCam(D, D.tail) real0 = one_hot_to_ascii_level(real, opt.token_list) real0 = opt.ImgGen.render(real0) real0 = np.array(real0) attr = camD.attribute(real1, target=(0, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() fig, ax = plt.subplots(1, 1) fig.figsize = (10, 1) ax.imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = ax.imshow(attr, cmap='jet', alpha=0.5) ax.axis('off') fig.colorbar(im, ax=ax, location='bottom', shrink=0.85) plt.suptitle(f'cGAN {level_name} D(x)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_D_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # GradCAM on G token_names = { 'M': 'Mario start', 'F': 'Mario finish', 'y': 'spiky', 'Y': 'winged spiky', 'k': 'green koopa', 'K': 'winged green koopa', '!': 'coin [?]', '#': 'pyramid', '-': 'sky', '1': 'invis. 1 up', '2': 'invis. coin', 'L': '1 up', '?': 'special [?]', '@': 'special [?]', 'Q': 'coin [?]', '!': 'coin [?]', 'C': 'coin brick', 'S': 'normal brick', 'U': 'mushroom brick', 'X': 'ground', 'E': 'goomba', 'g': 'goomba', 'k': 'green koopa', '%': 'platform', '|': 'platform bg', 'r': 'red koopa', 'R': 'winged red koopa', 'o': 'coin', 't': 'pipe', 'T': 'plant pipe', '*': 'bullet bill', '<': 'pipe top left', '>': 'pipe top right', '[': 'pipe left', ']': 'pipe right', 'B': 'bullet bill head', 'b': 'bullet bill body', 'D': 'used block', } def wrappedG(z): return G(z, z_opt) camG = LayerGradCam(wrappedG, G.tail[0]) z_cam = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_cam = pad_noise(z_cam) attrs = [] for i in range(opt.nc_current): attr = camG.attribute(z_cam, target=(i, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() attrs.append(attr) fig, axs = plt.subplots(opt.nc_current, 1) fig.figsize = (10, opt.nc_current) for i in range(opt.nc_current): axs[i].axis('off') axs[i].text(-0.1, 0.5, token_names[opt.token_list[i]], rotation=0, verticalalignment='center', horizontalalignment='right', transform=axs[i].transAxes) axs[i].imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = axs[i].imshow(attrs[i], cmap='jet', alpha=0.5) fig.colorbar(im, ax=axs, shrink=0.85) plt.suptitle(f'cGAN {level_name} G(z)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_G_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # Save networks torch.save(z_opt, "%s/z_opt.pth" % opt.outf) save_networks(G, D, z_opt, opt) wandb.save(opt.outf) return z_opt, input_from_prev_scale, G, divergences
def get_insights(self, tensor_data, _, target=0): default_cmap = LinearSegmentedColormap.from_list( "custom blue", [(0, "#ffffff"), (0.25, "#0000ff"), (1, "#0000ff")], N=256, ) attributions_ig, _ = self.attribute_image_features( self.ig, tensor_data, baselines=tensor_data * 0, return_convergence_delta=True, n_steps=15, ) attributions_occ = self.attribute_image_features( self.occlusion, tensor_data, strides=(3, 8, 8), sliding_window_shapes=(3, 15, 15), baselines=tensor_data * 0, ) attributions_lgc = self.attribute_image_features( self.layer_gradcam, tensor_data) upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, tensor_data.shape[2:]) matplot_viz_ig, _ = viz.visualize_image_attr_multiple( np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)), np.transpose(tensor_data.squeeze().cpu().detach().numpy(), (1, 2, 0)), use_pyplot=False, methods=["original_image", "heat_map"], cmap=default_cmap, show_colorbar=True, signs=["all", "positive"], titles=["Original", "Integrated Gradients"], ) matplot_viz_occ, _ = viz.visualize_image_attr_multiple( np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1, 2, 0)), np.transpose(tensor_data.squeeze().cpu().detach().numpy(), (1, 2, 0)), [ "original_image", "heat_map", "heat_map", ], ["all", "positive", "negative"], show_colorbar=True, titles=[ "Original", "Positive Attribution", "Negative Attribution", ], fig_size=(18, 6), use_pyplot=False, ) matplot_viz_lgc, _ = viz.visualize_image_attr_multiple( upsamp_attr_lgc[0].cpu().permute(1, 2, 0).detach().numpy(), tensor_data.squeeze().permute(1, 2, 0).cpu().numpy(), use_pyplot=False, methods=["original_image", "blended_heat_map", "blended_heat_map"], signs=["all", "positive", "negative"], show_colorbar=True, titles=[ "Original", "Positive Attribution", "Negative Attribution", ], fig_size=(18, 6)) occ_bytes = self.output_bytes(matplot_viz_occ) ig_bytes = self.output_bytes(matplot_viz_ig) lgc_bytes = self.output_bytes(matplot_viz_lgc) output = [{ "b64": b64encode(row).decode("utf8") } if isinstance(row, (bytes, bytearray)) else row for row in [ig_bytes, occ_bytes, lgc_bytes]] return output
def vis_explanation(self, number): if len(self.explainVis) == 0: for i, batch in enumerate(self.test_loader): self.explainVis = batch break # oldIndices = self.test_loader.indices.copy() # self.test_loader.indices = self.test_loader.indices[:2] # datasetLoader = self.test_loader layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2) # for i, batch in enumerate(datasetLoader): lb = self.explainVis[1].to(device) # print(len(lb)) img = self.explainVis[0].to(device) # plt.subplot(2,1,1) # plt.imshow(img.squeeze().cpu().numpy()) pred = self.model(img) predlb = torch.argmax(pred, 1) imgCQ = img.clone() # print('Prediction label is :',predlb.cpu().numpy()) # print('Ground Truth label is: ',lb.cpu().numpy()) ##explain to me : gc_attr = layer_gc.attribute(imgCQ, target=predlb, relu_attributions=False) upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64)) gc_attr = layer_gc.attribute(imgCQ, target=lb, relu_attributions=False) upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64)) if not os.path.exists('./pic'): os.mkdir('./pic') ####PLot################################################ plotMe = viz.visualize_image_attr( upsampled_attr[7].detach().cpu().numpy().transpose([1, 2, 0]), original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]), method='heat_map', sign='all', plt_fig_axis=None, outlier_perc=2, cmap='inferno', alpha_overlay=0.2, show_colorbar=True, title=str(predlb[7]), fig_size=(8, 10), use_pyplot=True) plotMe[0].savefig('./pic/' + str(number) + 'NotEQPred.jpg') ################################################ plotMe = viz.visualize_image_attr( upsampled_attrB[7].detach().cpu().numpy().transpose([1, 2, 0]), original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]), method='heat_map', sign='all', plt_fig_axis=None, outlier_perc=2, cmap='inferno', alpha_overlay=0.9, show_colorbar=True, title=str(lb[7].cpu()), fig_size=(8, 10), use_pyplot=True) plotMe[0].savefig('./pic/' + str(number) + 'NotEQLabel.jpg') ################################################ outImg = img[7].squeeze().detach().cpu().numpy() fig2 = plt.figure(figsize=(12, 12)) prImg = plt.imshow(outImg) fig2.savefig('./pic/' + str(number) + 'NotEQOrig.jpg') ################################################ fig = plt.figure(figsize=(15, 10)) ax = fig.add_subplot(111, projection='3d') z = upsampled_attr[7].squeeze().detach().cpu().numpy() x = np.arange(0, 64, 1) y = np.arange(0, 64, 1) X, Y = np.meshgrid(x, y) plll = ax.plot_surface(X, Y, z, cmap=cm.coolwarm) # Customize the z axis. # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z)) ax.set_zlim(-0.02, 0.1) ax.zaxis.set_major_locator(LinearLocator(10)) ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) # Add a color bar which maps values to colors. fig.colorbar(plll, shrink=0.5, aspect=5) fig.savefig('./pic/' + str(number) + 'NotEQ3D.jpg')
def train_known_expl(self, n_epochs, ite, max_ite): self.model.train() avg_loss = [] # if os.path.exists('./checkpoint'): # try: # print('model found ...') # self.model.load_state_dict(torch.load('./checkpoint/resnet18.pth')) # print('Model loaded sucessfully.') # continue # except: # print('Not found any model ... ') optim = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0) # lossOH = OhemCELoss(0.2,50) #cause batch is 100 criteria1 = nn.CrossEntropyLoss() criteria2 = nn.BCELoss() criteria21 = nn.BCEWithLogitsLoss() criteriaL2 = nn.MSELoss() layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2) # layer_DLS = LayerDeepLiftShap(self.model, self.model.layer1[0].conv2, multiply_by_inputs=False) looss = nn.CrossEntropyLoss() print('Init known training with explanation...') number = 0 # wandb.watch(self.model, log='all') for epoch in range(n_epochs): for i, batch in enumerate(self.KnownLoader): lb = batch[1].to(device) # print(batch[0].size()) # # img = batch[0].to(device) # img = F.interpolate(batch[0],(100,1,224,224)) # print(img.size()) img = batch[0].to(device) # print(img.size()) #define mask maskLb = batch[0].clone() maskLb = maskLb.squeeze() maskLb[maskLb == -0.5] = 0 maskLb[maskLb != 0] = 1 maskLb = maskLb.to(device) # Training optim.zero_grad() # a,b,c,out = self.model(img) out = self.model(img) predlb = torch.argmax(out, 1) # predlb = predlb.cpu().numpy() # print('Prediction label is :',predlb.cpu().numpy()) # print('Ground Truth label is: ',lb.cpu().numpy()) ##explain to me : gc_attr = layer_gc.attribute(img, target=predlb, relu_attributions=True) upsampled_attr = LayerAttribution.interpolate( gc_attr, (64, 64)) gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=True) upsampled_attrB = LayerAttribution.interpolate( gc_attr, (64, 64)) exitpic = upsampled_attrB.clone() exitpic = exitpic.detach().cpu().numpy() # wandb.log({"Examples": exitpic}) print(self.model.layer2[1].conv2.grads.data) # upsampled_attr_B = upsampled_attr.clone() # sz = upsampled_attr.size() # x = upsampled_attr_B.view(sz[0],sz[1], -1) # # print(x.size()) # upsampled_attr_B = F.softmax(x,dim=2).view_as(upsampled_attr_B) ######################################## # grid = torchvision.utils.make_grid(img) # self.writer.add_image('images', grid, 0) # self.writer.add_graph(self.model, img) # baseLine = torch.zeros(img.size()) # # baseLine = baseLine[:1] # # print(baseLine.size()) # baseLine = baseLine.to(device) # DLS_attr,delta = layer_DLS.attribute(img,baseLine,target=predlb,return_convergence_delta =True) # upsampled_attrDLS = LayerAttribution.interpolate(DLS_attr, (64, 64)) # upsampled_attrDLSSum = torch.sum(upsampled_attrDLS,dim=(1),keepdim=True) # print(upsampled_attrDLSSum.size()) # print(delta.size(),DLS_attr.size()) # if number % 60 ==0: # z = torch.eq(lb,predlb) # z = ~z # z = z.nonzero() # try: # z = z.cpu().numpy()[-1] # except: # z = [0] # # if z.size().cpu()>0: # print(lb[z[0]],predlb[z[0]],z[0]) # ################################################ # plotMe = viz.visualize_image_attr(upsampled_attr[z[0]].detach().cpu().numpy().transpose([1,2,0]), # original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]), # method='heat_map', # sign='absolute_value', plt_fig_axis=None, outlier_perc=2, # cmap='inferno', alpha_overlay=0.2, show_colorbar=True, # title=str(predlb[z[0]]), # fig_size=(8, 10), use_pyplot=True) # plotMe[0].savefig(str(number)+'NotEQPred.jpg') # ################################################ # plotMe = viz.visualize_image_attr(upsampled_attrB[z[0]].detach().cpu().numpy().transpose([1,2,0]), # original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]), # method='heat_map', # sign='absolute_value', plt_fig_axis=None, outlier_perc=2, # cmap='inferno', alpha_overlay=0.9, show_colorbar=True, # title=str(lb[z[0]].cpu()), # fig_size=(8, 10), use_pyplot=True) # plotMe[0].savefig(str(number)+'NotEQLabel.jpg') # ################################################ # outImg = img[z[0]].squeeze().detach().cpu().numpy() # fig2 = plt.figure(figsize=(12,12)) # prImg = plt.imshow(outImg) # fig2.savefig(str(number)+'NotEQOrig.jpg') # ################################################ # fig = plt.figure(figsize=(15,10)) # ax = fig.add_subplot(111, projection='3d') # z = upsampled_attr[z[0]].squeeze().detach().cpu().numpy() # x = np.arange(0,64,1) # y = np.arange(0,64,1) # X, Y = np.meshgrid(x, y) # plll = ax.plot_surface(X, Y , z, cmap=cm.coolwarm) # # Customize the z axis. # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z)) # ax.zaxis.set_major_locator(LinearLocator(10)) # ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) # # Add a color bar which maps values to colors. # fig.colorbar(plll, shrink=0.5, aspect=5) # fig.savefig(str(number)+'NotEQ3D.jpg') #######explainVis#################### # if number%30 == 0: # self.vis_explanation(number) # self.vis_explanation(number) ##################################### # reluMe = nn.ReLU(upsampled_attr) # upsampled_attr = reluMe(upsampled_attr) size_batch = upsampled_attr.size()[0] # upsampled_attr = F.relu(upsampled_attr, inplace=False) # print(upsampled_attr.size(),self.sintetic.size()) losssemantic = criteria2( upsampled_attr[:size_batch, :, :, :], self.sintetic[:size_batch, :, :, :].to(device)) loss1 = criteria1(out, lb) # lossall = 0.7*loss1 + 0.3*losssemantic lossall = losssemantic.to(device) optim.zero_grad() # loss2 = criteria21(upsampled_attr.squeeze()*64,maskLb) # loss3 = criteriaL2(maskLb*upsampled_attr.squeeze(),maskLb*upsampled_attrB.squeeze()) # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze()*64,img.squeeze()*upsampled_attrB.squeeze()*64) # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze(),img.squeeze()*upsampled_attrB.squeeze()) if number % 30 == 0: # print() print(loss1, losssemantic) # loss3 = torch.log(-loss3) # lossall = 0.7*loss1 + 0.3*loss2 # lossall = 0*loss1 + 0.3*loss3 # lossall = loss3 # print('Losss to cjeck is:--- ',torch.max(loss3)) avg_loss = torch.mean(lossall) lossall.backward() optim.step() # print(avg_loss) number += 1 # if number == 2: # gradds = [] # for tag, parm in self.model.state_dict().items(): # if parm.requires_grad: # # print(p.name, p.data) # gradds.append(parm.grad.data.cpu().numpy()) # self.writer.add_histogram('grads', torch.from_numpy(gradds), number) # self.writer.his # self.plot_grad_flow(self.model.state_dict().items(),number) if number % 10 == 0: # print(number) print( "Epoch: {}/{} batch: {}/{} iteration: {}/{} average-loss: {:0.4f}" .format(epoch + 1, n_epochs, i + 1, len(self.KnownLoader), ite + 1, max_ite, avg_loss.cpu())) # self.writer.close() # Save checkpoint if (os.path.exists("./checkpoint")): torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth") else: os.mkdir('checkpoint') torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth") number = 0