コード例 #1
0
    def propagate_relevance(self,
                            module,
                            relevance_input,
                            relevance_output,
                            lrp_method,
                            lrp_params=None):
        # todo how is batchnorm handled by default in innvestigate?
        if lrp_method == 'identity':
            R = relevance_output[0]

        else:
            input_ = module.input[0]
            mean = module._buffers['running_mean']
            var = module._buffers['running_var']
            gamma = module._parameters['weight']
            beta = module._parameters['bias']

            w = (gamma / torch.sqrt(var + module.eps))[:, None, None]
            b = (beta - (mean * gamma) / torch.sqrt(var + module.eps))[:, None,
                                                                       None]
            xw = input_ * w

            R = util.safe_divide(torch.abs(xw), (torch.abs(xw) + torch.abs(b))) * \
                relevance_output[0]
        assert not torch.isnan(R.sum())
        assert not torch.isinf(R.sum())
        assert R.sum() != 0
        module.zero_grad()
        return R, relevance_input[1], relevance_input[2]
コード例 #2
0
 def propagate_relevance(self, module, relevance_input, relevance_output,
                         lrp_method, lrp_params=None):
     input_ = module.input[0]
     module_clone = self._clone_module(module)
     with torch.enable_grad():
         X = input_.clone().detach().requires_grad_(True)
         Z = module_clone(X)
         S = util.safe_divide(relevance_output[0].clone().detach(), Z)
         Z.backward(S)
         R = X * X.grad
     assert not torch.isnan(R.sum())
     assert not torch.isinf(R.sum())
     module.zero_grad()
     return (R, )