def propagate_relevance(self, module, relevance_input, relevance_output, lrp_method, lrp_params=None): # todo how is batchnorm handled by default in innvestigate? if lrp_method == 'identity': R = relevance_output[0] else: input_ = module.input[0] mean = module._buffers['running_mean'] var = module._buffers['running_var'] gamma = module._parameters['weight'] beta = module._parameters['bias'] w = (gamma / torch.sqrt(var + module.eps))[:, None, None] b = (beta - (mean * gamma) / torch.sqrt(var + module.eps))[:, None, None] xw = input_ * w R = util.safe_divide(torch.abs(xw), (torch.abs(xw) + torch.abs(b))) * \ relevance_output[0] assert not torch.isnan(R.sum()) assert not torch.isinf(R.sum()) assert R.sum() != 0 module.zero_grad() return R, relevance_input[1], relevance_input[2]
def propagate_relevance(self, module, relevance_input, relevance_output, lrp_method, lrp_params=None): input_ = module.input[0] module_clone = self._clone_module(module) with torch.enable_grad(): X = input_.clone().detach().requires_grad_(True) Z = module_clone(X) S = util.safe_divide(relevance_output[0].clone().detach(), Z) Z.backward(S) R = X * X.grad assert not torch.isnan(R.sum()) assert not torch.isinf(R.sum()) module.zero_grad() return (R, )