def main(mixup=False): prefix = "Mixup" if mixup else "Large" run = wandb.init( name=f"Interpretability ({prefix})", project="ct-interpretability", dir=DEFAULT_DATA_STORAGE, reinit=True, ) model = get_model(mixup) dataset = get_miccai_2d( "test", transform=DEGREE[model.hparams.transform_degree]["test"], enhanced="Boundary" in model.hparams.loss_fx, ) class_labels = dict(zip(range(1, model._n_classes), miccai.STRUCTURES)) class_labels[0] = "Void" step = 0 for sample in tqdm(dataset): preproc_img, masks, _, *others = sample normalized_inp = preproc_img.unsqueeze(0).to(device) normalized_inp.requires_grad = True masks = _squash_masks(masks, 10, masks.device) if len(masks.unique()) < 6: # Only displaying structures with atleast 5 structures (excluding background) continue out = model(normalized_inp) out_max = _squash_predictions(out).unsqueeze(1) log_samples(preproc_img, masks, out_max, class_labels, step) def segmentation_wrapper(input): return model(input).sum(dim=(2, 3)) layer = model.unet.model[2][1].conv.unit0.conv lgc = LayerGradCam(segmentation_wrapper, layer) figures = [] for structure in miccai.STRUCTURES: idx = structures.index(structure) gc_attr = lgc.attribute(normalized_inp, target=idx) fig, ax = viz.visualize_image_attr( gc_attr[0].cpu().permute(1, 2, 0).detach().numpy(), sign="all", use_pyplot=False, ) ax.set_title(structure) figures.append(wandb.Image(fig)) wandb.log({"GradCam Attributions": figures}, step=step) step += 1 run.finish()
def layer_grad_cam(model, image, results_dir, file_name): layer_gc = LayerGradCam(model, model.skip_4[0]) attr = layer_gc.attribute(image) attr = LayerAttribution.interpolate(attr, (121, 145, 121)) attr = attr.squeeze() attr = attr.detach().numpy() attr = nib.Nifti1Image(attr, affine=np.eye(4)) nib.save(attr, results_dir + file_name + "-HM.nii")
def get_grad_cam(x, y, model, is_training=True): ''' Choose the last conv layer which only has 7x7 out of 224x224 ''' with torch.enable_grad(), \ HackGradAndOutputs(is_training=is_training) as hack: lgc = LayerGradCam(model, model.get_grad_cam_layer()) attributions = lgc.attribute(x, target=y) attributions = F.interpolate(attributions, size=x.shape[-2:], mode='bilinear') return attributions, hack.output
def run(arch, img, target): input = Image.open(img).convert('RGB') input = apply_transform(input) model = models.vgg16(pretrained=True).eval() ig = LayerGradCam(model, model.features[28]) out = FF.softmax(model(input), dim=1) class_idx = out.max(1)[-1].item() attr = ig.attribute(input, target=target) attr = LayerAttribution.interpolate(attr, (224, 224)) attr = (attr - attr.min()) / (attr.max() - attr.min()) #attr=attr.squeeze(0).squeeze(0) #print('IG Attributions:',attr, attr.shape) return attr
def explanation(self, dataindex): oldIndices = self.unknown.indices.copy() self.unknown.indices = dataindex datasetLoader = torch.utils.data.DataLoader(dataset=self.unknown, batch_size=1, shuffle=False) self.model.eval() # for param in self.model.parameters(): # param.requires_grad = False avg_loss = [] #Dont forget to replace indices at end ########## layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2) #deep lift dl = LayerDeepLift(self.model, self.model.layer1[0].conv2) # atrr = [] plt.figure(figsize=(18, 10)) for i, batch in enumerate(datasetLoader): lb = batch[1].to(device) print(len(lb)) img = batch[0].to(device) # plt.subplot(2,1,1) # plt.imshow(img.squeeze().cpu().numpy()) lbin = batch[1].cpu().numpy() print(lbin) pred = self.model(img) predlb = torch.argmax(pred, 1) print('Prediction label is :', predlb.cpu().numpy()) print('Ground Truth label is: ', lb.cpu().numpy()) # gc_attr = layer_gc.attribute(img, target=int(lbin[0])) gc_attr = layer_gc.attribute(img, target=int(predlb.cpu().numpy())) upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28)) base = torch.zeros([1, 1, 28, 28]).to(device) de_attr = dl.attribute(img, base, target=int(lbin[0])) dl_upsampled_attr = LayerAttribution.interpolate(de_attr, (28, 28)) # upsampled_attr = LayerAttribution.interpolate(gc_attr, (28, 28)) # plt.subplot(2,1,2) # plt.imshow(upsampled_attr.squeeze().detach().cpu().numpy()) # atrr.append[gc_attr] print("done ...") # print(gc_attr,upsampled_attr.squeeze().detach().cpu().numpy()) # plt.show() return img, gc_attr, upsampled_attr.squeeze().detach().cpu().numpy( ), dl_upsampled_attr.squeeze().detach().cpu().numpy()
def explain_gradXact(model, node_idx, x, edge_index, target, include_edges=None): # Captum default implementation of LayerGradCam does not average over nodes for different channels because of # different assumptions on tensor shapes input_mask = x.clone().requires_grad_(True).to(device) layers = get_all_convolution_layers(model) node_attrs = [] for layer in layers: layer_gc = LayerGradCam(model_forward_node, layer) node_attr = layer_gc.attribute(input_mask, target=target, additional_forward_args=(model, edge_index, node_idx)) node_attr = node_attr.cpu().detach().numpy().ravel() node_attrs.append(node_attr) node_attr = np.array(node_attrs).mean(axis=0) edge_mask = node_attr_to_edge(edge_index, node_attr) return edge_mask
def __init__(self, trainer): CaptumDerivative.__init__(self, trainer) model = self.trainer adaptive_idx = 0 self.configs.update({"relu_attributions": True}) # As in original GradCam paper found_avg_pool = False for adaptive_idx, mod in enumerate(trainer.model.children()): if isinstance(mod, nn.AdaptiveAvgPool2d): found_avg_pool = True break assert found_avg_pool, "This implementation assumes that final spatial dimension is reduced via AvgPool." \ "If you want to use it, change the model accordingly or update this code." layer = model.trainer.model[adaptive_idx - 1] LayerGradCam.__init__(self, model, layer)
def validate_explanation(self): predlist = torch.zeros(0, dtype=torch.long, device='cpu') lbllist = torch.zeros(0, dtype=torch.long, device='cpu') scores = [] layer_gc = LayerGradCam(self.model, self.model.layer1[0].conv2) for i, batch in enumerate(self.test_loader): lb = batch[1].to(device) img = batch[0].to(device) pred = self.model(img) predlb = torch.argmax(pred, 1) gc_attr = layer_gc.attribute(img, target=predlb, relu_attributions=False) upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64)) sz = upsampled_attr.size() x = upsampled_attr.view(sz[0], sz[1], -1) # print(x.size()) upsampled_attr_soft = F.softmax(x, dim=2).view_as(upsampled_attr) # print(upsampled_attr_soft.size()) # upsampled_attr = F.softmax(upsampled_attr) # gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=False) # upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64)) # Append batch prediction results predlist = torch.cat( [predlist, upsampled_attr_soft.detach().squeeze().cpu()]) lbllist = torch.cat([ lbllist, self.sintetic[:upsampled_attr_soft.size( )[0], :, :, :].squeeze().cpu() ], dim=0) print(predlist.size, lbllist.size) final_prec, final_rec, final_corr = self.calculate_measures( lbllist, predlist) # return final_prec, final_rec, final_corr # Save checkpoint print('Final validation result...') print("- final explanation percision: {:.3f}".format(final_prec)) print("- final explanation recal: {:.3f}".format(final_rec)) print("- final explanation correlation: {:.3f}".format(final_corr))
def cal(name, model, image_orl, size, label_true): func = make_transform(args, mode="inference") image = func(image_orl) image = image.unsqueeze(0) image = image.to(args.gpu, dtype=torch.float32) output = model(image) output = F.softmax(output, dim=1) prediction_score, pred_label_idx = torch.topk(output, 1) pred_label_idx.squeeze() predicted_label = str(pred_label_idx.item()) if int(predicted_label) != 1 and int(predicted_label) != label_true: # print("predict wrong") count.append(1) return None # print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')') if args.att_type == "GradCam": gradients = LayerGradCam(model, layer=model.layer4) att_map = make_grad(name, gradients, image, image_orl, 'GradCam', size) if args.att_type == "DeepLIFT": gradients = LayerDeepLift(model, layer=model.layer4) att_map = make_grad(name, gradients, image, image_orl, 'DeepLIFT', size) return att_map
def compute_gradcam(self, img_path, target): # open image img, transformed_img, input = self.open_image(img_path) # grad cam input.requires_grad = True gradcam = LayerGradCam(self.model, self.layer) attr = gradcam.attribute(input, target) cam = (attr.squeeze().cpu().detach().numpy()) cam = np.maximum(cam, 0) cam = cv2.resize(cam, input.shape[2:]) cam = cam - np.min(cam) cam = cam / np.max(cam) return cam
def compute_attr_one_pixel_target(x, net, spatial_coords, method, **kwargs): # x is a single input tensor, i.e. shape (batch_size,channel_size,H,W)=(1,1,28,28) from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop if 'target' in kwargs: target = kwargs['target'] if 'wrapper_output' in kwargs: output_mode = kwargs['wrapper_output'] else: output_mode = 'yg_pixel' idx,idy = spatial_coords wnet = WrapperNet(net, output_mode=output_mode, spatial_coords=(idx,idy)) if method=='gradCAM': xai = LayerGradCam(wnet, wnet.main_net.channel_adj) elif method=='deconv': xai = Deconvolution(wnet) elif method=='GuidedBP': xai = GuidedBackprop(wnet) if method in ['gradCAM', 'deconv', 'GuidedBP']: attr = xai.attribute(x, target=target ) elif method == 'layerAct': attr = xai.attribute(x) attr = attr[0][0].clone().detach().cpu().numpy() return attr
def initialize(self, ctx): # pylint: disable=arguments-differ """In this initialize function, the CIFAR10 trained model is loaded and the Integrated Gradients,occlusion and layer_gradcam Algorithm for Captum Explanations is initialized here. Args: ctx (context): It is a JSON Object containing information pertaining to the model artifacts parameters. """ self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") print("Model dir is {}".format(model_dir)) serialized_file = self.manifest["model"]["serializedFile"] mapping_file_path = os.path.join(model_dir, "index_to_name.json") if os.path.exists(mapping_file_path): with open(mapping_file_path) as fp: self.mapping = json.load(fp) model_pt_path = os.path.join(model_dir, serialized_file) self.device = torch.device( "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" ) from cifar10_train import CIFAR10Classifier self.model = CIFAR10Classifier() self.model.load_state_dict(torch.load(model_pt_path)) self.model.to(self.device) self.model.eval() self.model.zero_grad() logger.info("CIFAR10 model from path %s loaded successfully", model_dir) # Read the mapping file, index to object name mapping_file_path = os.path.join(model_dir, "class_mapping.json") if os.path.isfile(mapping_file_path): print("Mapping file present") with open(mapping_file_path) as pointer: self.mapping = json.load(pointer) else: print("Mapping file missing") logger.warning("Missing the class_mapping.json file.") self.ig = IntegratedGradients(self.model) self.layer_gradcam = LayerGradCam(self.model, self.model.model_conv.layer4[2].conv3) self.occlusion = Occlusion(self.model) self.initialized = True self.image_processing = transforms.Compose( [ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] )
def initialize(self, ctx): #pylint: disable=arguments-differ """In this initialize function, the Titanic trained model is loaded and the Integrated Gradients Algorithm for Captum Explanations is initialized here. Args: ctx (context): It is a JSON Object containing information pertaining to the model artifacts parameters. """ self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") print("Model dir is {}".format(model_dir)) serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) self.device = torch.device( #pylint: disable=no-member "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") self.model = CIFAR10CLASSIFIER() self.model.load_state_dict(torch.load(model_pt_path)) self.model.to(self.device) self.model.eval() self.model.zero_grad() logger.info("CIFAR10 model from path %s loaded successfully", model_dir) # Read the mapping file, index to object name mapping_file_path = os.path.join(model_dir, "class_mapping.json") if os.path.isfile(mapping_file_path): print("Mapping file present") with open(mapping_file_path) as file_pointer: self.mapping = json.load(file_pointer) else: print("Mapping file missing") logger.warning("Missing the class_mapping.json file.") self.ig = IntegratedGradients(self.model) self.layer_gradcam = LayerGradCam( self.model, self.model.model_conv.layer4[2].conv3) self.occlusion = Occlusion(self.model) self.initialized = True
def heatmaps(args): print('heatmaps') PROJECT_ID = args['PROJECT_ID'] CKPT_DIR, PROJECT_DIR, MODEL_DIR, LOGGER_DIR, load_model = folder_check(PROJECT_ID, CKPT_DIR='checkpoint') XAI_DIR = os.path.join(PROJECT_DIR,'XAI') if not os.path.exists(XAI_DIR): os.mkdir(XAI_DIR) from .sampler import Pytorch_GPT_MNIST_Sampler samp = Pytorch_GPT_MNIST_Sampler(compenv_mode=None, growth_mode=None) from .model import ResGPTNet34 net = ResGPTNet34(nG0=samp.gen.nG0, Nj=samp.gen.N_neighbor) net = torch.load(MODEL_DIR) net.output_mode = 'prediction_only' net.to(device=device) net.eval() x, y0, yg0, ys0 = samp.get_sample_batch(class_indices=np.array(range(10)), device=device) x.requires_grad=True attrs = {} SAVE_DIR = os.path.join(XAI_DIR, 'heatmaps.y0.jpeg') from captum.attr import LayerGradCam, Deconvolution, GuidedBackprop # ShapleyValueSampling xai = LayerGradCam(net, net.channel_adj) attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy() attrs['gradCAM'] = attr xai = Deconvolution(net) attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy() attrs['deconv'] = attr xai = GuidedBackprop(net) attr = xai.attribute(x, target=y0).clone().detach().cpu().numpy() attrs['GuidedBP'] = attr arrange_heatmaps(x.clone().detach().cpu().numpy() , attrs, save_dir=SAVE_DIR)
# # GradCAM computes the gradients of the target output with respect to the # given layer, averages for each output channel (dimension 2 of output), # and multiplies the average gradient for each channel by the layer # activations. The results are summed over all channels. GradCAM is # designed for convnets; since the activity of convolutional layers often # maps spatially to the input, GradCAM attributions are often upsampled # and used to mask the input. # # Layer attribution is set up similarly to input attribution, except that # in addition to the model, you must specify a hidden layer within the # model that you wish to examine. As above, when we call ``attribute()``, # we specify the target class of interest. # layer_gradcam = LayerGradCam(model, model.layer3[1].conv2) attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx) _ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute( 1, 2, 0).detach().numpy(), sign="all", title="Layer 3 Block 1 Conv 2") ########################################################################## # We’ll use the convenience method ``interpolate()`` in the # `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__ # base class to upsample this attribution data for comparison to the input # image. # upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc,
def get_grad_imp(model, X, y=None, mode='grad', return_y=False, clip=False, baselines=None): X.requires_grad_(True) # X = X.cuda() if mode in ['grad']: logits = model(X) if y is None: y = logits.argmax(dim=1) attributions = torch.autograd.grad( logits[torch.arange(len(logits)), y].sum(), X)[0].detach() else: if y is None: with torch.no_grad(): logits = model(X) y = logits.argmax(dim=1) if mode == 'deeplift': dl = DeepLift(model) attributions = dl.attribute(inputs=X, baselines=0., target=y) attributions = attributions.detach() # attributions = (attributions.detach() ** 2).sum(dim=1, keepdim=True) # elif mode in ['deepliftshap', 'deepliftshap_mean']: elif mode in ['deepliftshap']: dl = DeepLiftShap(model) attributions = [] for idx in range(0, len(X), 2): the_x, the_y = X[idx:(idx + 2)], y[idx:(idx + 2)] attribution = dl.attribute(inputs=the_x, baselines=baselines, target=the_y) attributions.append(attribution.detach()) attributions = torch.cat(attributions, dim=0) # attributions = dl.attribute(inputs=X, baselines=baselines, target=y).detach() # if mode == 'deepliftshap': # attributions = (attributions ** 2).sum(dim=1, keepdim=True) # else: # attributions = (attributions).mean(dim=1, keepdim=True) elif mode in ['gradcam']: orig_lgc = LayerGradCam(model, model.body[0]) attributions = orig_lgc.attribute(X, target=y) attributions = F.interpolate(attributions, size=X.shape[-2:], mode='bilinear') else: raise NotImplementedError(f'${mode} is not specified.') # Do clipping! if clip: attributions = myclip(attributions) X.requires_grad_(False) if not return_y: return attributions return attributions, y
def get_attribution(real_img, fake_img, real_class, fake_class, net_module, checkpoint_path, input_shape, channels, methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"], output_classes=6, downsample_factors=[(2,2), (2,2), (2,2), (2,2)]): imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), image_to_tensor(normalize_image(fake_img).astype(np.float32))] classes = [real_class, fake_class] net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False, downsample_factors=downsample_factors) attrs = [] attrs_names = [] if "residual" in methods: res = np.abs(real_img - fake_img) res = res - np.min(res) attrs.append(torch.tensor(res/np.max(res))) attrs_names.append("residual") if "random" in methods: rand = np.abs(np.random.randn(*np.shape(real_img))) rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4)) rand = rand - np.min(rand) rand = rand/np.max(np.abs(rand)) attrs.append(torch.tensor(rand)) attrs_names.append("random") if "gc" in methods: net.zero_grad() last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] layer_name = last_conv_layer[0] layer = last_conv_layer[1] layer_gc = LayerGradCam(net, layer) gc_real = layer_gc.attribute(imgs[0], target=classes[0]) gc_fake = layer_gc.attribute(imgs[1], target=classes[1]) gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1])) gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1])) attrs.append(torch.tensor(gc_real[0,0,:,:])) attrs_names.append("gc_real") attrs.append(torch.tensor(gc_fake[0,0,:,:])) attrs_names.append("gc_fake") # SCAM gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, fake_class, net_module, checkpoint_path, input_shape, channels, None, output_classes=output_classes, downsample_factors=downsample_factors) attrs.append(gc_diff_0) attrs_names.append("gc_diff_0") attrs.append(gc_diff_1) attrs_names.append("gc_diff_1") if "ggc" in methods: net.zero_grad() last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1] guided_gc = GuidedGradCam(net, last_conv) ggc_real = guided_gc.attribute(imgs[0], target=classes[0]) ggc_fake = guided_gc.attribute(imgs[1], target=classes[1]) attrs.append(ggc_real[0,0,:,:]) attrs_names.append("ggc_real") attrs.append(ggc_fake[0,0,:,:]) attrs_names.append("ggc_fake") net.zero_grad() gbp = GuidedBackprop(net) gbp_real = gbp.attribute(imgs[0], target=classes[0]) gbp_fake = gbp.attribute(imgs[1], target=classes[1]) attrs.append(gbp_real[0,0,:,:]) attrs_names.append("gbp_real") attrs.append(gbp_fake[0,0,:,:]) attrs_names.append("gbp_fake") ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0 ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1 attrs.append(ggc_diff_0) attrs_names.append("ggc_diff_0") attrs.append(ggc_diff_1) attrs_names.append("ggc_diff_1") # IG if "ig" in methods: baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32)) net.zero_grad() ig = IntegratedGradients(net) ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True) ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True) ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True) ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True) attrs.append(ig_real[0,0,:,:]) attrs_names.append("ig_real") attrs.append(ig_fake[0,0,:,:]) attrs_names.append("ig_fake") attrs.append(ig_diff_0[0,0,:,:]) attrs_names.append("ig_diff_0") attrs.append(ig_diff_1[0,0,:,:]) attrs_names.append("ig_diff_1") # DL if "dl" in methods: net.zero_grad() dl = DeepLift(net) dl_real = dl.attribute(imgs[0], target=classes[0]) dl_fake = dl.attribute(imgs[1], target=classes[1]) dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0]) dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1]) attrs.append(dl_real[0,0,:,:]) attrs_names.append("dl_real") attrs.append(dl_fake[0,0,:,:]) attrs_names.append("dl_fake") attrs.append(dl_diff_0[0,0,:,:]) attrs_names.append("dl_diff_0") attrs.append(dl_diff_1[0,0,:,:]) attrs_names.append("dl_diff_1") # INGRAD if "ingrad" in methods: net.zero_grad() saliency = Saliency(net) grads_real = saliency.attribute(imgs[0], target=classes[0]) grads_fake = saliency.attribute(imgs[1], target=classes[1]) attrs.append(grads_real[0,0,:,:]) attrs_names.append("grads_real") attrs.append(grads_fake[0,0,:,:]) attrs_names.append("grads_fake") net.zero_grad() input_x_gradient = InputXGradient(net) ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0]) ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1]) ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1]) ingrad_diff_1 = grads_real * (imgs[1] - imgs[0]) attrs.append(torch.abs(ingrad_real[0,0,:,:])) attrs_names.append("ingrad_real") attrs.append(torch.abs(ingrad_fake[0,0,:,:])) attrs_names.append("ingrad_fake") attrs.append(torch.abs(ingrad_diff_0[0,0,:,:])) attrs_names.append("ingrad_diff_0") attrs.append(torch.abs(ingrad_diff_1[0,0,:,:])) attrs_names.append("ingrad_diff_1") attrs = [a.detach().cpu().numpy() for a in attrs] attrs_norm = [a/np.max(np.abs(a)) for a in attrs] return attrs_norm, attrs_names
attr_gbp, i = torch.max(attr_gbp, dim=1) attr_gbp = attr_gbp.unsqueeze(0) visualize_attr_maps('visualization/Captum_Guided_BackProp.png', X, y, class_names, attr_gbp, ['GuidedBackprop'], lambda attr: attr.detach().numpy()) # Try out different layers # and see observe how the attributions change layer = model.features[3] layer_act = LayerActivation(model, layer) layer_act_attr = compute_attributions(layer_act, X_tensor) layer_act_attr_sum = layer_act_attr.mean(axis=1, keepdim=True) # Layer gradcam aggregates across all channels layer_gradcam = LayerGradCam(model, layer) layer_gradcam_attr = compute_attributions(layer_gradcam, X_tensor, target=y_tensor, relu_attributions=True) layer_gradcam_attr_sum = layer_gradcam_attr.mean(axis=1, keepdim=True) layer_gradcam_attr_sum = layer_gradcam_attr_sum.permute(1, 0, 2, 3) visualize_attr_maps('visualization/layer_gradcam.png', X, y, class_names, layer_gradcam_attr_sum, ['layer_gradcam'], lambda attr: attr.detach().numpy()) layer_conduct = LayerConductance(model, layer) layer_conduct_attr = compute_attributions(layer_conduct, X_tensor, target=y_tensor) layer_conduct_attr_sum = layer_conduct_attr.mean(axis=1, keepdim=True)
def train_single_scale(D, G, reals, generators, noise_maps, input_from_prev_scale, noise_amplitudes, opt): """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the original level, generators and noise_maps contain information from previous scales and will receive information in this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """ current_scale = len(generators) real = reals[current_scale] keepSky = False kernel_dims = (2, 2) # Initialize real detector real0 = preprocess(opt, real, keepSky) N, C, H, W = real0.shape scale = opt.scales[current_scale] if current_scale < len(opt.scales) else 1 if opt.cgan: detector = PCA_Detector(opt, 'real', real0, kernel_dims) real_detection_map = detector(real0) detection_scale = 0.1 real_detection_map *= detection_scale real1 = torch.cat( [real, F.interpolate(real_detection_map, (H, W))], dim=1) divergences = [] else: real1 = real if opt.game == 'mario': token_group = MARIO_TOKEN_GROUPS else: # if opt.game == 'mariokart': token_group = MARIOKART_TOKEN_GROUPS nzx = real.shape[2] # Noise size x nzy = real.shape[3] # Noise size y padsize = int( 1 * opt.num_layer ) # As kernel size is always 3 currently, padsize goes up by one per layer if not opt.pad_with_noise: pad_noise = nn.ZeroPad2d(padsize) pad_image = nn.ZeroPad2d(padsize) else: pad_noise = nn.ReflectionPad2d(padsize) pad_image = nn.ReflectionPad2d(padsize) # setup optimizer optimizerD = optim.Adam(D.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600, 2500], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600, 2500], gamma=opt.gamma) if current_scale == 0: # Generate new noise z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_opt = pad_noise(z_opt) else: # Add noise to previous output z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device) z_opt = pad_noise(z_opt) logger.info("Training at scale {}", current_scale) for epoch in tqdm(range(opt.niter)): step = current_scale * opt.niter + epoch noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) noise_ = pad_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real D.zero_grad() output = D(real1).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) # train with fake if (j == 0) & (epoch == 0): if current_scale == 0: # If we are in the lowest scale, noise is generated from scratch prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) input_from_prev_scale = prev prev = pad_image(prev) z_prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) z_prev = pad_noise(z_prev) opt.noise_amp = 1 else: # First step in NOT the lowest scale # We need to adapt our inputs from the previous scale and add noise to it prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) z_prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rec", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): z_prev = group_to_token(z_prev, opt.token_list, token_group) z_prev = interpolate(z_prev, real1.shape[-2:], mode="bilinear", align_corners=False) opt.noise_amp = update_noise_amplitude( z_prev, real1[:, :-1], opt) z_prev = pad_image(z_prev) else: # Any other step prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) # After creating our correct noise input, we feed it to the generator: noise = opt.noise_amp * noise_ + prev fake = G(noise.detach(), prev, temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) if opt.cgan: Nf, Cf, Hf, Wf = fake0.shape fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake # Then run the result through the discriminator output = D(fake1.detach()) errD_fake = output.mean() # Backpropagation errD_fake.backward(retain_graph=False) # Gradient Penalty gradient_penalty = calc_gradient_penalty(D, real1, fake1, opt.lambda_grad, opt.device) gradient_penalty.backward(retain_graph=False) # Logging: if step % 10 == 0: wandb.log( { f"D(G(z))@{current_scale}": errD_fake.item(), f"D(x)@{current_scale}": -errD_real.item(), f"gradient_penalty@{current_scale}": gradient_penalty.item() }, step=step, sync=False) optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): G.zero_grad() fake = G(noise.detach(), prev.detach(), temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) Nf, Cf, Hf, Wf = fake0.shape if opt.cgan: fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake output = D(fake1) errG = -output.mean() errG.backward(retain_graph=False) if opt.alpha != 0: # i. e. we are trying to find an exact recreation of our input in the lat space Z_opt = opt.noise_amp * z_opt + z_prev G_rec = G( Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1) rec_loss = opt.alpha * F.mse_loss(G_rec, real) if opt.cgan: div = divergence(real_detection_map, preprocess(opt, G_rec, keepSky)) rec_loss += div rec_loss.backward( retain_graph=False ) # TODO: Check for unexpected argument retain_graph=True rec_loss = rec_loss.detach() else: # We are not trying to find an exact recreation rec_loss = torch.zeros([]) Z_opt = z_opt optimizerG.step() # More Logging: div = divergence(real_detection_map, preprocess(opt, fake, keepSky)) divergences.append(div) # logger.info("divergence(fake) = {}", div) if step % 10 == 0: wandb.log( { f"noise_amplitude@{current_scale}": opt.noise_amp, f"rec_loss@{current_scale}": rec_loss.item() }, step=step, sync=False, commit=True) # Rendering and logging images of levels if epoch % 500 == 0 or epoch == (opt.niter - 1): if opt.token_insert >= 0 and opt.nc_current == len(token_group): token_list = [list(group.keys())[0] for group in token_group] else: token_list = opt.token_list img = opt.ImgGen.render( one_hot_to_ascii_level(fake1[:, :-1].detach(), token_list)) img2 = opt.ImgGen.render( one_hot_to_ascii_level( G(Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1).detach(), token_list)) real_scaled = one_hot_to_ascii_level(real1[:, :-1].detach(), token_list) img3 = opt.ImgGen.render(real_scaled) wandb.log( { f"G(z)@{current_scale}": wandb.Image(img), f"G(z_opt)@{current_scale}": wandb.Image(img2), f"real@{current_scale}": wandb.Image(img3) }, sync=False, commit=False) real_scaled_path = os.path.join(wandb.run.dir, f"real@{current_scale}.txt") with open(real_scaled_path, "w") as f: f.writelines(real_scaled) wandb.save(real_scaled_path) # Learning Rate scheduler step schedulerD.step() schedulerG.step() if opt.cgan: div = divergence(real_detection_map, preprocess(opt, z_opt, keepSky)) divergences.append(div) # visualization config folder_name = 'gradcam' level_name = opt.input_name.rsplit(".", 1)[0].split("_", 1)[1] # GradCAM on D camD = LayerGradCam(D, D.tail) real0 = one_hot_to_ascii_level(real, opt.token_list) real0 = opt.ImgGen.render(real0) real0 = np.array(real0) attr = camD.attribute(real1, target=(0, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() fig, ax = plt.subplots(1, 1) fig.figsize = (10, 1) ax.imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = ax.imshow(attr, cmap='jet', alpha=0.5) ax.axis('off') fig.colorbar(im, ax=ax, location='bottom', shrink=0.85) plt.suptitle(f'cGAN {level_name} D(x)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_D_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # GradCAM on G token_names = { 'M': 'Mario start', 'F': 'Mario finish', 'y': 'spiky', 'Y': 'winged spiky', 'k': 'green koopa', 'K': 'winged green koopa', '!': 'coin [?]', '#': 'pyramid', '-': 'sky', '1': 'invis. 1 up', '2': 'invis. coin', 'L': '1 up', '?': 'special [?]', '@': 'special [?]', 'Q': 'coin [?]', '!': 'coin [?]', 'C': 'coin brick', 'S': 'normal brick', 'U': 'mushroom brick', 'X': 'ground', 'E': 'goomba', 'g': 'goomba', 'k': 'green koopa', '%': 'platform', '|': 'platform bg', 'r': 'red koopa', 'R': 'winged red koopa', 'o': 'coin', 't': 'pipe', 'T': 'plant pipe', '*': 'bullet bill', '<': 'pipe top left', '>': 'pipe top right', '[': 'pipe left', ']': 'pipe right', 'B': 'bullet bill head', 'b': 'bullet bill body', 'D': 'used block', } def wrappedG(z): return G(z, z_opt) camG = LayerGradCam(wrappedG, G.tail[0]) z_cam = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_cam = pad_noise(z_cam) attrs = [] for i in range(opt.nc_current): attr = camG.attribute(z_cam, target=(i, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() attrs.append(attr) fig, axs = plt.subplots(opt.nc_current, 1) fig.figsize = (10, opt.nc_current) for i in range(opt.nc_current): axs[i].axis('off') axs[i].text(-0.1, 0.5, token_names[opt.token_list[i]], rotation=0, verticalalignment='center', horizontalalignment='right', transform=axs[i].transAxes) axs[i].imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = axs[i].imshow(attrs[i], cmap='jet', alpha=0.5) fig.colorbar(im, ax=axs, shrink=0.85) plt.suptitle(f'cGAN {level_name} G(z)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_G_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # Save networks torch.save(z_opt, "%s/z_opt.pth" % opt.outf) save_networks(G, D, z_opt, opt) wandb.save(opt.outf) return z_opt, input_from_prev_scale, G, divergences
def LayerGradCAM(classifier_model, config, dataset_features, GNNgraph_list, current_fold=None, cuda=0): ''' Attribute to input layer using soft assign :param classifier_model: trained classifier model :param config: parsed configuration file of config.yml :param dataset_features: a dictionary of dataset features obtained from load_data.py :param GNNgraph_list: a list of GNNgraphs obtained from the dataset :param current_fold: has no use in this method :param cuda: whether to use GPU to perform conversion to tensor ''' # Initialise settings config = config interpretability_config = config["interpretability_methods"][ "LayerGradCAM"] dataset_features = dataset_features assign_type = interpretability_config["assign_attribution"] # Perform grad cam on the classifier model and on a specific layer layer_idx = interpretability_config["layer"] if layer_idx == 0: gc = LayerGradCam(classifier_model, classifier_model.graph_convolution) else: gc = LayerGradCam(classifier_model, classifier_model.conv_modules[layer_idx - 1]) output_for_metrics_calculation = [] output_for_generating_saliency_map = {} # Obtain attribution score for use in qualitative metrics tmp_timing_list = [] for GNNgraph in GNNgraph_list: output = {'graph': GNNgraph} for _, label in dataset_features["label_dict"].items(): # Relabel all just in case, may only relabel those that need relabelling # if performance is poor original_label = GNNgraph.label GNNgraph.label = label node_feat, n2n, subg = graph_to_tensor( [GNNgraph], dataset_features["feat_dim"], dataset_features["edge_feat_dim"], cuda) start_generation = perf_counter() attribution = gc.attribute(node_feat, additional_forward_args=(n2n, subg, [GNNgraph]), target=label, relu_attributions=True) # Attribute to the input layer using the assign method specified reverse_assign_tensor_list = [] for i in range(1, layer_idx + 1): assign_tensor = classifier_model.cur_assign_tensor_list[i - 1] max_index = torch.argmax(assign_tensor, dim=1, keepdim=True) if assign_type == "hard": reverse_assign_tensor = torch.transpose( torch.zeros(assign_tensor.size()).scatter_(1, max_index, value=1), 0, 1) else: reverse_assign_tensor = torch.transpose( assign_tensor, 0, 1) reverse_assign_tensor_list.append(reverse_assign_tensor) attribution = torch.transpose(attribution, 0, 1) for reverse_tensor in reversed(reverse_assign_tensor_list): attribution = attribution @ reverse_tensor attribution = torch.transpose(attribution, 0, 1) tmp_timing_list.append(perf_counter() - start_generation) attribution_score = torch.sum(attribution, dim=1).tolist() attribution_score = standardize_scores(attribution_score) GNNgraph.label = original_label output[label] = attribution_score output_for_metrics_calculation.append(output) execution_time = sum(tmp_timing_list) / (len(tmp_timing_list)) # Obtain attribution score for use in generating saliency map for comparison with zero tensors if interpretability_config["sample_ids"] is not None: if ',' in str(interpretability_config["sample_ids"]): sample_graph_id_list = list( map(int, interpretability_config["sample_ids"].split(','))) else: sample_graph_id_list = [int(interpretability_config["sample_ids"])] output_for_generating_saliency_map.update({ "layergradcam_%s_%s" % (str(assign_type), str(label)): [] for _, label in dataset_features["label_dict"].items() }) for index in range(len(output_for_metrics_calculation)): tmp_output = output_for_metrics_calculation[index] tmp_label = tmp_output['graph'].label if tmp_output['graph'].graph_id in sample_graph_id_list: element_name = "layergradcam_%s_%s" % (str(assign_type), str(tmp_label)) output_for_generating_saliency_map[element_name].append( (tmp_output['graph'], tmp_output[tmp_label])) elif interpretability_config["number_of_samples"] > 0: # Randomly sample from existing list: graph_idxes = list(range(len(output_for_metrics_calculation))) random.shuffle(graph_idxes) output_for_generating_saliency_map.update({ "layergradcam_%s_%s" % (str(assign_type), str(label)): [] for _, label in dataset_features["label_dict"].items() }) # Begin appending found samples for index in graph_idxes: tmp_label = output_for_metrics_calculation[index]['graph'].label element_name = "layergradcam_%s_%s" % (str(assign_type), str(tmp_label)) if len(output_for_generating_saliency_map[element_name] ) < interpretability_config["number_of_samples"]: output_for_generating_saliency_map[element_name].append( (output_for_metrics_calculation[index]['graph'], output_for_metrics_calculation[index][tmp_label])) return output_for_metrics_calculation, output_for_generating_saliency_map, execution_time
process_sequence(gradcam, test_loader, model_type, device, output_file) #%% captum grad cam experiment_num = 4 #test_list_full = [4,11,18,22,23,24,25,26,27,28,29,32,33,34,36,37,38,39] test_list_full = [32, 34] train_list = [1, 3, 5, 7, 8, 10, 12, 14, 15, 17, 19, 21] for test in test_list_full: test_list = [test] loader_dict, loader_sizes = dat.init_dataset(train_list, train_list, test_list, model_type, config_dict) test_loader = loader_dict['test'] captum_gc = LayerGradCam(model, target_layer) output_file = file_dir + '/heatmaps/imageset_captum_new_' + model_type + "_" + str( test) process_sequence_captum(captum_gc, test_loader, model_type, device, output_file) #%% finalconv_name = 'layer4' target_layer = model.cnn._modules.get(finalconv_name) inputs, inputs_aug, _ = next(iter(loader_dict['train'])) captum_gc = LayerGradCam(model, target_layer) attributions = captum_gc.attribute((inputs.to( device, dtype=torch.float), inputs_aug.to(device, dtype=torch.float)), target=0, relu_attributions=True) upsampled_attr = captum_gc.interpolate(attributions, (224, 224), 'bicubic')
32,33, 34,36, 37,38,39, 45,46,47] condition_list = ['center','right','left', 'right_less','right_less', 'right_more','right_more', 'left_less','left_less', 'left_more','left_more', 'new_tool','new_tool', 'new_material','new_material', 'center','right','left', 'z_mid','z_high','z_low'] ''' test_list_full = [34] train_list = [1, 3, 5, 7, 8, 10, 12, 14, 15, 17, 19, 21, 41, 42] for test in test_list_full: test_list = [test] loader_dict, loader_sizes = dat.init_dataset(train_list, train_list, test_list, model_type, config_dict) test_loader = loader_dict['test'] captum_gc = LayerGradCam(model, target_layer) output_file = 'heatmaps/imageset_captum' + model_type + "_" + str(test) process_sequence_captum(captum_gc, test_loader, model_type, device, output_file)
model.to(DEVICE) img, img_t = load_preprocess_image(args.img, gray=args.gray) img_t = img_t.to(DEVICE) img_t.requires_grad = True input_shape = img.shape[:2] cent_idx = (img_t.shape[0] // 2, img_t.shape[1] // 2) # Grab attribution mapping attr_map_ = {} for name, layer in model.named_children(): if "up" in name: attr_map_[name] = LayerGradCam(model.forward, model.up3) data_attr_ = {} for name, act in attr_map_.items(): data_attr_[name] = act.attribute(img_t, target=(1, ) + cent_idx) ## Plot attributions fig: plt.Figure = plt.figure(figsize=(12, 4)) for k, (layer_name, attrib) in enumerate(data_attr_.items()): plt.subplot(1, 3, k + 1) attr_arr = attrib.data.cpu().numpy()[0, 0] plt.imshow(attr_arr) plt.title("GradCam center px %s" % layer_name) fig.tight_layout() plt.show()
def vis_explanation(self, number): if len(self.explainVis) == 0: for i, batch in enumerate(self.test_loader): self.explainVis = batch break # oldIndices = self.test_loader.indices.copy() # self.test_loader.indices = self.test_loader.indices[:2] # datasetLoader = self.test_loader layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2) # for i, batch in enumerate(datasetLoader): lb = self.explainVis[1].to(device) # print(len(lb)) img = self.explainVis[0].to(device) # plt.subplot(2,1,1) # plt.imshow(img.squeeze().cpu().numpy()) pred = self.model(img) predlb = torch.argmax(pred, 1) imgCQ = img.clone() # print('Prediction label is :',predlb.cpu().numpy()) # print('Ground Truth label is: ',lb.cpu().numpy()) ##explain to me : gc_attr = layer_gc.attribute(imgCQ, target=predlb, relu_attributions=False) upsampled_attr = LayerAttribution.interpolate(gc_attr, (64, 64)) gc_attr = layer_gc.attribute(imgCQ, target=lb, relu_attributions=False) upsampled_attrB = LayerAttribution.interpolate(gc_attr, (64, 64)) if not os.path.exists('./pic'): os.mkdir('./pic') ####PLot################################################ plotMe = viz.visualize_image_attr( upsampled_attr[7].detach().cpu().numpy().transpose([1, 2, 0]), original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]), method='heat_map', sign='all', plt_fig_axis=None, outlier_perc=2, cmap='inferno', alpha_overlay=0.2, show_colorbar=True, title=str(predlb[7]), fig_size=(8, 10), use_pyplot=True) plotMe[0].savefig('./pic/' + str(number) + 'NotEQPred.jpg') ################################################ plotMe = viz.visualize_image_attr( upsampled_attrB[7].detach().cpu().numpy().transpose([1, 2, 0]), original_image=img[7].detach().cpu().numpy().transpose([1, 2, 0]), method='heat_map', sign='all', plt_fig_axis=None, outlier_perc=2, cmap='inferno', alpha_overlay=0.9, show_colorbar=True, title=str(lb[7].cpu()), fig_size=(8, 10), use_pyplot=True) plotMe[0].savefig('./pic/' + str(number) + 'NotEQLabel.jpg') ################################################ outImg = img[7].squeeze().detach().cpu().numpy() fig2 = plt.figure(figsize=(12, 12)) prImg = plt.imshow(outImg) fig2.savefig('./pic/' + str(number) + 'NotEQOrig.jpg') ################################################ fig = plt.figure(figsize=(15, 10)) ax = fig.add_subplot(111, projection='3d') z = upsampled_attr[7].squeeze().detach().cpu().numpy() x = np.arange(0, 64, 1) y = np.arange(0, 64, 1) X, Y = np.meshgrid(x, y) plll = ax.plot_surface(X, Y, z, cmap=cm.coolwarm) # Customize the z axis. # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z)) ax.set_zlim(-0.02, 0.1) ax.zaxis.set_major_locator(LinearLocator(10)) ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) # Add a color bar which maps values to colors. fig.colorbar(plll, shrink=0.5, aspect=5) fig.savefig('./pic/' + str(number) + 'NotEQ3D.jpg')
def train_known_expl(self, n_epochs, ite, max_ite): self.model.train() avg_loss = [] # if os.path.exists('./checkpoint'): # try: # print('model found ...') # self.model.load_state_dict(torch.load('./checkpoint/resnet18.pth')) # print('Model loaded sucessfully.') # continue # except: # print('Not found any model ... ') optim = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0) # lossOH = OhemCELoss(0.2,50) #cause batch is 100 criteria1 = nn.CrossEntropyLoss() criteria2 = nn.BCELoss() criteria21 = nn.BCEWithLogitsLoss() criteriaL2 = nn.MSELoss() layer_gc = LayerGradCam(self.model, self.model.layer2[1].conv2) # layer_DLS = LayerDeepLiftShap(self.model, self.model.layer1[0].conv2, multiply_by_inputs=False) looss = nn.CrossEntropyLoss() print('Init known training with explanation...') number = 0 # wandb.watch(self.model, log='all') for epoch in range(n_epochs): for i, batch in enumerate(self.KnownLoader): lb = batch[1].to(device) # print(batch[0].size()) # # img = batch[0].to(device) # img = F.interpolate(batch[0],(100,1,224,224)) # print(img.size()) img = batch[0].to(device) # print(img.size()) #define mask maskLb = batch[0].clone() maskLb = maskLb.squeeze() maskLb[maskLb == -0.5] = 0 maskLb[maskLb != 0] = 1 maskLb = maskLb.to(device) # Training optim.zero_grad() # a,b,c,out = self.model(img) out = self.model(img) predlb = torch.argmax(out, 1) # predlb = predlb.cpu().numpy() # print('Prediction label is :',predlb.cpu().numpy()) # print('Ground Truth label is: ',lb.cpu().numpy()) ##explain to me : gc_attr = layer_gc.attribute(img, target=predlb, relu_attributions=True) upsampled_attr = LayerAttribution.interpolate( gc_attr, (64, 64)) gc_attr = layer_gc.attribute(img, target=lb, relu_attributions=True) upsampled_attrB = LayerAttribution.interpolate( gc_attr, (64, 64)) exitpic = upsampled_attrB.clone() exitpic = exitpic.detach().cpu().numpy() # wandb.log({"Examples": exitpic}) print(self.model.layer2[1].conv2.grads.data) # upsampled_attr_B = upsampled_attr.clone() # sz = upsampled_attr.size() # x = upsampled_attr_B.view(sz[0],sz[1], -1) # # print(x.size()) # upsampled_attr_B = F.softmax(x,dim=2).view_as(upsampled_attr_B) ######################################## # grid = torchvision.utils.make_grid(img) # self.writer.add_image('images', grid, 0) # self.writer.add_graph(self.model, img) # baseLine = torch.zeros(img.size()) # # baseLine = baseLine[:1] # # print(baseLine.size()) # baseLine = baseLine.to(device) # DLS_attr,delta = layer_DLS.attribute(img,baseLine,target=predlb,return_convergence_delta =True) # upsampled_attrDLS = LayerAttribution.interpolate(DLS_attr, (64, 64)) # upsampled_attrDLSSum = torch.sum(upsampled_attrDLS,dim=(1),keepdim=True) # print(upsampled_attrDLSSum.size()) # print(delta.size(),DLS_attr.size()) # if number % 60 ==0: # z = torch.eq(lb,predlb) # z = ~z # z = z.nonzero() # try: # z = z.cpu().numpy()[-1] # except: # z = [0] # # if z.size().cpu()>0: # print(lb[z[0]],predlb[z[0]],z[0]) # ################################################ # plotMe = viz.visualize_image_attr(upsampled_attr[z[0]].detach().cpu().numpy().transpose([1,2,0]), # original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]), # method='heat_map', # sign='absolute_value', plt_fig_axis=None, outlier_perc=2, # cmap='inferno', alpha_overlay=0.2, show_colorbar=True, # title=str(predlb[z[0]]), # fig_size=(8, 10), use_pyplot=True) # plotMe[0].savefig(str(number)+'NotEQPred.jpg') # ################################################ # plotMe = viz.visualize_image_attr(upsampled_attrB[z[0]].detach().cpu().numpy().transpose([1,2,0]), # original_image=img[z[0]].detach().cpu().numpy().transpose([1,2,0]), # method='heat_map', # sign='absolute_value', plt_fig_axis=None, outlier_perc=2, # cmap='inferno', alpha_overlay=0.9, show_colorbar=True, # title=str(lb[z[0]].cpu()), # fig_size=(8, 10), use_pyplot=True) # plotMe[0].savefig(str(number)+'NotEQLabel.jpg') # ################################################ # outImg = img[z[0]].squeeze().detach().cpu().numpy() # fig2 = plt.figure(figsize=(12,12)) # prImg = plt.imshow(outImg) # fig2.savefig(str(number)+'NotEQOrig.jpg') # ################################################ # fig = plt.figure(figsize=(15,10)) # ax = fig.add_subplot(111, projection='3d') # z = upsampled_attr[z[0]].squeeze().detach().cpu().numpy() # x = np.arange(0,64,1) # y = np.arange(0,64,1) # X, Y = np.meshgrid(x, y) # plll = ax.plot_surface(X, Y , z, cmap=cm.coolwarm) # # Customize the z axis. # ax.set_zlim(np.min(z)+0.1*np.min(z),np.max(z)+0.1*np.max(z)) # ax.zaxis.set_major_locator(LinearLocator(10)) # ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) # # Add a color bar which maps values to colors. # fig.colorbar(plll, shrink=0.5, aspect=5) # fig.savefig(str(number)+'NotEQ3D.jpg') #######explainVis#################### # if number%30 == 0: # self.vis_explanation(number) # self.vis_explanation(number) ##################################### # reluMe = nn.ReLU(upsampled_attr) # upsampled_attr = reluMe(upsampled_attr) size_batch = upsampled_attr.size()[0] # upsampled_attr = F.relu(upsampled_attr, inplace=False) # print(upsampled_attr.size(),self.sintetic.size()) losssemantic = criteria2( upsampled_attr[:size_batch, :, :, :], self.sintetic[:size_batch, :, :, :].to(device)) loss1 = criteria1(out, lb) # lossall = 0.7*loss1 + 0.3*losssemantic lossall = losssemantic.to(device) optim.zero_grad() # loss2 = criteria21(upsampled_attr.squeeze()*64,maskLb) # loss3 = criteriaL2(maskLb*upsampled_attr.squeeze(),maskLb*upsampled_attrB.squeeze()) # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze()*64,img.squeeze()*upsampled_attrB.squeeze()*64) # loss3 = criteriaL2(img.squeeze()*upsampled_attr.squeeze(),img.squeeze()*upsampled_attrB.squeeze()) if number % 30 == 0: # print() print(loss1, losssemantic) # loss3 = torch.log(-loss3) # lossall = 0.7*loss1 + 0.3*loss2 # lossall = 0*loss1 + 0.3*loss3 # lossall = loss3 # print('Losss to cjeck is:--- ',torch.max(loss3)) avg_loss = torch.mean(lossall) lossall.backward() optim.step() # print(avg_loss) number += 1 # if number == 2: # gradds = [] # for tag, parm in self.model.state_dict().items(): # if parm.requires_grad: # # print(p.name, p.data) # gradds.append(parm.grad.data.cpu().numpy()) # self.writer.add_histogram('grads', torch.from_numpy(gradds), number) # self.writer.his # self.plot_grad_flow(self.model.state_dict().items(),number) if number % 10 == 0: # print(number) print( "Epoch: {}/{} batch: {}/{} iteration: {}/{} average-loss: {:0.4f}" .format(epoch + 1, n_epochs, i + 1, len(self.KnownLoader), ite + 1, max_ite, avg_loss.cpu())) # self.writer.close() # Save checkpoint if (os.path.exists("./checkpoint")): torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth") else: os.mkdir('checkpoint') torch.save(self.model.state_dict(), "./checkpoint/resnet18.pth") number = 0
def attribute(self, img, target, **kwargs): return LayerGradCam.interpolate( CaptumDerivative.attribute(self, img, torch.tensor(target).cuda()), img.shape[-2:])