def test_relu_deeplift_exact_match_wo_mutliplying_by_inputs(self) -> None:
        x1 = torch.tensor([1.0])
        x2 = torch.tensor([2.0])
        inputs = (x1, x2)

        model = ReLUDeepLiftModel()
        dl = DeepLift(model, multiply_by_inputs=False)
        attributions = dl.attribute(inputs)
        self.assertEqual(attributions[0][0], 2.0)
        self.assertEqual(attributions[1][0], 0.5)
    def test_relu_deeplift_batch(self) -> None:
        x1 = torch.tensor([[1.0], [1.0], [1.0], [1.0]], requires_grad=True)
        x2 = torch.tensor([[2.0], [2.0], [2.0], [2.0]], requires_grad=True)

        b1 = torch.tensor([[0.0], [0.0], [0.0], [0.0]], requires_grad=True)
        b2 = torch.tensor([[0.0], [0.0], [0.0], [0.0]], requires_grad=True)

        inputs = (x1, x2)
        baselines = (b1, b2)

        model = ReLUDeepLiftModel()
        self._deeplift_assert(model, DeepLift(model), inputs, baselines)
    def test_relu_deepliftshap_batch_4D_input(self) -> None:
        x1 = torch.ones(4, 1, 1, 1)
        x2 = torch.tensor([[[[2.0]]]] * 4)

        b1 = torch.zeros(4, 1, 1, 1)
        b2 = torch.zeros(4, 1, 1, 1)

        inputs = (x1, x2)
        baselines = (b1, b2)

        model = ReLUDeepLiftModel()
        self._deeplift_assert(model, DeepLiftShap(model), inputs, baselines)
    def test_relu_deepliftshap_batch_4D_input_wo_mutliplying_by_inputs(self) -> None:
        x1 = torch.ones(4, 1, 1, 1)
        x2 = torch.tensor([[[[2.0]]]] * 4)

        b1 = torch.zeros(4, 1, 1, 1)
        b2 = torch.zeros(4, 1, 1, 1)

        inputs = (x1, x2)
        baselines = (b1, b2)

        model = ReLUDeepLiftModel()
        attr = DeepLiftShap(model, multiply_by_inputs=False).attribute(
            inputs, baselines
        )
        assertTensorAlmostEqual(self, attr[0], 2 * torch.ones(4, 1))
        assertTensorAlmostEqual(self, attr[1], 0.5 * torch.ones(4, 1))
    def test_relu_deeplift_exact_match(self) -> None:
        x1 = torch.tensor([1.0], requires_grad=True)
        x2 = torch.tensor([2.0], requires_grad=True)

        b1 = torch.tensor([0.0], requires_grad=True)
        b2 = torch.tensor([0.0], requires_grad=True)

        inputs = (x1, x2)
        baselines = (b1, b2)
        model = ReLUDeepLiftModel()
        dl = DeepLift(model)
        attributions, delta = dl.attribute(
            inputs, baselines, return_convergence_delta=True
        )
        self.assertEqual(attributions[0][0], 2.0)
        self.assertEqual(attributions[1][0], 1.0)
        self.assertEqual(delta[0], 0.0)