Exemplo n.º 1
0
    def eigen_proxy(m, n, target_scheme, size_data=None):
        m, n = maths.sym(m), maths.sym(n)
        if size_data is not None:
            m = clean_zero_padding(m, size_data)
            n = clean_zero_padding(n, size_data)

        return maths.eighb(m, n, scheme=target_scheme)
Exemplo n.º 2
0
def test_eighb_general_batch(device):
    """eighb accuracy on a batch of general eigenvalue problems."""
    sizes = torch.randint(2, 10, (11, ), device=device)
    a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes]
    b = [
        maths.sym(torch.eye(s, device=device) * torch.rand(s, device=device))
        for s in sizes
    ]
    a_batch, b_batch = batch.pack(a), batch.pack(b)

    w_ref = batch.pack(
        [torch.tensor(linalg.eigh(i.sft(), j.sft())[0]) for i, j in zip(a, b)])

    aux_settings = [True, False]
    schemes = ['chol', 'lowd']
    for scheme in schemes:
        for aux in aux_settings:
            w_calc = maths.eighb(a_batch, b_batch, scheme=scheme, aux=aux)[0]

            mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))

            same_device = w_calc.device == device

            assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}'
            assert same_device, 'Device persistence check'
Exemplo n.º 3
0
def test_eighb_general_grad(device):
    """eighb gradient stability on general eigenvalue problems."""
    def eigen_proxy(m, n, target_scheme, size_data=None):
        m, n = maths.sym(m), maths.sym(n)
        if size_data is not None:
            m = clean_zero_padding(m, size_data)
            n = clean_zero_padding(n, size_data)

        return maths.eighb(m, n, scheme=target_scheme)

    # Generate a single generalised eigenvalue test instance
    a1 = maths.sym(torch.rand(8, 8, device=device))
    b1 = maths.sym(torch.eye(8, device=device) * torch.rand(8, device=device))
    a1.requires_grad, b1.requires_grad = True, True

    schemes = ['chol', 'lowd']
    for scheme in schemes:
        grad_is_safe = gradcheck(eigen_proxy, (a1, b1, scheme),
                                 raise_exception=False)
        assert grad_is_safe, f'Non-degenerate single test failed on {scheme}'

    # Generate a batch of generalised eigenvalue test instances
    sizes = torch.randint(3, 8, (5, ), device=device)
    a2 = batch.pack(
        [maths.sym(torch.rand(s, s, device=device)) for s in sizes])
    b2 = batch.pack([
        maths.sym(torch.eye(s, device=device) * torch.rand(s, device=device))
        for s in sizes
    ])
    a2.requires_grad, b2.requires_grad = True, True

    for scheme in schemes:
        grad_is_safe = gradcheck(eigen_proxy, (a2, b2, scheme, sizes),
                                 raise_exception=False)
        assert grad_is_safe, f'Non-degenerate batch test failed on {scheme}'
Exemplo n.º 4
0
def test_eighb_broadening_grad(device):
    """eighb gradient stability on standard, broadened, eigenvalue problems.

    There is no separate test for the standard eigenvalue problem without
    broadening as this would result in a direct call to torch.symeig which is
    unnecessary. However, it is important to note that conditional broadening
    technically is never tested, i.e. the lines:

    .. code-block:: python
        ...
        if ctx.bm == 'cond':  # <- Conditional broadening
            deltas = 1 / torch.where(torch.abs(deltas) > bf,
                                     deltas, bf) * torch.sign(deltas)
        ...

    of `_SymEigB` are never actual run. This is because it only activates when
    there are true eigen-value degeneracies; & degenerate eigenvalue problems
    do not "play well" with the gradcheck operation.
    """
    def eigen_proxy(m, target_method, size_data=None):
        m = maths.sym(m)
        if size_data is not None:
            m = clean_zero_padding(m, size_data)
        if target_method is None:
            return torch.symeig(m, True)
        else:
            return maths.eighb(m, broadening_method=target_method)

    # Generate a single standard eigenvalue test instance
    a1 = maths.sym(torch.rand(8, 8, device=device))
    a1.requires_grad = True

    broadening_methods = [None, 'none', 'cond', 'lorn']
    for method in broadening_methods:
        grad_is_safe = gradcheck(eigen_proxy, (a1, method),
                                 raise_exception=False)
        assert grad_is_safe, f'Non-degenerate single test failed on {method}'

    # Generate a batch of standard eigenvalue test instances
    sizes = torch.randint(3, 8, (5, ), device=device)
    a2 = batch.pack(
        [maths.sym(torch.rand(s, s, device=device)) for s in sizes])
    a2.requires_grad = True

    for method in broadening_methods[2:]:
        grad_is_safe = gradcheck(eigen_proxy, (a2, method, sizes),
                                 raise_exception=False)
        assert grad_is_safe, f'Non-degenerate batch test failed on {method}'
