def test_minimize_methods(dtype, device, method): torch.manual_seed(400) random.seed(100) nr = 3 nbatch = 2 default_fwd_options = { "max_niter": 50, "f_tol": 1e-9, "alpha": -1.0, } linearmixing_fwd_options = { "max_niter": 50, "f_tol": 3e-6, "alpha": -0.3, } gd_fwd_options = { "maxiter": 5000, "f_rtol": 1e-10, "x_rtol": 1e-10, "step": 1e-2, } # list the methods and the options here options = { "broyden1": default_fwd_options, "broyden2": default_fwd_options, "linearmixing": linearmixing_fwd_options, "gd": gd_fwd_options, "adam": gd_fwd_options, }[method] # specify higher atol for non-ideal method atol = defaultdict(lambda: 1e-8) atol["linearmixing"] = 3e-6 A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) # bias will be detached from the optimization line, so set it undifferentiable bias = torch.zeros((nbatch, nr)).to(dtype) y0 = torch.randn((nbatch, nr)).to(dtype) activation = "square" # square activation makes it easy to optimize fwd_options = {**options, "method": method} model = DummyModule(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, **fwd_options) # check the grad (must be close to 1) with torch.enable_grad(): y1 = y.clone().requires_grad_() f = model.forward(y1) grady, = torch.autograd.grad(f, (y1, )) assert torch.allclose(grady, grady * 0, atol=atol[method]) # check the hessian (must be posdef) h = hess(model.forward, (y1, ), idxs=0).fullmatrix() eigval, _ = torch.symeig(h) assert torch.all(eigval >= 0)
def test_minimize(dtype, device, clss): torch.manual_seed(400) random.seed(100) nr = 3 nbatch = 2 A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) # bias will be detached from the optimization line, so set it undifferentiable bias = torch.zeros((nbatch, nr)).to(dtype) y0 = torch.randn((nbatch, nr)).to(dtype) fwd_options = { "method": "broyden1", "max_niter": 50, "f_tol": 1e-9, "alpha": -0.5, } activation = "square" # square activation makes it easy to optimize model = clss(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, **fwd_options) # check the grad (must be close to 1) with torch.enable_grad(): y1 = y.clone().requires_grad_() f = model.forward(y1) grady, = torch.autograd.grad(f, (y1, )) assert torch.allclose(grady, grady * 0) # check the hessian (must be posdef) h = hess(model.forward, (y1, ), idxs=0).fullmatrix() eigval, _ = torch.symeig(h) assert torch.all(eigval >= 0) def getloss(A, y0, diag, bias): model = clss(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, **fwd_options) return y gradcheck(getloss, (A, y0, diag, bias)) gradgradcheck(getloss, (A, y0, diag, bias))
def test_minimize_methods(dtype, device): torch.manual_seed(400) random.seed(100) dtype = torch.float64 nr = 3 nbatch = 2 default_fwd_options = { "max_niter": 50, "f_tol": 1e-9, "alpha": -0.5, } # list the methods and the options here methods_and_options = { "broyden1": default_fwd_options, } A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) # bias will be detached from the optimization line, so set it undifferentiable bias = torch.zeros((nbatch, nr)).to(dtype) y0 = torch.randn((nbatch, nr)).to(dtype) activation = "square" # square activation makes it easy to optimize for method in methods_and_options: fwd_options = {**methods_and_options[method], "method": method} model = DummyModule(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, **fwd_options) # check the grad (must be close to 1) with torch.enable_grad(): y1 = y.clone().requires_grad_() f = model.forward(y1) grady, = torch.autograd.grad(f, (y1, )) assert torch.allclose(grady, grady * 0) # check the hessian (must be posdef) h = hess(model.forward, (y1, ), idxs=0).fullmatrix() eigval, _ = torch.symeig(h) assert torch.all(eigval >= 0)
def test_hess_func(): na = 3 params = getfnparams(na) nnparams = getnnparams(na) nparams = len(params) all_idxs = [None, (0, ), (1, ), (0, 1), (0, 1, 2)] funcs = [hfunc1, hfunc2(*nnparams)] for func in funcs: for idxs in all_idxs: if idxs is None: gradparams = params else: gradparams = [params[i] for i in idxs] hs = hess(func, params, idxs=idxs) assert len(hs) == len(gradparams) y = func(*params) nins = [torch.numel(p) for p in gradparams] w = [torch.rand_like(p) for p in gradparams] for i in range(len(hs)): assert list(hs[i].shape) == [nins[i], nins[i]] # assert the values dfdy = torch.autograd.grad(y, gradparams, create_graph=True) hs_mv_man = [ torch.autograd.grad(dfdy[i], (gradparams[i], ), grad_outputs=w[i], retain_graph=True)[0] for i in range(len(dfdy)) ] hs_mv = [ hs[i].mv(w[i].reshape(-1, nins[i])) for i in range(len(dfdy)) ] for i in range(len(dfdy)): assert torch.allclose(hs_mv[i].view(-1), hs_mv_man[i].view(-1))
def test_hess_grad(): na = 3 params = getfnparams(na) params2 = [torch.rand(1, dtype=dtype).requires_grad_() for p in params] hs = hess(hfunc1, params) def fcnl(i, v, *params): hs = hess(hfunc1, params) return hs[i].mv(v.view(-1)) def fcnl2(i, v, *params2): fmv = get_pure_function(hs[i].mv) params0 = v.view(-1) params1 = fmv.objparams() params12 = [p1 * p2 for (p1, p2) in zip(params1, params2)] with fmv.useobjparams(params12): return fmv(params0) w = [torch.rand_like(p).requires_grad_() for p in params] for i in range(len(hs)): gradcheck(fcnl, (i, w[i], *params)) gradgradcheck(fcnl, (i, w[i], *params)) gradcheck(fcnl2, (i, w[i], *params2)) gradgradcheck(fcnl2, (i, w[i], *params2))
def fcnl(i, v, *params): hs = hess(hfunc1, params) return hs[i].mv(v.view(-1))
def test_minimize(dtype, device, clss, method): torch.manual_seed(400) random.seed(100) method_fwd_options = { "broyden1": { "max_niter": 50, "f_tol": 1e-9, "alpha": -0.5, }, "gd": { "maxiter": 10000, "f_rtol": 1e-14, "x_rtol": 1e-14, "step": 2e-2, }, } nr = 2 nbatch = 2 A = torch.nn.Parameter((torch.randn( (nr, nr)) * 0.5).to(dtype).requires_grad_()) diag = torch.nn.Parameter( torch.randn((nbatch, nr)).to(dtype).requires_grad_()) bias = torch.nn.Parameter( torch.zeros((nbatch, nr)).to(dtype).requires_grad_()) y0 = torch.randn((nbatch, nr)).to(dtype) fwd_options = method_fwd_options[method] bck_options = { "rtol": 1e-9, "atol": 1e-9, } activation = "square" # square activation makes it easy to optimize model = clss(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, method=method, **fwd_options) # check the grad (must be close to 0) with torch.enable_grad(): y1 = y.clone().requires_grad_() f = model.forward(y1) grady, = torch.autograd.grad(f, (y1, )) assert torch.allclose(grady, grady * 0) # check the hessian (must be posdef) h = hess(model.forward, (y1, ), idxs=0).fullmatrix() eigval, _ = torch.symeig(h) assert torch.all(eigval >= 0) def getloss(A, y0, diag, bias): model = clss(A, addx=False, activation=activation, sumoutput=True) model.set_diag_bias(diag, bias) y = minimize(model.forward, y0, method=method, bck_options=bck_options, **fwd_options) return y gradcheck(getloss, (A, y0, diag, bias)) # pytorch 1.8's gradgradcheck fails if there are unrelated variables # I have made a PR to solve this and will be in 1.9 gradgradcheck(getloss, (A, y0, diag, bias.detach()))