示例#1
0
    def heatmap(self, input_t: torch.Tensor, target_t: torch.Tensor):

        self.model.eval()
        self._inject_bottleneck()
        with torch.no_grad():
            self.model(input_t)
        self._remove_bottleneck()
        if (input_t.shape[0] > 1):
            hmaps = np.zeros(
                (input_t.shape[0], input_t.shape[2], input_t.shape[3]))
            for i in range(input_t.shape[0]):
                htensor = to_np(self.bn_layer.buffer_capacity[i])
                hmap = htensor.mean(0)
                hmap = resize(hmap, input_t.shape[2:])

                hmap = hmap - hmap.min()
                hmap = hmap / (max(hmap.max(), 1e-5))
                hmaps[i] = hmap
            hmap = hmaps
        else:
            htensor = to_np(self.bn_layer.buffer_capacity)
            hmap = htensor.mean(axis=(0, 1))
            hmap = resize(hmap, input_t.shape[2:])
            hmap = hmap - hmap.min()
            hmap = hmap / (max(hmap.max(), 1e-5))

        return hmap
示例#2
0
    def _current_heatmap(self, shape=None):
        # Read bottleneck
        heatmap = self.bottleneck.buffer_capacity
        heatmap = to_np(heatmap[0])
        heatmap = heatmap.sum(axis=0)  # Sum over channel dim
        heatmap = heatmap - heatmap.min()  # min=0
        heatmap = heatmap / heatmap.max()  # max=0

        if shape is not None:
            heatmap = resize(heatmap, shape)

        return heatmap
    def heatmap(self, input_t: torch.Tensor, target_t: torch.Tensor):

        self.model.eval()
        self._inject_bottleneck()
        with torch.no_grad():
            self.model(input_t)
        self._remove_bottleneck()

        htensor = to_np(self.bn_layer.buffer_capacity)

        hmap = htensor.mean(axis=(0, 1))
        hmap = resize(hmap, input_t.shape[2:])

        hmap = hmap - hmap.min()
        hmap = hmap / (max(hmap.max(), 1e-5))

        return hmap
示例#4
0
    def heatmap(self, input_t, target):
        target_t = target if isinstance(
            target, torch.Tensor) else torch.tensor(target,
                                                    device=input_t.device)

        assert input_t.shape[0] == 1, "We can only fit on one sample"
        assert target_t.shape[0] == 1, "We can only fit on one label"
        assert input_t.shape[
            -1] == 224, "with must be 224, otherwise avgpool is not the identity"
        avgpool = self.model.avgpool
        self.model.avgpool = Identity()

        lrp_model = InnvestigateModel(self.model,
                                      epsilon=self.eps,
                                      beta=self.beta,
                                      method=self.method)
        lrp_model.to(self.device)

        _, relevance = lrp_model.innvestigate(input_t, target_t)

        self.model.avgpool = avgpool

        heatmap = to_np(relevance[0]).sum(0)
        return heatmap
示例#5
0
 def hook(m, x, z):
     self.estimators[i].feed_batch(to_np(z))