Exemplo n.º 5
0
 def eigen_proxy(m, target_method, size_data=None):
     m = maths.sym(m)
     if size_data is not None:
         m = clean_zero_padding(m, size_data)
     if target_method is None:
         return torch.symeig(m, True)
     else:
         return maths.eighb(m, broadening_method=target_method)
Exemplo n.º 6
0
def test_eighb_general_single(device):
    """eighb accuracy on a single general eigenvalue problem."""
    a = maths.sym(torch.rand(10, 10, device=device))
    b = maths.sym(torch.eye(10, device=device) * torch.rand(10, device=device))

    w_ref = linalg.eigh(a.sft(), b.sft())[0]

    schemes = ['chol', 'lowd']
    for scheme in schemes:
        w_calc, v_calc = maths.eighb(a, b, scheme=scheme)

        mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))
        mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0)))
        same_device = w_calc.device == device == v_calc.device

        assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}'
        assert mae_v < 1E-12, f'Eigenvector orthogonality test {scheme}'
        assert same_device, 'Device persistence check'
Exemplo n.º 7
0
def test_sym_batch(device):
    """Batch evaluation of maths.sym function."""
    data = torch.rand(10, 10, 10, device=device)
    pred = maths.sym(data, -1, -2)
    ref = torch.stack([(i + i.T) / 2 for i in data], 0)
    abs_delta = torch.max(torch.abs(pred.cpu() - ref.cpu()))
    same_device = pred.device == device

    assert abs_delta < 1E-12, 'Tolerance check'
    assert same_device, 'Device persistence check'
Exemplo n.º 8
0
def test_sym_single(device):
    """Serial evaluation of maths.sym function."""
    data = torch.rand(10, 10, device=device)
    pred = maths.sym(data)
    ref = (data + data.T) / 2
    abs_delta = torch.max(torch.abs(pred.cpu() - ref.cpu()))
    same_device = pred.device == device

    assert abs_delta < 1E-12, 'Tolerance check'
    assert same_device, 'Device persistence check'
Exemplo n.º 9
0
def test_eighb_standard_single(device):
    """eighb accuracy on a single standard eigenvalue problem."""
    a = maths.sym(torch.rand(10, 10, device=device))

    w_ref = linalg.eigh(a.sft())[0]

    w_calc, v_calc = maths.eighb(a)

    mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))
    mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0)))
    same_device = w_calc.device == device == v_calc.device

    assert mae_w < 1E-12, 'Eigenvalue tolerance test'
    assert mae_v < 1E-12, 'Eigenvector orthogonality test'
    assert same_device, 'Device persistence check'
Exemplo n.º 10
0
def test_eighb_standard_batch(device):
    """eighb accuracy on a batch of standard eigenvalue problems."""
    sizes = torch.randint(2, 10, (11, ), device=device)
    a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes]
    a_batch = batch.pack(a)

    w_ref = batch.pack([torch.tensor(linalg.eigh(i.cpu())[0]) for i in a])

    w_calc = maths.eighb(a_batch)[0]

    mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))

    same_device = w_calc.device == device

    assert mae_w < 1E-12, 'Eigenvalue tolerance test'
    assert same_device, 'Device persistence check'