def __mul__(self, other): if isinstance(other, Attribute): self_data, other_data = self.data, other.data # Repeat along missing colour dimensions if necessary if self.data.ndim < other.data.ndim: self_data = self.data.unsqueeze(0).repeat( other.data.shape[0], 1, 1) elif self.data.ndim > other.data.ndim: other_data = other.data.unsqueeze(0).repeat( self.data.shape[0], 1, 1) # Compare shape of data tensors if self_data.shape != other_data.shape: self_data = denorm(self_data) other_data = denorm(other_data) if self_data.size > other_data.size: other_data = other_data.resize(self_data.size, resample=2) else: self_data = self_data.resize(other_data.size, resample=2) self_data = norm(self_data, unsqueeze=False, grad=False) other_data = norm(other_data, unsqueeze=False, grad=False) return Attribute(self_data * other_data, self.input_data) elif isinstance(other, (int, float)): return Attribute(self.data * other, self.input_data) else: raise ValueError(f"Can't multiply by type {type(other)}")
def test_norm_denorm(): img = Image.fromarray(np.random.random((32, 32, 3)), 'RGB') data = norm(img) denorm_data = denorm(data, image=False) assert np.isclose(np.array(img), denorm_data).all()
def test_denorm(): data = torch.randn(1, 3, 50, 50) denorm_data = denorm(data, image=False) assert denorm_data.max() <= 255 assert denorm_data.min() >= 0 assert denorm_data.shape == (50, 50, 3)
def show(self, ax=None, show_image=True, alpha=0.4, cmap='magma', colorbar=False, **kwargs): """Show the generated attribution map. Parameters: show_image (bool): show the denormalised input image overlaid on the heatmap. ax: axes on which to plot image. colorbar (bool): show a colorbar. cmap: matplotlib colourmap. alpha (float): transparency value alpha for heatmap. kwargs: passed to `denorm`. Used to change mean and std normalization values. """ if ax is None: _, ax = plt.subplots() sz = list(self.input_data.shape[2:]) if show_image: input_image = denorm(self.input_data[0], **kwargs) ax.imshow(input_image) data = self.data if (data < 0).any(): data = (data - data.min()) / (data.max() - data.min()) if data.ndim >= 3: data = data.squeeze() if data.ndim == 3: data = data.permute(1, 2, 0) else: raise RuntimeError( f"Can't display data shape {self.data.shape} as an image.") im = ax.imshow(data, alpha=alpha, extent=(0, *sz[::-1], 0), interpolation='bilinear', cmap=cmap) if colorbar: ax.figure.colorbar(im, ax=ax)
def _repr_html_(self): from IPython.display import display return display(denorm(self.image))