def test_image_feature_generates_correct_ouput(self): attribution = torch.zeros(1, 3, 4, 4) data = torch.ones(1, 3, 4, 4) contribution = 1.0 name = "photo" orig_fig = Figure(figsize=(4, 4)) attr_fig = Figure(figsize=(4, 4)) def mock_viz_attr(*args, **kwargs): if kwargs["method"] == "original_image": return orig_fig, None else: return attr_fig, None feature = ImageFeature( name=name, baseline_transforms=None, input_transforms=None, visualization_transform=None, ) with patch( "captum.attr._utils.visualization.visualize_image_attr", mock_viz_attr ): feature_output = feature.visualize(attribution, data, contribution) expected_feature_output = FeatureOutput( name=name, base=_convert_figure_base64(orig_fig), modified=_convert_figure_base64(attr_fig), type="image", contribution=contribution, ) self.assertEqual(expected_feature_output, feature_output)
def test_one_feature(self): batch_size = 2 classes = _get_classes() dataset = list( _labelled_img_data(num_labels=len(classes), num_samples=batch_size)) # NOTE: using DataLoader to batch the inputs # since AttributionVisualizer requires the input to be of size `B x ...` data_loader = torch.utils.data.DataLoader(list(dataset), batch_size=batch_size, shuffle=False, num_workers=0) visualizer = AttributionVisualizer( models=[_get_cnn()], classes=classes, features=[ ImageFeature( "Photo", input_transforms=[lambda x: x], baseline_transforms=[lambda x: x * 0], ) ], dataset=to_iter(data_loader), score_func=None, ) visualizer._config = FilterConfig(attribution_arguments={"n_steps": 2}) outputs = visualizer.visualize() for output in outputs: total_contrib = sum( abs(f.contribution) for f in output.feature_outputs) self.assertAlmostEqual(total_contrib, 1.0, places=6)
def insights(x: CaptumInterpretation, inp_data, debug=True): _baseline_func = lambda o: o * 0 _get_vocab = lambda vocab: list(map(str, vocab)) if isinstance( vocab[0], bool) else vocab dl = x.dls.test_dl(L(inp_data), with_labels=True, bs=4) normalize_func = next( (func for func in dl.after_batch if type(func) == Normalize), noop) # captum v0.3 expects tensors without the batch dimension. if hasattr(normalize_func, 'mean'): if normalize_func.mean.ndim == 4: normalize_func.mean.squeeze_(0) if hasattr(normalize_func, 'std'): if normalize_func.std.ndim == 4: normalize_func.std.squeeze_(0) visualizer = AttributionVisualizer( models=[x.model], score_func=lambda o: torch.nn.functional.softmax(o, 1), classes=_get_vocab(dl.vocab), features=[ ImageFeature( "Image", baseline_transforms=[_baseline_func], input_transforms=[normalize_func], ) ], dataset=x._formatted_data_iter(dl, normalize_func)) visualizer.render(debug=debug)
def main(): normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) model = get_pretrained_model() visualizer = AttributionVisualizer( models=[model], score_func=lambda o: torch.nn.functional.softmax(o, 1), classes=get_classes(), features=[ ImageFeature( "Photo", baseline_transforms=[baseline_func], input_transforms=[normalize], ) ], dataset=formatted_data_iter(), ) visualizer.serve(debug=True)
def __init__(self, path, bs, learn, application, baseline_func_default=False, baseline_token=None, baseline_fn=None): self.path = path if path else None self.bs = bs if bs else 8 self.data = learn.data self.model = learn.model.eval() self.application = application self.baseline_func_default = baseline_func_default if self.application == 'text': self.vocab = learn.data.x.vocab self.baseline_token = (baseline_token if baseline_token else '.' if self.application == 'text' else 0) self.baseline_fn = baseline_fn self.features = ([ TextFeature( "Sentence", baseline_transforms=( [self.baseline_fn] if self.baseline_fn else None if self.baseline_func_default else [self.baseline_func]), visualization_transform=self.itos, input_transforms=[]) ] if self.application == 'text' else [ ImageFeature( 'Image', baseline_transforms=( [self.baseline_fn] if self.baseline_fn else None if self. baseline_func_default else [self.baseline_func]), input_transforms=[]) ] if self.application == 'vision' else [ TabularFeature( 'Table', baseline_transforms=( [self.baseline_fn] if self.baseline_fn else None if self. baseline_func_default else [self.baseline_func])) ])
i = transform(i) i = transform_normalize(i) i = i.unsqueeze(0) return i input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0) visualizer = AttributionVisualizer( models=[model], score_func=lambda o: torch.nn.functional.softmax(o, 1), classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())), features=[ ImageFeature( "Photo", baseline_transforms=[baseline_func], input_transforms=[], ) ], dataset=[Batch(input_imgs, labels=[282, 849, 69])]) ######################################################################### # Note that running the cell above didn’t take much time at all, unlike # our attributions above. That’s because Captum Insights lets you # configure different attribution algorithms in a visual widget, after # which it will compute and display the attributions. *That* process will # take a few minutes. # # Running the cell below will render the Captum Insights widget. You can # then choose attributions methods and their arguments, filter model # responses based on predicted class or prediction correctness, see the