예제 #1
0
    def eigen_proxy(m, n, target_scheme, size_data=None):
        m, n = maths.sym(m), maths.sym(n)
        if size_data is not None:
            m = clean_zero_padding(m, size_data)
            n = clean_zero_padding(n, size_data)

        return maths.eighb(m, n, scheme=target_scheme)
예제 #2
0
def test_eighb_general_batch(device):
    """eighb accuracy on a batch of general eigenvalue problems."""
    sizes = torch.randint(2, 10, (11, ), device=device)
    a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes]
    b = [
        maths.sym(torch.eye(s, device=device) * torch.rand(s, device=device))
        for s in sizes
    ]
    a_batch, b_batch = batch.pack(a), batch.pack(b)

    w_ref = batch.pack(
        [torch.tensor(linalg.eigh(i.sft(), j.sft())[0]) for i, j in zip(a, b)])

    aux_settings = [True, False]
    schemes = ['chol', 'lowd']
    for scheme in schemes:
        for aux in aux_settings:
            w_calc = maths.eighb(a_batch, b_batch, scheme=scheme, aux=aux)[0]

            mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))

            same_device = w_calc.device == device

            assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}'
            assert same_device, 'Device persistence check'
예제 #3
0
 def eigen_proxy(m, target_method, size_data=None):
     m = maths.sym(m)
     if size_data is not None:
         m = clean_zero_padding(m, size_data)
     if target_method is None:
         return torch.symeig(m, True)
     else:
         return maths.eighb(m, broadening_method=target_method)
예제 #4
0
def test_eighb_standard_single(device):
    """eighb accuracy on a single standard eigenvalue problem."""
    a = maths.sym(torch.rand(10, 10, device=device))

    w_ref = linalg.eigh(a.sft())[0]

    w_calc, v_calc = maths.eighb(a)

    mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))
    mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0)))
    same_device = w_calc.device == device == v_calc.device

    assert mae_w < 1E-12, 'Eigenvalue tolerance test'
    assert mae_v < 1E-12, 'Eigenvector orthogonality test'
    assert same_device, 'Device persistence check'
예제 #5
0
def test_eighb_standard_batch(device):
    """eighb accuracy on a batch of standard eigenvalue problems."""
    sizes = torch.randint(2, 10, (11, ), device=device)
    a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes]
    a_batch = batch.pack(a)

    w_ref = batch.pack([torch.tensor(linalg.eigh(i.cpu())[0]) for i in a])

    w_calc = maths.eighb(a_batch)[0]

    mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))

    same_device = w_calc.device == device

    assert mae_w < 1E-12, 'Eigenvalue tolerance test'
    assert same_device, 'Device persistence check'
예제 #6
0
def test_eighb_general_single(device):
    """eighb accuracy on a single general eigenvalue problem."""
    a = maths.sym(torch.rand(10, 10, device=device))
    b = maths.sym(torch.eye(10, device=device) * torch.rand(10, device=device))

    w_ref = linalg.eigh(a.sft(), b.sft())[0]

    schemes = ['chol', 'lowd']
    for scheme in schemes:
        w_calc, v_calc = maths.eighb(a, b, scheme=scheme)

        mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref))
        mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0)))
        same_device = w_calc.device == device == v_calc.device

        assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}'
        assert mae_v < 1E-12, f'Eigenvector orthogonality test {scheme}'
        assert same_device, 'Device persistence check'