def heatmap(self, input_t: torch.Tensor, target_t: torch.Tensor): self.model.eval() self._inject_bottleneck() with torch.no_grad(): self.model(input_t) self._remove_bottleneck() if (input_t.shape[0] > 1): hmaps = np.zeros( (input_t.shape[0], input_t.shape[2], input_t.shape[3])) for i in range(input_t.shape[0]): htensor = to_np(self.bn_layer.buffer_capacity[i]) hmap = htensor.mean(0) hmap = resize(hmap, input_t.shape[2:]) hmap = hmap - hmap.min() hmap = hmap / (max(hmap.max(), 1e-5)) hmaps[i] = hmap hmap = hmaps else: htensor = to_np(self.bn_layer.buffer_capacity) hmap = htensor.mean(axis=(0, 1)) hmap = resize(hmap, input_t.shape[2:]) hmap = hmap - hmap.min() hmap = hmap / (max(hmap.max(), 1e-5)) return hmap
def _current_heatmap(self, shape=None): # Read bottleneck heatmap = self.bottleneck.buffer_capacity heatmap = to_np(heatmap[0]) heatmap = heatmap.sum(axis=0) # Sum over channel dim heatmap = heatmap - heatmap.min() # min=0 heatmap = heatmap / heatmap.max() # max=0 if shape is not None: heatmap = resize(heatmap, shape) return heatmap
def heatmap(self, input_t: torch.Tensor, target_t: torch.Tensor): self.model.eval() self._inject_bottleneck() with torch.no_grad(): self.model(input_t) self._remove_bottleneck() htensor = to_np(self.bn_layer.buffer_capacity) hmap = htensor.mean(axis=(0, 1)) hmap = resize(hmap, input_t.shape[2:]) hmap = hmap - hmap.min() hmap = hmap / (max(hmap.max(), 1e-5)) return hmap
def heatmap(self, input_t, target): target_t = target if isinstance( target, torch.Tensor) else torch.tensor(target, device=input_t.device) assert input_t.shape[0] == 1, "We can only fit on one sample" assert target_t.shape[0] == 1, "We can only fit on one label" assert input_t.shape[ -1] == 224, "with must be 224, otherwise avgpool is not the identity" avgpool = self.model.avgpool self.model.avgpool = Identity() lrp_model = InnvestigateModel(self.model, epsilon=self.eps, beta=self.beta, method=self.method) lrp_model.to(self.device) _, relevance = lrp_model.innvestigate(input_t, target_t) self.model.avgpool = avgpool heatmap = to_np(relevance[0]).sum(0) return heatmap
def hook(m, x, z): self.estimators[i].feed_batch(to_np(z))