def eigen_proxy(m, n, target_scheme, size_data=None): m, n = maths.sym(m), maths.sym(n) if size_data is not None: m = clean_zero_padding(m, size_data) n = clean_zero_padding(n, size_data) return maths.eighb(m, n, scheme=target_scheme)
def test_eighb_general_batch(device): """eighb accuracy on a batch of general eigenvalue problems.""" sizes = torch.randint(2, 10, (11, ), device=device) a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes] b = [ maths.sym(torch.eye(s, device=device) * torch.rand(s, device=device)) for s in sizes ] a_batch, b_batch = batch.pack(a), batch.pack(b) w_ref = batch.pack( [torch.tensor(linalg.eigh(i.sft(), j.sft())[0]) for i, j in zip(a, b)]) aux_settings = [True, False] schemes = ['chol', 'lowd'] for scheme in schemes: for aux in aux_settings: w_calc = maths.eighb(a_batch, b_batch, scheme=scheme, aux=aux)[0] mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref)) same_device = w_calc.device == device assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}' assert same_device, 'Device persistence check'
def eigen_proxy(m, target_method, size_data=None): m = maths.sym(m) if size_data is not None: m = clean_zero_padding(m, size_data) if target_method is None: return torch.symeig(m, True) else: return maths.eighb(m, broadening_method=target_method)
def test_eighb_standard_single(device): """eighb accuracy on a single standard eigenvalue problem.""" a = maths.sym(torch.rand(10, 10, device=device)) w_ref = linalg.eigh(a.sft())[0] w_calc, v_calc = maths.eighb(a) mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref)) mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0))) same_device = w_calc.device == device == v_calc.device assert mae_w < 1E-12, 'Eigenvalue tolerance test' assert mae_v < 1E-12, 'Eigenvector orthogonality test' assert same_device, 'Device persistence check'
def test_eighb_standard_batch(device): """eighb accuracy on a batch of standard eigenvalue problems.""" sizes = torch.randint(2, 10, (11, ), device=device) a = [maths.sym(torch.rand(s, s, device=device)) for s in sizes] a_batch = batch.pack(a) w_ref = batch.pack([torch.tensor(linalg.eigh(i.cpu())[0]) for i in a]) w_calc = maths.eighb(a_batch)[0] mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref)) same_device = w_calc.device == device assert mae_w < 1E-12, 'Eigenvalue tolerance test' assert same_device, 'Device persistence check'
def test_eighb_general_single(device): """eighb accuracy on a single general eigenvalue problem.""" a = maths.sym(torch.rand(10, 10, device=device)) b = maths.sym(torch.eye(10, device=device) * torch.rand(10, device=device)) w_ref = linalg.eigh(a.sft(), b.sft())[0] schemes = ['chol', 'lowd'] for scheme in schemes: w_calc, v_calc = maths.eighb(a, b, scheme=scheme) mae_w = torch.max(torch.abs(w_calc.cpu() - w_ref)) mae_v = torch.max(torch.abs((v_calc @ v_calc.T).fill_diagonal_(0))) same_device = w_calc.device == device == v_calc.device assert mae_w < 1E-12, f'Eigenvalue tolerance test {scheme}' assert mae_v < 1E-12, f'Eigenvector orthogonality test {scheme}' assert same_device, 'Device persistence check'