Exemplo n.º 1
0
    def test_image_feature_generates_correct_ouput(self):
        attribution = torch.zeros(1, 3, 4, 4)
        data = torch.ones(1, 3, 4, 4)
        contribution = 1.0
        name = "photo"

        orig_fig = Figure(figsize=(4, 4))
        attr_fig = Figure(figsize=(4, 4))

        def mock_viz_attr(*args, **kwargs):
            if kwargs["method"] == "original_image":
                return orig_fig, None
            else:
                return attr_fig, None

        feature = ImageFeature(
            name=name,
            baseline_transforms=None,
            input_transforms=None,
            visualization_transform=None,
        )

        with patch(
            "captum.attr._utils.visualization.visualize_image_attr", mock_viz_attr
        ):
            feature_output = feature.visualize(attribution, data, contribution)
            expected_feature_output = FeatureOutput(
                name=name,
                base=_convert_figure_base64(orig_fig),
                modified=_convert_figure_base64(attr_fig),
                type="image",
                contribution=contribution,
            )

            self.assertEqual(expected_feature_output, feature_output)
Exemplo n.º 2
0
    def test_one_feature(self):
        batch_size = 2
        classes = _get_classes()
        dataset = list(
            _labelled_img_data(num_labels=len(classes),
                               num_samples=batch_size))

        # NOTE: using DataLoader to batch the inputs
        # since AttributionVisualizer requires the input to be of size `B x ...`
        data_loader = torch.utils.data.DataLoader(list(dataset),
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=0)

        visualizer = AttributionVisualizer(
            models=[_get_cnn()],
            classes=classes,
            features=[
                ImageFeature(
                    "Photo",
                    input_transforms=[lambda x: x],
                    baseline_transforms=[lambda x: x * 0],
                )
            ],
            dataset=to_iter(data_loader),
            score_func=None,
        )
        visualizer._config = FilterConfig(attribution_arguments={"n_steps": 2})

        outputs = visualizer.visualize()

        for output in outputs:
            total_contrib = sum(
                abs(f.contribution) for f in output.feature_outputs)
            self.assertAlmostEqual(total_contrib, 1.0, places=6)
Exemplo n.º 3
0
def insights(x: CaptumInterpretation, inp_data, debug=True):
    _baseline_func = lambda o: o * 0
    _get_vocab = lambda vocab: list(map(str, vocab)) if isinstance(
        vocab[0], bool) else vocab
    dl = x.dls.test_dl(L(inp_data), with_labels=True, bs=4)
    normalize_func = next(
        (func for func in dl.after_batch if type(func) == Normalize), noop)

    # captum v0.3 expects tensors without the batch dimension.
    if hasattr(normalize_func, 'mean'):
        if normalize_func.mean.ndim == 4: normalize_func.mean.squeeze_(0)
    if hasattr(normalize_func, 'std'):
        if normalize_func.std.ndim == 4: normalize_func.std.squeeze_(0)

    visualizer = AttributionVisualizer(
        models=[x.model],
        score_func=lambda o: torch.nn.functional.softmax(o, 1),
        classes=_get_vocab(dl.vocab),
        features=[
            ImageFeature(
                "Image",
                baseline_transforms=[_baseline_func],
                input_transforms=[normalize_func],
            )
        ],
        dataset=x._formatted_data_iter(dl, normalize_func))
    visualizer.render(debug=debug)
Exemplo n.º 4
0
def main():
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    model = get_pretrained_model()
    visualizer = AttributionVisualizer(
        models=[model],
        score_func=lambda o: torch.nn.functional.softmax(o, 1),
        classes=get_classes(),
        features=[
            ImageFeature(
                "Photo",
                baseline_transforms=[baseline_func],
                input_transforms=[normalize],
            )
        ],
        dataset=formatted_data_iter(),
    )

    visualizer.serve(debug=True)
Exemplo n.º 5
0
 def __init__(self,
              path,
              bs,
              learn,
              application,
              baseline_func_default=False,
              baseline_token=None,
              baseline_fn=None):
     self.path = path if path else None
     self.bs = bs if bs else 8
     self.data = learn.data
     self.model = learn.model.eval()
     self.application = application
     self.baseline_func_default = baseline_func_default
     if self.application == 'text':
         self.vocab = learn.data.x.vocab
     self.baseline_token = (baseline_token if baseline_token else
                            '.' if self.application == 'text' else 0)
     self.baseline_fn = baseline_fn
     self.features = ([
         TextFeature(
             "Sentence",
             baseline_transforms=(
                 [self.baseline_fn] if self.baseline_fn else None
                 if self.baseline_func_default else [self.baseline_func]),
             visualization_transform=self.itos,
             input_transforms=[])
     ] if self.application == 'text' else [
         ImageFeature(
             'Image',
             baseline_transforms=(
                 [self.baseline_fn] if self.baseline_fn else None if self.
                 baseline_func_default else [self.baseline_func]),
             input_transforms=[])
     ] if self.application == 'vision' else [
         TabularFeature(
             'Table',
             baseline_transforms=(
                 [self.baseline_fn] if self.baseline_fn else None if self.
                 baseline_func_default else [self.baseline_func]))
     ])
    i = transform(i)
    i = transform_normalize(i)
    i = i.unsqueeze(0)
    return i


input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0)

visualizer = AttributionVisualizer(
    models=[model],
    score_func=lambda o: torch.nn.functional.softmax(o, 1),
    classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())),
    features=[
        ImageFeature(
            "Photo",
            baseline_transforms=[baseline_func],
            input_transforms=[],
        )
    ],
    dataset=[Batch(input_imgs, labels=[282, 849, 69])])

#########################################################################
# Note that running the cell above didn’t take much time at all, unlike
# our attributions above. That’s because Captum Insights lets you
# configure different attribution algorithms in a visual widget, after
# which it will compute and display the attributions. *That* process will
# take a few minutes.
#
# Running the cell below will render the Captum Insights widget. You can
# then choose attributions methods and their arguments, filter model
# responses based on predicted class or prediction correctness, see the