Example #1
0
def test_minimize_methods(dtype, device, method):
    torch.manual_seed(400)
    random.seed(100)

    nr = 3
    nbatch = 2
    default_fwd_options = {
        "max_niter": 50,
        "f_tol": 1e-9,
        "alpha": -1.0,
    }
    linearmixing_fwd_options = {
        "max_niter": 50,
        "f_tol": 3e-6,
        "alpha": -0.3,
    }
    gd_fwd_options = {
        "maxiter": 5000,
        "f_rtol": 1e-10,
        "x_rtol": 1e-10,
        "step": 1e-2,
    }
    # list the methods and the options here
    options = {
        "broyden1": default_fwd_options,
        "broyden2": default_fwd_options,
        "linearmixing": linearmixing_fwd_options,
        "gd": gd_fwd_options,
        "adam": gd_fwd_options,
    }[method]

    # specify higher atol for non-ideal method
    atol = defaultdict(lambda: 1e-8)
    atol["linearmixing"] = 3e-6

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    # bias will be detached from the optimization line, so set it undifferentiable
    bias = torch.zeros((nbatch, nr)).to(dtype)
    y0 = torch.randn((nbatch, nr)).to(dtype)
    activation = "square"  # square activation makes it easy to optimize

    fwd_options = {**options, "method": method}
    model = DummyModule(A, addx=False, activation=activation, sumoutput=True)
    model.set_diag_bias(diag, bias)
    y = minimize(model.forward, y0, **fwd_options)

    # check the grad (must be close to 1)
    with torch.enable_grad():
        y1 = y.clone().requires_grad_()
        f = model.forward(y1)
    grady, = torch.autograd.grad(f, (y1, ))
    assert torch.allclose(grady, grady * 0, atol=atol[method])

    # check the hessian (must be posdef)
    h = hess(model.forward, (y1, ), idxs=0).fullmatrix()
    eigval, _ = torch.symeig(h)
    assert torch.all(eigval >= 0)
Example #2
0
def test_minimize(dtype, device, clss):
    torch.manual_seed(400)
    random.seed(100)

    nr = 3
    nbatch = 2

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    # bias will be detached from the optimization line, so set it undifferentiable
    bias = torch.zeros((nbatch, nr)).to(dtype)
    y0 = torch.randn((nbatch, nr)).to(dtype)
    fwd_options = {
        "method": "broyden1",
        "max_niter": 50,
        "f_tol": 1e-9,
        "alpha": -0.5,
    }
    activation = "square"  # square activation makes it easy to optimize

    model = clss(A, addx=False, activation=activation, sumoutput=True)
    model.set_diag_bias(diag, bias)
    y = minimize(model.forward, y0, **fwd_options)

    # check the grad (must be close to 1)
    with torch.enable_grad():
        y1 = y.clone().requires_grad_()
        f = model.forward(y1)
    grady, = torch.autograd.grad(f, (y1, ))
    assert torch.allclose(grady, grady * 0)

    # check the hessian (must be posdef)
    h = hess(model.forward, (y1, ), idxs=0).fullmatrix()
    eigval, _ = torch.symeig(h)
    assert torch.all(eigval >= 0)

    def getloss(A, y0, diag, bias):
        model = clss(A, addx=False, activation=activation, sumoutput=True)
        model.set_diag_bias(diag, bias)
        y = minimize(model.forward, y0, **fwd_options)
        return y

    gradcheck(getloss, (A, y0, diag, bias))
    gradgradcheck(getloss, (A, y0, diag, bias))
def test_minimize_methods(dtype, device):
    torch.manual_seed(400)
    random.seed(100)
    dtype = torch.float64

    nr = 3
    nbatch = 2
    default_fwd_options = {
        "max_niter": 50,
        "f_tol": 1e-9,
        "alpha": -0.5,
    }
    # list the methods and the options here
    methods_and_options = {
        "broyden1": default_fwd_options,
    }

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    # bias will be detached from the optimization line, so set it undifferentiable
    bias = torch.zeros((nbatch, nr)).to(dtype)
    y0 = torch.randn((nbatch, nr)).to(dtype)
    activation = "square"  # square activation makes it easy to optimize

    for method in methods_and_options:
        fwd_options = {**methods_and_options[method], "method": method}
        model = DummyModule(A,
                            addx=False,
                            activation=activation,
                            sumoutput=True)
        model.set_diag_bias(diag, bias)
        y = minimize(model.forward, y0, **fwd_options)

        # check the grad (must be close to 1)
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            f = model.forward(y1)
        grady, = torch.autograd.grad(f, (y1, ))
        assert torch.allclose(grady, grady * 0)

        # check the hessian (must be posdef)
        h = hess(model.forward, (y1, ), idxs=0).fullmatrix()
        eigval, _ = torch.symeig(h)
        assert torch.all(eigval >= 0)
