def test_softmax_classification_batch_multi_target(self) -> None: num_in = 40 inputs = torch.arange(0.0, num_in * 3.0, requires_grad=True).reshape(3, num_in) baselines = torch.arange(1.0, num_in + 1).reshape(1, num_in) model = SoftmaxDeepLiftModel(num_in, 20, 10) dl = DeepLift(model) self.softmax_classification(model, dl, inputs, baselines, torch.tensor([2, 2, 2]))
def test_softmax_classification_batch_multi_baseline(self) -> None: num_in = 40 input = torch.arange(0.0, num_in * 2.0, requires_grad=True).reshape(2, num_in) baselines = torch.randn(5, 40) model = SoftmaxDeepLiftModel(num_in, 20, 10) dl = DeepLiftShap(model) self.softmax_classification(model, dl, input, baselines, torch.tensor(2))
def test_softmax_classification_zero_baseline(self) -> None: num_in = 20 input = torch.arange(0.0, num_in * 1.0, requires_grad=True).unsqueeze(0) baselines = 0.0 model = SoftmaxDeepLiftModel(num_in, 20, 10) dl = DeepLift(model) self.softmax_classification(model, dl, input, baselines, torch.tensor(2))