def test_quad(dtype, device): torch.manual_seed(100) random.seed(100) nr = 2 fwd_options = { "method": "leggauss", "n": 100, } a = torch.nn.Parameter(torch.rand((nr,), dtype=dtype, device=device).requires_grad_()) b = torch.nn.Parameter(torch.randn((nr,), dtype=dtype, device=device).requires_grad_()) c = torch.randn((nr,), dtype=dtype, device=device).requires_grad_() xl = torch.zeros((1,), dtype=dtype, device=device).requires_grad_() xu = (torch.ones ((1,), dtype=dtype, device=device) * 0.5).requires_grad_() for clss in [IntegrationModule, IntegrationNNModule]: module = clss(a, b) y = quad(module.forward, xl, xu, params=(c,), **fwd_options) ytrue = (torch.sin(a * xu + b * c) - torch.sin(a * xl + b * c)) / a assert torch.allclose(y, ytrue) def getloss(a, b, c, xl, xu): module = clss(a, b) y = quad(module.forward, xl, xu, params=(c,), **fwd_options) return y gradcheck (getloss, (a, b, c, xl, xu)) gradgradcheck(getloss, (a, b, c, xl, xu)) # check if not all parameters require grad gradcheck (getloss, (a, b.detach(), c, xl, xu))
def test_quad_multi(dtype, device): torch.manual_seed(100) random.seed(100) nr = 4 fwd_options = { "method": "leggauss", "n": 100, } a = torch.nn.Parameter( torch.rand((nr, ), dtype=dtype, device=device).requires_grad_()) b = torch.nn.Parameter( torch.randn((nr, ), dtype=dtype, device=device).requires_grad_()) c = torch.randn((nr, ), dtype=dtype, device=device).requires_grad_() xl = torch.zeros((1, ), dtype=dtype, device=device).requires_grad_() xu = (torch.ones((1, ), dtype=dtype, device=device) * 0.5).requires_grad_() for clss in [IntegrationMultiModule, IntegrationNNMultiModule]: module = clss(a, b) y = quad(module.forward, xl, xu, params=(c, ), fwd_options=fwd_options) ytrue0 = (torch.sin(a * xu + b * c) - torch.sin(a * xl + b * c)) / a ytrue1 = (-torch.cos(a * xu + b * c) + torch.cos(a * xl + b * c)) / a assert len(y) == 2 assert torch.allclose(y[0], ytrue0) assert torch.allclose(y[1], ytrue1)
def getloss(a, b, c, xl, xu): module = clss(a, b) y = quad(module.forward, xl, xu, params=(c, ), fwd_options=fwd_options) return y
def get_loss(w): module = IntegrationInfModule(w) y = quad(module.forward, xl, xu, params=[], fwd_options=fwd_options) return y