Beispiel #1
0
    def get_heatmap(self, _input, _label, cmap=jet):
        from trojanvision.utils import apply_cmap

        squeeze_flag = False
        if _input.dim() == 3:
            _input = _input.unsqueeze(0)  # (N, C, H, W)
            squeeze_flag = True
        if isinstance(_label, int):
            _label = [_label] * len(_input)
        _label = torch.as_tensor(_label, device=_input.device)
        heatmap = _input
        _input.requires_grad_()
        _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum()
        grad = torch.autograd.grad(_output, _input)[0]  # (N,C,H,W)
        zero = torch.zeros_like(grad)
        grad = torch.where(grad < 0, zero, grad)
        _input.requires_grad_(False)

        heatmap = grad.abs().max(dim=1)[0]  # (N,H,W)
        heatmap.sub_(
            heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0])
        heatmap.div_(
            heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0])

        heatmap = apply_cmap(heatmap.detach().cpu(), cmap)
        return heatmap[0] if squeeze_flag else heatmap
Beispiel #2
0
    def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor, method: str = 'grad_cam', cmap: Colormap = jet) -> torch.Tensor:
        squeeze_flag = False
        if len(_input.shape) == 3:
            _input = _input.unsqueeze(0)    # (N, C, H, W)
            squeeze_flag = True
        if isinstance(_label, int):
            _label = [_label] * len(_input)
        _label = torch.as_tensor(_label, device=_input.device)
        heatmap = _input    # linting purpose
        if method == 'grad_cam':
            feats = self._model.get_fm(_input).detach()   # (N, C', H', W')
            feats.requires_grad_()
            _output: torch.Tensor = self._model.pool(feats)   # (N, C', 1, 1)
            _output = self._model.flatten(_output)   # (N, C')
            _output = self._model.classifier(_output)   # (N, num_classes)
            _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, feats)[0]   # (N, C',H', W')
            feats.requires_grad_(False)
            weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True)    # (N, C',1,1)
            heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(0)  # (N, 1, H', W')
            # heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0])
            heatmap.div_(heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0])
            heatmap: torch.Tensor = F.upsample(heatmap, _input.shape[-2:], mode='bilinear')[:, 0]   # (N, H, W)
            # Note that we violate the image order convension (W, H, C)
        elif method == 'saliency_map':
            _input.requires_grad_()
            _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum()
            grad = torch.autograd.grad(_output, _input)[0]   # (N,C,H,W)
            _input.requires_grad_(False)

            heatmap = grad.abs().max(dim=1)[0]   # (N,H,W)
            heatmap.sub_(heatmap.min(dim=-2, keepdim=True)[0].min(dim=-1, keepdim=True)[0])
            heatmap.div_(heatmap.max(dim=-2, keepdim=True)[0].max(dim=-1, keepdim=True)[0])
        heatmap = apply_cmap(heatmap.detach().cpu(), cmap)
        return heatmap[0] if squeeze_flag else heatmap
