示例#1
0
def test_equil_methods(dtype, device, method):
    torch.manual_seed(100)
    random.seed(100)

    nr = 3
    nbatch = 2
    default_fwd_options = {
        "f_tol": 1e-9,
        "alpha": -0.5,
    }
    # list the methods and the options here
    options = {
        "broyden1": default_fwd_options,
        "broyden2": default_fwd_options,
        "linearmixing": default_fwd_options,
    }[method]

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)

    fwd_options = {**options, "method": method}
    model = DummyModule(A, addx=False)
    model.set_diag_bias(diag, bias)
    y = equilibrium(model.forward, y0, **fwd_options)
    f = model.forward(y)
    assert torch.allclose(y, f)
示例#2
0
def test_equil(dtype, device):
    torch.manual_seed(100)
    random.seed(100)

    nr = 3
    nbatch = 2
    fwd_options = {
        "f_tol": 1e-9,
        "alpha": -0.5,
    }

    for clss in [DummyModule, DummyNNModule]:
        A = torch.nn.Parameter((torch.randn(
            (nr, nr)) * 0.5).to(dtype).requires_grad_())
        diag = torch.nn.Parameter(
            torch.randn((nbatch, nr)).to(dtype).requires_grad_())
        bias = torch.nn.Parameter(
            torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
        y0 = torch.randn((nbatch, nr)).to(dtype)

        model = clss(A, addx=False)
        model.set_diag_bias(diag, bias)
        y = equilibrium(model.forward, y0, **fwd_options)
        f = model.forward(y)
        assert torch.allclose(y, f)

        def getloss(A, y0, diag, bias):
            model = clss(A, addx=False)
            model.set_diag_bias(diag, bias)
            y = equilibrium(model.forward, y0, **fwd_options)
            return y

        gradcheck(getloss, (A, y0, diag, bias))
        gradgradcheck(getloss, (A, y0, diag, bias))
示例#3
0
 def getloss(A, y0, diag, bias):
     model = clss(A, addx=False)
     model.set_diag_bias(diag, bias)
     y = equilibrium(model.forward,
                     y0,
                     bck_options=bck_options,
                     **fwd_options)
     return y
示例#4
0
def test_equil(dtype, device, clss):
    torch.manual_seed(100)
    random.seed(100)

    nr = 2
    nbatch = 2
    fwd_options = {
        "method": "broyden1",
        "f_tol": 1e-12,
        "alpha": -0.5,
    }
    bck_options = {
        "method": "cg",
    }

    A = torch.nn.Parameter((torch.randn(
        (nr, nr)) * 0.5).to(dtype).requires_grad_())
    diag = torch.nn.Parameter(
        torch.randn((nbatch, nr)).to(dtype).requires_grad_())
    bias = torch.nn.Parameter(
        torch.zeros((nbatch, nr)).to(dtype).requires_grad_())
    y0 = torch.randn((nbatch, nr)).to(dtype)

    model = clss(A, addx=False)
    model.set_diag_bias(diag, bias)
    y = equilibrium(model.forward, y0, bck_options=bck_options, **fwd_options)
    f = model.forward(y)
    assert torch.allclose(y, f)

    def getloss(A, y0, diag, bias):
        model = clss(A, addx=False)
        model.set_diag_bias(diag, bias)
        y = equilibrium(model.forward,
                        y0,
                        bck_options=bck_options,
                        **fwd_options)
        return y

    # only check for real numbers or complex with DummyModule to save time
    checkgrad = not torch.is_complex(y0) or clss is DummyModule
    if checkgrad:
        gradcheck(getloss, (A, y0, diag, bias))
        gradgradcheck(getloss, (A, y0, diag, bias))
示例#5
0
 def getloss(a):
     model = clss(a)
     y = equilibrium(model.forward, y0, **fwd_options)
     return y