Ejemplo n.º 1
0
def test_rootfinder_with_params(dtype, device, bias_is_tensor):
    torch.manual_seed(100)
    random.seed(100)

    nr = 3
    nbatch = 2
    fwd_options = {
        "method": "broyden1",
        "f_tol": 1e-9,
        "alpha": -0.5,
    }

    clss = DummyModuleExplicit
    A = (torch.randn((nr, nr)) * 0.5).to(dtype).requires_grad_()
    diag = torch.randn((nbatch, nr)).to(dtype).requires_grad_()
    if bias_is_tensor:
        bias = torch.zeros((nbatch, nr)).to(dtype).requires_grad_()
    else:
        bias = 0.0
    y0 = torch.randn((nbatch, nr)).to(dtype)

    model = clss(addx=True)
    y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options)
    f = model.forward(y, A, diag, bias)
    assert torch.allclose(f * 0, f)

    def getloss(y0, A, diag, bias):
        model = clss(addx=True)
        y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options)
        return y

    gradcheck(getloss, (y0, A, diag, bias))
    gradgradcheck(getloss, (y0, A, diag, bias))
Ejemplo n.º 2
0
def test_rootfinder_methods(dtype, device, method):
    torch.manual_seed(100)
    random.seed(100)
    dtype = torch.float64

    nr = 3
    nbatch = 2
    default_fwd_options = {
        "f_tol": 1e-9,
        "alpha": -0.5,
    }
    # list the methods and the options here
    options = {
        "broyden1": default_fwd_options,
        "broyden2": default_fwd_options,
        "linearmixing": default_fwd_options,
    }[method]

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)

    fwd_options = {**options, "method": method}
    model = DummyModule(A, addx=True)
    model.set_diag_bias(diag, bias)
    y = rootfinder(model.forward, y0, **fwd_options)
    f = model.forward(y)
    assert torch.allclose(f * 0, f)
Ejemplo n.º 3
0
def test_rootfinder(dtype, device, clss):
    torch.manual_seed(100)
    random.seed(100)

    nr = 3
    nbatch = 2
    fwd_options = {
        "method": "broyden1",
        "f_tol": 1e-9,
        "alpha": -0.5,
    }

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)

    model = clss(A, addx=True)
    model.set_diag_bias(diag, bias)
    y = rootfinder(model.forward, y0, **fwd_options)
    f = model.forward(y)
    assert torch.allclose(f * 0, f)

    def getloss(A, y0, diag, bias):
        model = clss(A, addx=True)
        model.set_diag_bias(diag, bias)
        y = rootfinder(model.forward, y0, **fwd_options)
        return y

    gradcheck(getloss, (A, y0, diag, bias))
    gradgradcheck(getloss, (A, y0, diag, bias))
Ejemplo n.º 4
0
def get_intersection(r0, v, fcn):
    # r0: (nbatch, ndim) initial point of the rays
    # v: (nbatch, ndim) the direction of travel of the rays
    # fcn: a function that takes (nbatch, ndim-1) and outputs (nbatch, ndim)
    @xt.make_sibling(fcn)
    def rootfinder_fcn(y, r0, v):
        surface_pos = fcn(y[..., :-1])  # (nbatch, ndim)
        raypos = r0 + v * y[..., -1:]  # (nbatch, ndim)
        return (raypos - surface_pos)

    y0 = torch.zeros_like(v)
    y = rootfinder(rootfinder_fcn, y0, params=(r0, v))
    return y[..., :-1], y[..., -1:]  # (nbatch, ndim-1) and (nbatch, 1)
Ejemplo n.º 5
0
def test_rootfinder(dtype, device, clss):
    torch.manual_seed(100)
    random.seed(100)

    nr = 2
    nbatch = 2
    fwd_options = {
        "method": "broyden1",
        "f_tol": 1e-9,
        "alpha": -0.5,
    }

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)

    model = clss(A, addx=True)
    model.set_diag_bias(diag, bias)
    y = rootfinder(model.forward, y0, **fwd_options)
    f = model.forward(y)
    assert torch.allclose(f * 0, f)

    def getloss(A, y0, diag, bias):
        model = clss(A, addx=True)
        model.set_diag_bias(diag, bias)
        y = rootfinder(model.forward, y0, **fwd_options)
        return y

    # only check for real numbers or complex with DummyModule to save time
    checkgrad = not torch.is_complex(y0) or clss is DummyModule
    if checkgrad:
        gradcheck(getloss, (A, y0, diag, bias))
        gradgradcheck(getloss, (A, y0, diag, bias))
Ejemplo n.º 6
0
 def getloss(y0, A, diag, bias):
     model = clss(addx=True)
     y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options)
     return y
Ejemplo n.º 7
0
 def getloss(A, y0, diag, bias):
     model = clss(A, addx=True)
     model.set_diag_bias(diag, bias)
     y = rootfinder(model.forward, y0, **fwd_options)
     return y
Ejemplo n.º 8
0
 def getloss(a):
     model = clss(a)
     y = rootfinder(model.forward, y0, **fwd_options)
     return y