Beispiel #3
0
    def get_heatmap(self, _input: torch.Tensor, _label: torch.Tensor,
                    method: str = 'grad_cam', cmap: Colormap = jet,
                    mode: str = 'bicubic') -> torch.Tensor:
        r"""Use colormap :attr:`cmap` to get heatmap tensor of :attr:`_input`
        w.r.t. :attr:`_label` with :attr:`method`.

        Args:
            _input (torch.Tensor): The (batched) input tensor
                with shape ``([N], C, H, W)``.
            _label (torch.Tensor): The (batched) label tensor
                with shape ``([N])``
            method (str): The method to calculate heatmap.
                Choose from ``['grad_cam', 'saliency_map']``.
                Defaults to ``'grad_cam'``.
            cmap (matplotlib.colors.Colormap): The colormap to use.
            mode (str): Passed to :any:`torch.nn.functional.interpolate`.
                Defaults to ``'bicubic'``.

        Returns:
            torch.Tensor: The heatmap tensor with shape ([N], C, H, W).

        Note:
            Most :any:`matplotlib.colors.Colormap` will return
            a 4-channel heatmap with alpha channel.

        See Also:
            https://keras.io/examples/vision/grad_cam/

        :Example:
            .. code-block:: python
                :emphasize-lines: 30-32

                import trojanvision
                from trojanvision.utils import superimpose
                import torchvision.transforms as transforms
                import torchvision.transforms.functional as F
                import PIL.Image as Image
                import os
                import wget

                env = trojanvision.environ.create(device='cpu')
                model = trojanvision.models.create(
                    'resnet152', data_shape=[3, 224, 224], official=True,
                    norm_par={'mean': [0.485, 0.456, 0.406],
                              'std': [0.229, 0.224, 0.225]})
                transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.PILToTensor(),
                    transforms.ConvertImageDtype(torch.float)])
                url = 'https://i.imgur.com/Bvro0YD.png'
                if not os.path.isfile('african_elephant.png'):
                    wget.download(url, 'african_elephant.png')

                img = Image.open('african_elephant.png').convert(mode='RGB')

                _input = transform(img).unsqueeze(0).to(env['device'])
                _prob = model.get_prob(_input).squeeze()
                label = _prob.argmax().item()
                conf = _prob[label].item()
                print(f'{label=:}  {conf=:.2%}')

                grad_cam = model.get_heatmap(_input, label)[:, :3]
                saliency_map = model.get_heatmap(_input, label,
                                                 method='saliency_map')[:, :3]
                grad_cam_impose = (grad_cam * 0.4 + _input)
                saliency_map_impose = (saliency_map * 0.4 + _input)
                grad_cam_impose = grad_cam_impose.div(grad_cam_impose.max())
                saliency_map_impose = saliency_map_impose.div(saliency_map_impose.max())

                F.to_pil_image(_input).save('./center_cropped.png')
                F.to_pil_image(grad_cam).save('./grad_cam.png')
                F.to_pil_image(saliency_map).save('./saliency_map.png')
                F.to_pil_image(grad_cam_impose).save('./grad_cam_impose.png')
                F.to_pil_image(saliency_map_impose).save('./saliency_map_impose.png')

            ``label=386  conf=77.74%``

            .. table::
                :align: left
                :widths: 80, 160, 160

                +-----------------------+------------------+------------------------+
                |  original             |  |original|      | |center_cropped|       |
                +-----------------------+------------------+------------------------+
                |  grad_cam             |  |grad_cam|      | |grad_cam_impose|      |
                +-----------------------+------------------+------------------------+
                |  saliency_map         |  |saliency_map|  | |saliency_map_impose|  |
                +-----------------------+------------------+------------------------+

            .. |original| image:: https://i.imgur.com/Bvro0YD.png
                :width: 224px
            .. |center_cropped| image:: /images/trojanvision/center_cropped.png
            .. |grad_cam| image:: /images/trojanvision/grad_cam.png
            .. |grad_cam_impose| image:: /images/trojanvision/grad_cam_impose.png
            .. |saliency_map| image:: /images/trojanvision/saliency_map.png
            .. |saliency_map_impose| image:: /images/trojanvision/saliency_map_impose.png
        """
        squeeze_flag = False
        if _input.dim() == 3:
            _input = _input.unsqueeze(0)    # (N, C, H, W)
            squeeze_flag = True
        if isinstance(_label, int):
            _label = [_label] * len(_input)
        _label = torch.as_tensor(_label, device=_input.device)
        heatmap = _input    # linting purpose
        match method:
            case 'grad_cam':
                feats = self._model.get_fm(_input).detach()   # (N, C', H', W')
                feats.requires_grad_()
                _output: torch.Tensor = self._model.pool(feats)   # (N, C', 1, 1)
                _output = self._model.flatten(_output)   # (N, C')
                _output = self._model.classifier(_output)   # (N, num_classes)
                _output = _output.gather(dim=1, index=_label.unsqueeze(1)).sum()
                grad = torch.autograd.grad(_output, feats)[0]   # (N, C', H', W')
                feats.requires_grad_(False)
                weights = grad.mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True)    # (N, C', 1, 1)
                heatmap = (feats * weights).sum(dim=1, keepdim=True).clamp(0)  # (N, 1, H', W')
                # heatmap.sub_(heatmap.amin(dim=-2, keepdim=True).amin(dim=-1, keepdim=True))
                heatmap.div_(heatmap.amax(dim=-2, keepdim=True).amax(dim=-1, keepdim=True))
                heatmap: torch.Tensor = F.interpolate(heatmap, _input.shape[-2:], mode=mode)[:, 0]   # (N, H, W)
                # Note that we violate the image order convension (W, H, C)
            case 'saliency_map':
                _input.requires_grad_()
                _output = self(_input).gather(dim=1, index=_label.unsqueeze(1)).sum()
                grad = torch.autograd.grad(_output, _input)[0]   # (N, C, H, W)
                _input.requires_grad_(False)

                heatmap = grad.abs().amax(dim=1)   # (N, H, W)
                heatmap.sub_(heatmap.amin(dim=-2, keepdim=True).amin(dim=-1, keepdim=True))
                heatmap.div_(heatmap.amax(dim=-2, keepdim=True).amax(dim=-1, keepdim=True))
            case _:
                raise NotImplementedError(f'{method=} is not supported yet.')
        heatmap = apply_cmap(heatmap.detach().cpu(), cmap)
        return heatmap[0] if squeeze_flag else heatmap