def test_equil_methods(dtype, device, method): torch.manual_seed(100) random.seed(100) nr = 3 nbatch = 2 default_fwd_options = { "f_tol": 1e-9, "alpha": -0.5, } # list the methods and the options here options = { "broyden1": default_fwd_options, "broyden2": default_fwd_options, "linearmixing": default_fwd_options, }[method] A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) fwd_options = {**options, "method": method} model = DummyModule(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, **fwd_options) f = model.forward(y) assert torch.allclose(y, f)
def test_equil(dtype, device): torch.manual_seed(100) random.seed(100) nr = 3 nbatch = 2 fwd_options = { "f_tol": 1e-9, "alpha": -0.5, } for clss in [DummyModule, DummyNNModule]: A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) model = clss(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, **fwd_options) f = model.forward(y) assert torch.allclose(y, f) def getloss(A, y0, diag, bias): model = clss(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, **fwd_options) return y gradcheck(getloss, (A, y0, diag, bias)) gradgradcheck(getloss, (A, y0, diag, bias))
def getloss(A, y0, diag, bias): model = clss(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, bck_options=bck_options, **fwd_options) return y
def test_equil(dtype, device, clss): torch.manual_seed(100) random.seed(100) nr = 2 nbatch = 2 fwd_options = { "method": "broyden1", "f_tol": 1e-12, "alpha": -0.5, } bck_options = { "method": "cg", } A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) model = clss(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, bck_options=bck_options, **fwd_options) f = model.forward(y) assert torch.allclose(y, f) def getloss(A, y0, diag, bias): model = clss(A, addx=False) model.set_diag_bias(diag, bias) y = equilibrium(model.forward, y0, bck_options=bck_options, **fwd_options) return y # only check for real numbers or complex with DummyModule to save time checkgrad = not torch.is_complex(y0) or clss is DummyModule if checkgrad: gradcheck(getloss, (A, y0, diag, bias)) gradgradcheck(getloss, (A, y0, diag, bias))
def getloss(a): model = clss(a) y = equilibrium(model.forward, y0, **fwd_options) return y