Ejemplo n.º 1
0
    def __mul__(self, other):
        if isinstance(other, Attribute):
            self_data, other_data = self.data, other.data

            # Repeat along missing colour dimensions if necessary
            if self.data.ndim < other.data.ndim:
                self_data = self.data.unsqueeze(0).repeat(
                    other.data.shape[0], 1, 1)
            elif self.data.ndim > other.data.ndim:
                other_data = other.data.unsqueeze(0).repeat(
                    self.data.shape[0], 1, 1)

            # Compare shape of data tensors
            if self_data.shape != other_data.shape:
                self_data = denorm(self_data)
                other_data = denorm(other_data)
                if self_data.size > other_data.size:
                    other_data = other_data.resize(self_data.size, resample=2)
                else:
                    self_data = self_data.resize(other_data.size, resample=2)

                self_data = norm(self_data, unsqueeze=False, grad=False)
                other_data = norm(other_data, unsqueeze=False, grad=False)

            return Attribute(self_data * other_data, self.input_data)
        elif isinstance(other, (int, float)):
            return Attribute(self.data * other, self.input_data)
        else:
            raise ValueError(f"Can't multiply by type {type(other)}")
Ejemplo n.º 2
0
def test_norm_denorm():
    img = Image.fromarray(np.random.random((32, 32, 3)), 'RGB')

    data = norm(img)
    denorm_data = denorm(data, image=False)

    assert np.isclose(np.array(img), denorm_data).all()
Ejemplo n.º 3
0
def test_denorm():
    data = torch.randn(1, 3, 50, 50)
    denorm_data = denorm(data, image=False)

    assert denorm_data.max() <= 255
    assert denorm_data.min() >= 0
    assert denorm_data.shape == (50, 50, 3)
Ejemplo n.º 4
0
    def show(self,
             ax=None,
             show_image=True,
             alpha=0.4,
             cmap='magma',
             colorbar=False,
             **kwargs):
        """Show the generated attribution map.

        Parameters:
            show_image (bool): show the denormalised input image overlaid on the heatmap.
            ax: axes on which to plot image.
            colorbar (bool): show a colorbar.
            cmap: matplotlib colourmap.
            alpha (float): transparency value alpha for heatmap.
            kwargs: passed to `denorm`. Used to change mean and std normalization values.
        """
        if ax is None:
            _, ax = plt.subplots()

        sz = list(self.input_data.shape[2:])
        if show_image:
            input_image = denorm(self.input_data[0], **kwargs)
            ax.imshow(input_image)

        data = self.data
        if (data < 0).any():
            data = (data - data.min()) / (data.max() - data.min())

        if data.ndim >= 3:
            data = data.squeeze()
            if data.ndim == 3:
                data = data.permute(1, 2, 0)
            else:
                raise RuntimeError(
                    f"Can't display data shape {self.data.shape} as an image.")

        im = ax.imshow(data,
                       alpha=alpha,
                       extent=(0, *sz[::-1], 0),
                       interpolation='bilinear',
                       cmap=cmap)
        if colorbar:
            ax.figure.colorbar(im, ax=ax)
Ejemplo n.º 5
0
 def _repr_html_(self):
     from IPython.display import display
     return display(denorm(self.image))