def test_quad(dtype, device):
    torch.manual_seed(100)
    random.seed(100)
    nr = 2
    fwd_options = {
        "method": "leggauss",
        "n": 100,
    }

    a = torch.nn.Parameter(torch.rand((nr,), dtype=dtype, device=device).requires_grad_())
    b = torch.nn.Parameter(torch.randn((nr,), dtype=dtype, device=device).requires_grad_())
    c = torch.randn((nr,), dtype=dtype, device=device).requires_grad_()
    xl = torch.zeros((1,), dtype=dtype, device=device).requires_grad_()
    xu = (torch.ones ((1,), dtype=dtype, device=device) * 0.5).requires_grad_()

    for clss in [IntegrationModule, IntegrationNNModule]:

        module = clss(a, b)
        y = quad(module.forward, xl, xu, params=(c,), **fwd_options)
        ytrue = (torch.sin(a * xu + b * c) - torch.sin(a * xl + b * c)) / a
        assert torch.allclose(y, ytrue)

        def getloss(a, b, c, xl, xu):
            module = clss(a, b)
            y = quad(module.forward, xl, xu, params=(c,), **fwd_options)
            return y

        gradcheck    (getloss, (a, b, c, xl, xu))
        gradgradcheck(getloss, (a, b, c, xl, xu))
        # check if not all parameters require grad
        gradcheck    (getloss, (a, b.detach(), c, xl, xu))
Exemple #2
0
def test_quad_multi(dtype, device):
    torch.manual_seed(100)
    random.seed(100)
    nr = 4
    fwd_options = {
        "method": "leggauss",
        "n": 100,
    }

    a = torch.nn.Parameter(
        torch.rand((nr, ), dtype=dtype, device=device).requires_grad_())
    b = torch.nn.Parameter(
        torch.randn((nr, ), dtype=dtype, device=device).requires_grad_())
    c = torch.randn((nr, ), dtype=dtype, device=device).requires_grad_()
    xl = torch.zeros((1, ), dtype=dtype, device=device).requires_grad_()
    xu = (torch.ones((1, ), dtype=dtype, device=device) * 0.5).requires_grad_()

    for clss in [IntegrationMultiModule, IntegrationNNMultiModule]:
        module = clss(a, b)
        y = quad(module.forward, xl, xu, params=(c, ), fwd_options=fwd_options)
        ytrue0 = (torch.sin(a * xu + b * c) - torch.sin(a * xl + b * c)) / a
        ytrue1 = (-torch.cos(a * xu + b * c) + torch.cos(a * xl + b * c)) / a
        assert len(y) == 2
        assert torch.allclose(y[0], ytrue0)
        assert torch.allclose(y[1], ytrue1)
Exemple #3
0
 def getloss(a, b, c, xl, xu):
     module = clss(a, b)
     y = quad(module.forward,
              xl,
              xu,
              params=(c, ),
              fwd_options=fwd_options)
     return y
Exemple #4
0
 def get_loss(w):
     module = IntegrationInfModule(w)
     y = quad(module.forward,
              xl,
              xu,
              params=[],
              fwd_options=fwd_options)
     return y