def test_trsm_wrapper(mat, arr, dtype, order, device, lower, transpose): rtol = 1e-2 if dtype == np.float32 else 1e-11 n_mat = move_tensor(fix_mat(mat, dtype=dtype, order=order, copy=True), device=device) n_arr = move_tensor(fix_mat(arr, dtype=dtype, order=order, copy=True), device=device) expected = sclb.dtrsm(1e-2, mat, arr, side=0, lower=lower, trans_a=transpose, overwrite_b=0) if device.startswith("cuda") and order == "C": with pytest.raises(ValueError): actual = trsm(n_arr, n_mat, alpha=1e-2, lower=lower, transpose=transpose) else: actual = trsm(n_arr, n_mat, alpha=1e-2, lower=lower, transpose=transpose) np.testing.assert_allclose(expected, actual.cpu().numpy(), rtol=rtol)
def test_trsm(self, mat, vec, solution, alpha, dtype, order_v, order_A, device): mat = move_tensor(fix_mat(mat, dtype, order_A, copy=True, numpy=False), device=device) vec = move_tensor(fix_mat(vec, dtype, order_v, copy=True, numpy=False), device=device) sol_vec, lower, trans = solution out = trsm(vec, mat, alpha, lower=int(lower), transpose=int(trans)) assert out.data_ptr() != vec.data_ptr(), "Vec was overwritten." assert out.device == vec.device, "Output device is incorrect." assert out.stride() == vec.stride(), "Stride was modified." assert out.dtype == vec.dtype, "Dtype was modified." np.testing.assert_allclose(sol_vec, out.cpu().numpy(), rtol=self.rtol[dtype])
def invTt(self, v: torch.Tensor) -> torch.Tensor: r"""Solve the system of equations :math:`T^\\top x = v` for unknown vector :math:`x`. Multiple right-hand sides are supported (by simply passing a 2D tensor for `v`) Parameters ---------- v The right-hand side of the triangular system of equations Returns ------- x The solution, computed with the `trsm` function. See Also -------- :func:`falkon.preconditioner.pc_utils.trsm` : the function used to solve the system of equations """ inplace_set_diag_th(self.fC, self.dT) return trsm(v, self.fC, alpha=1.0, lower=0, transpose=1)