コード例 #1
0
    def test_broadcastable_masks(self) -> None:
        # integration test to ensure that
        # permutation function works with custom masks
        def forward_func(x: Tensor) -> Tensor:
            return x.view(x.shape[0], -1).sum(dim=-1)

        batch_size = 2
        inp = torch.randn((batch_size,) + (3, 4, 4))

        feature_importance = FeaturePermutation(forward_func=forward_func)

        masks = [
            torch.tensor([0]),
            torch.tensor([[0, 1, 2, 3]]),
            torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]),
        ]

        for mask in masks:
            attribs = feature_importance.attribute(inp, feature_mask=mask)

            self.assertTrue(attribs is not None)
            self.assertTrue(attribs.shape == inp.shape)

            fm = mask.expand_as(inp[0])

            features = set(mask.flatten())
            for feature in features:
                m = (fm == feature).bool()
                attribs_for_feature = attribs[:, m]
                assertArraysAlmostEqual(attribs_for_feature[0], -attribs_for_feature[1])
コード例 #2
0
    def test_mulitple_perturbations_per_eval(self) -> None:
        perturbations_per_eval = 4
        batch_size = 2
        input_size = (4,)

        inp = torch.randn((batch_size,) + input_size)

        def forward_func(x):
            return 1 - x

        target = 1
        feature_importance = FeaturePermutation(forward_func=forward_func)

        attribs = feature_importance.attribute(
            inp, perturbations_per_eval=perturbations_per_eval, target=target
        )
        self.assertTrue(attribs.size() == (batch_size,) + input_size)

        for i in range(inp.size(1)):
            if i == target:
                continue
            assertTensorAlmostEqual(self, attribs[:, i], 0)

        y = forward_func(inp)
        actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]])
        assertTensorAlmostEqual(self, attribs[:, target], actual_diff)
コード例 #3
0
    def test_empty_sparse_features(self) -> None:
        model = BasicModelWithSparseInputs()
        inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])
        inp2 = torch.tensor([])

        # test empty sparse tensor
        feature_importance = FeaturePermutation(model)
        attr1, attr2 = feature_importance.attribute((inp1, inp2))
        self.assertEqual(attr1.shape, (1, 3))
        self.assertEqual(attr2.shape, (1,))
コード例 #4
0
    def test_sparse_features(self) -> None:
        model = BasicModelWithSparseInputs()
        inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])
        # Length of sparse index list may not match # of examples
        inp2 = torch.tensor([1, 7, 2, 4, 5, 3, 6])

        feature_importance = FeaturePermutation(model)
        total_attr1, total_attr2 = feature_importance.attribute((inp1, inp2))

        for _ in range(50):
            attr1, attr2 = feature_importance.attribute((inp1, inp2))
            total_attr1 += attr1
            total_attr2 += attr2
        total_attr1 /= 50
        total_attr2 /= 50
        self.assertEqual(total_attr2.shape, (1,))
        assertTensorAlmostEqual(self, total_attr1, [0.0, 0.0, 0.0])
        assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2)
コード例 #5
0
    def test_single_input(self) -> None:
        batch_size = 2
        input_size = (3, )

        def forward_func(x: Tensor) -> Tensor:
            return x.sum()

        feature_importance = FeaturePermutation(forward_func=forward_func)

        inp = torch.randn((batch_size, ) + input_size) * 10

        inp[:, 0] = 5
        for _ in range(10):
            attribs = feature_importance.attribute(inp)

            self.assertTrue(attribs.squeeze(0).size() == input_size)
            self.assertTrue((attribs[:, 0] == 0).all())
            self.assertTrue((attribs[:, 1] != 0).all())
            self.assertTrue((attribs[:, 2] != 0).all())
コード例 #6
0
    def test_single_input(self) -> None:
        batch_size = 2
        input_size = (6,)
        constant_value = 10000

        def forward_func(x: Tensor) -> Tensor:
            return x.sum(dim=-1)

        feature_importance = FeaturePermutation(forward_func=forward_func)

        inp = torch.randn((batch_size,) + input_size)

        inp[:, 0] = constant_value
        zeros = torch.zeros_like(inp[:, 0])

        attribs = feature_importance.attribute(inp)

        self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size)
        assertArraysAlmostEqual(attribs[:, 0], zeros)
        self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all())
コード例 #7
0
    def test_multi_input(self) -> None:
        batch_size = 20
        inp1_size = (5, 2)
        inp2_size = (5, 3)

        labels = torch.randn(batch_size)

        def forward_func(*x: Tensor) -> Tensor:
            y = torch.zeros(x[0].shape[0:2])
            for xx in x:
                y += xx[:, :, 0] * xx[:, :, 1]
            y = y.sum(dim=-1)

            return torch.mean((y - labels) ** 2)

        feature_importance = FeaturePermutation(forward_func=forward_func)

        inp = (
            torch.randn((batch_size,) + inp1_size),
            torch.randn((batch_size,) + inp2_size),
        )

        feature_mask = (
            torch.arange(inp[0][0].numel()).view_as(inp[0][0]).unsqueeze(0),
            torch.arange(inp[1][0].numel()).view_as(inp[1][0]).unsqueeze(0),
        )

        inp[1][:, :, 1] = 4
        attribs = feature_importance.attribute(inp, feature_mask=feature_mask)

        self.assertTrue(isinstance(attribs, tuple))
        self.assertTrue(len(attribs) == 2)

        self.assertTrue(attribs[0].squeeze(0).size() == inp1_size)
        self.assertTrue(attribs[1].squeeze(0).size() == inp2_size)

        self.assertTrue((attribs[1][:, :, 1] == 0).all())
        self.assertTrue((attribs[1][:, :, 2] == 0).all())

        self.assertTrue((attribs[0] != 0).all())
        self.assertTrue((attribs[1][:, :, 0] != 0).all())