def get_heatmap(self, _input, _label, cmap=jet): from trojanvision.utils import apply_cmap squeeze_flag = False if _input.dim() == 3: _input = _input.unsqueeze(0) # (N, C, H, W) squeeze_flag = True if isinstance(_label, int): _label = [_label] * len(_input) _label = torch.as_tensor(_label, device=_input.device) heatmap = _input _input.requires_grad_() _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, _input)[0] # (N,C,H,W) zero = torch.zeros_like(grad) grad = torch.where(grad < 0, zero, grad) _input.requires_grad_(False) heatmap = grad.abs().max(dim=1)[0] # (N,H,W) heatmap.sub_( heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_( heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap = apply_cmap(heatmap.detach().cpu(), cmap) return heatmap[0] if squeeze_flag else heatmap
def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor, method: str = 'grad_cam', cmap: Colormap = jet) -> torch.Tensor: squeeze_flag = False if len(_input.shape) == 3: _input = _input.unsqueeze(0) # (N, C, H, W) squeeze_flag = True if isinstance(_label, int): _label = [_label] * len(_input) _label = torch.as_tensor(_label, device=_input.device) heatmap = _input # linting purpose if method == 'grad_cam': feats = self._model.get_fm(_input).detach() # (N, C', H', W') feats.requires_grad_() _output: torch.Tensor = self._model.pool(feats) # (N, C', 1, 1) _output = self._model.flatten(_output) # (N, C') _output = self._model.classifier(_output) # (N, num_classes) _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, feats)[0] # (N, C',H', W') feats.requires_grad_(False) weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True) # (N, C',1,1) heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(0) # (N, 1, H', W') # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_(heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap: torch.Tensor = F.upsample(heatmap, _input.shape[-2:], mode='bilinear')[:, 0] # (N, H, W) # Note that we violate the image order convension (W, H, C) elif method == 'saliency_map': _input.requires_grad_() _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, _input)[0] # (N,C,H,W) _input.requires_grad_(False) heatmap = grad.abs().max(dim=1)[0] # (N,H,W) heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0]) heatmap.div_(heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0]) heatmap = apply_cmap(heatmap.detach().cpu(), cmap) return heatmap[0] if squeeze_flag else heatmap
def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor, method: str = 'grad_cam', cmap: Colormap = jet, mode: str = 'bicubic') -> torch.Tensor: r"""Use colormap :attr:`cmap` to get heatmap tensor of :attr:`_input` w.r.t. :attr:`_label` with :attr:`method`. Args: _input (torch.Tensor): The (batched) input tensor with shape ``([N], C, H, W)``. _label (torch.Tensor): The (batched) label tensor with shape ``([N])`` method (str): The method to calculate heatmap. Choose from ``['grad_cam', 'saliency_map']``. Defaults to ``'grad_cam'``. cmap (matplotlib.colors.Colormap): The colormap to use. mode (str): Passed to :any:`torch.nn.functional.interpolate`. Defaults to ``'bicubic'``. Returns: torch.Tensor: The heatmap tensor with shape ([N], C, H, W). Note: Most :any:`matplotlib.colors.Colormap` will return a 4-channel heatmap with alpha channel. See Also: https://keras.io/examples/vision/grad_cam/ :Example: .. code-block:: python :emphasize-lines: 30-32 import trojanvision from trojanvision.utils import superimpose import torchvision.transforms as transforms import torchvision.transforms.functional as F import PIL.Image as Image import os import wget env = trojanvision.environ.create(device='cpu') model = trojanvision.models.create( 'resnet152', data_shape=[3, 224, 224], official=True, norm_par={'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float)]) url = 'https://i.imgur.com/Bvro0YD.png' if not os.path.isfile('african_elephant.png'): wget.download(url, 'african_elephant.png') img = Image.open('african_elephant.png').convert(mode='RGB') _input = transform(img).unsqueeze(0).to(env['device']) _prob = model.get_prob(_input).squeeze() label = _prob.argmax().item() conf = _prob[label].item() print(f'{label=:} {conf=:.2%}') grad_cam = model.get_heatmap(_input, label)[:, :3] saliency_map = model.get_heatmap(_input, label, method='saliency_map')[:, :3] grad_cam_impose = (grad_cam * 0.4 + _input) saliency_map_impose = (saliency_map * 0.4 + _input) grad_cam_impose = grad_cam_impose.div(grad_cam_impose.max()) saliency_map_impose = saliency_map_impose.div(saliency_map_impose.max()) F.to_pil_image(_input).save('./center_cropped.png') F.to_pil_image(grad_cam).save('./grad_cam.png') F.to_pil_image(saliency_map).save('./saliency_map.png') F.to_pil_image(grad_cam_impose).save('./grad_cam_impose.png') F.to_pil_image(saliency_map_impose).save('./saliency_map_impose.png') ``label=386 conf=77.74%`` .. table:: :align: left :widths: 80, 160, 160 +-----------------------+------------------+------------------------+ | original | |original| | |center_cropped| | +-----------------------+------------------+------------------------+ | grad_cam | |grad_cam| | |grad_cam_impose| | +-----------------------+------------------+------------------------+ | saliency_map | |saliency_map| | |saliency_map_impose| | +-----------------------+------------------+------------------------+ .. |original| image:: https://i.imgur.com/Bvro0YD.png :width: 224px .. |center_cropped| image:: /images/trojanvision/center_cropped.png .. |grad_cam| image:: /images/trojanvision/grad_cam.png .. |grad_cam_impose| image:: /images/trojanvision/grad_cam_impose.png .. |saliency_map| image:: /images/trojanvision/saliency_map.png .. |saliency_map_impose| image:: /images/trojanvision/saliency_map_impose.png """ squeeze_flag = False if _input.dim() == 3: _input = _input.unsqueeze(0) # (N, C, H, W) squeeze_flag = True if isinstance(_label, int): _label = [_label] * len(_input) _label = torch.as_tensor(_label, device=_input.device) heatmap = _input # linting purpose match method: case 'grad_cam': feats = self._model.get_fm(_input).detach() # (N, C', H', W') feats.requires_grad_() _output: torch.Tensor = self._model.pool(feats) # (N, C', 1, 1) _output = self._model.flatten(_output) # (N, C') _output = self._model.classifier(_output) # (N, num_classes) _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, feats)[0] # (N, C', H', W') feats.requires_grad_(False) weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True) # (N, C', 1, 1) heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(0) # (N, 1, H', W') # heatmap.sub_(heatmap.amin(dim=-2, keepdim=True).amin(dim=-1, keepdim=True)) heatmap.div_(heatmap.amax(dim=-2, keepdim=True).amax(dim=-1, keepdim=True)) heatmap: torch.Tensor = F.interpolate(heatmap, _input.shape[-2:], mode=mode)[:, 0] # (N, H, W) # Note that we violate the image order convension (W, H, C) case 'saliency_map': _input.requires_grad_() _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum() grad = torch.autograd.grad(_output, _input)[0] # (N, C, H, W) _input.requires_grad_(False) heatmap = grad.abs().amax(dim=1) # (N, H, W) heatmap.sub_(heatmap.amin(dim=-2, keepdim=True).amin(dim=-1, keepdim=True)) heatmap.div_(heatmap.amax(dim=-2, keepdim=True).amax(dim=-1, keepdim=True)) case _: raise NotImplementedError(f'{method=} is not supported yet.') heatmap = apply_cmap(heatmap.detach().cpu(), cmap) return heatmap[0] if squeeze_flag else heatmap