Example #4
0
def test_hess_func():
    na = 3
    params = getfnparams(na)
    nnparams = getnnparams(na)
    nparams = len(params)
    all_idxs = [None, (0, ), (1, ), (0, 1), (0, 1, 2)]
    funcs = [hfunc1, hfunc2(*nnparams)]

    for func in funcs:
        for idxs in all_idxs:
            if idxs is None:
                gradparams = params
            else:
                gradparams = [params[i] for i in idxs]

            hs = hess(func, params, idxs=idxs)
            assert len(hs) == len(gradparams)

            y = func(*params)
            nins = [torch.numel(p) for p in gradparams]
            w = [torch.rand_like(p) for p in gradparams]
            for i in range(len(hs)):
                assert list(hs[i].shape) == [nins[i], nins[i]]

            # assert the values
            dfdy = torch.autograd.grad(y, gradparams, create_graph=True)
            hs_mv_man = [
                torch.autograd.grad(dfdy[i], (gradparams[i], ),
                                    grad_outputs=w[i],
                                    retain_graph=True)[0]
                for i in range(len(dfdy))
            ]
            hs_mv = [
                hs[i].mv(w[i].reshape(-1, nins[i])) for i in range(len(dfdy))
            ]
            for i in range(len(dfdy)):
                assert torch.allclose(hs_mv[i].view(-1), hs_mv_man[i].view(-1))
Example #5
0
def test_hess_grad():
    na = 3
    params = getfnparams(na)
    params2 = [torch.rand(1, dtype=dtype).requires_grad_() for p in params]
    hs = hess(hfunc1, params)

    def fcnl(i, v, *params):
        hs = hess(hfunc1, params)
        return hs[i].mv(v.view(-1))

    def fcnl2(i, v, *params2):
        fmv = get_pure_function(hs[i].mv)
        params0 = v.view(-1)
        params1 = fmv.objparams()
        params12 = [p1 * p2 for (p1, p2) in zip(params1, params2)]
        with fmv.useobjparams(params12):
            return fmv(params0)

    w = [torch.rand_like(p).requires_grad_() for p in params]
    for i in range(len(hs)):
        gradcheck(fcnl, (i, w[i], *params))
        gradgradcheck(fcnl, (i, w[i], *params))
        gradcheck(fcnl2, (i, w[i], *params2))
        gradgradcheck(fcnl2, (i, w[i], *params2))
Example #6
0
 def fcnl(i, v, *params):
     hs = hess(hfunc1, params)
     return hs[i].mv(v.view(-1))
Example #7
0
def test_minimize(dtype, device, clss, method):
    torch.manual_seed(400)
    random.seed(100)

    method_fwd_options = {
        "broyden1": {
            "max_niter": 50,
            "f_tol": 1e-9,
            "alpha": -0.5,
        },
        "gd": {
            "maxiter": 10000,
            "f_rtol": 1e-14,
            "x_rtol": 1e-14,
            "step": 2e-2,
        },
    }

    nr = 2
    nbatch = 2

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)
    fwd_options = method_fwd_options[method]
    bck_options = {
        "rtol": 1e-9,
        "atol": 1e-9,
    }
    activation = "square"  # square activation makes it easy to optimize

    model = clss(A, addx=False, activation=activation, sumoutput=True)
    model.set_diag_bias(diag, bias)
    y = minimize(model.forward, y0, method=method, **fwd_options)

    # check the grad (must be close to 0)
    with torch.enable_grad():
        y1 = y.clone().requires_grad_()
        f = model.forward(y1)
    grady, = torch.autograd.grad(f, (y1, ))
    assert torch.allclose(grady, grady * 0)

    # check the hessian (must be posdef)
    h = hess(model.forward, (y1, ), idxs=0).fullmatrix()
    eigval, _ = torch.symeig(h)
    assert torch.all(eigval >= 0)

    def getloss(A, y0, diag, bias):
        model = clss(A, addx=False, activation=activation, sumoutput=True)
        model.set_diag_bias(diag, bias)
        y = minimize(model.forward,
                     y0,
                     method=method,
                     bck_options=bck_options,
                     **fwd_options)
        return y

    gradcheck(getloss, (A, y0, diag, bias))
    # pytorch 1.8's gradgradcheck fails if there are unrelated variables
    # I have made a PR to solve this and will be in 1.9
    gradgradcheck(getloss, (A, y0, diag, bias.detach()))