def test_rootfinder_with_params(dtype, device, bias_is_tensor): torch.manual_seed(100) random.seed(100) nr = 3 nbatch = 2 fwd_options = { "method": "broyden1", "f_tol": 1e-9, "alpha": -0.5, } clss = DummyModuleExplicit A = (torch.randn((nr, nr)) * 0.5).to(dtype).requires_grad_() diag = torch.randn((nbatch, nr)).to(dtype).requires_grad_() if bias_is_tensor: bias = torch.zeros((nbatch, nr)).to(dtype).requires_grad_() else: bias = 0.0 y0 = torch.randn((nbatch, nr)).to(dtype) model = clss(addx=True) y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options) f = model.forward(y, A, diag, bias) assert torch.allclose(f * 0, f) def getloss(y0, A, diag, bias): model = clss(addx=True) y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options) return y gradcheck(getloss, (y0, A, diag, bias)) gradgradcheck(getloss, (y0, A, diag, bias))
def test_rootfinder_methods(dtype, device, method): torch.manual_seed(100) random.seed(100) dtype = torch.float64 nr = 3 nbatch = 2 default_fwd_options = { "f_tol": 1e-9, "alpha": -0.5, } # list the methods and the options here options = { "broyden1": default_fwd_options, "broyden2": default_fwd_options, "linearmixing": default_fwd_options, }[method] A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) fwd_options = {**options, "method": method} model = DummyModule(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) f = model.forward(y) assert torch.allclose(f * 0, f)
def test_rootfinder(dtype, device, clss): torch.manual_seed(100) random.seed(100) nr = 3 nbatch = 2 fwd_options = { "method": "broyden1", "f_tol": 1e-9, "alpha": -0.5, } A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) model = clss(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) f = model.forward(y) assert torch.allclose(f * 0, f) def getloss(A, y0, diag, bias): model = clss(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) return y gradcheck(getloss, (A, y0, diag, bias)) gradgradcheck(getloss, (A, y0, diag, bias))
def get_intersection(r0, v, fcn): # r0: (nbatch, ndim) initial point of the rays # v: (nbatch, ndim) the direction of travel of the rays # fcn: a function that takes (nbatch, ndim-1) and outputs (nbatch, ndim) @xt.make_sibling(fcn) def rootfinder_fcn(y, r0, v): surface_pos = fcn(y[..., :-1]) # (nbatch, ndim) raypos = r0 + v * y[..., -1:] # (nbatch, ndim) return (raypos - surface_pos) y0 = torch.zeros_like(v) y = rootfinder(rootfinder_fcn, y0, params=(r0, v)) return y[..., :-1], y[..., -1:] # (nbatch, ndim-1) and (nbatch, 1)
def test_rootfinder(dtype, device, clss): torch.manual_seed(100) random.seed(100) nr = 2 nbatch = 2 fwd_options = { "method": "broyden1", "f_tol": 1e-9, "alpha": -0.5, } A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) model = clss(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) f = model.forward(y) assert torch.allclose(f * 0, f) def getloss(A, y0, diag, bias): model = clss(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) return y # only check for real numbers or complex with DummyModule to save time checkgrad = not torch.is_complex(y0) or clss is DummyModule if checkgrad: gradcheck(getloss, (A, y0, diag, bias)) gradgradcheck(getloss, (A, y0, diag, bias))
def getloss(y0, A, diag, bias): model = clss(addx=True) y = rootfinder(model.forward, y0, (A, diag, bias), **fwd_options) return y
def getloss(A, y0, diag, bias): model = clss(A, addx=True) model.set_diag_bias(diag, bias) y = rootfinder(model.forward, y0, **fwd_options) return y
def getloss(a): model = clss(a) y = rootfinder(model.forward, y0, **fwd_options) return y