Esempio n. 1
0
def test_trsm_wrapper(mat, arr, dtype, order, device, lower, transpose):
    rtol = 1e-2 if dtype == np.float32 else 1e-11

    n_mat = move_tensor(fix_mat(mat, dtype=dtype, order=order, copy=True), device=device)
    n_arr = move_tensor(fix_mat(arr, dtype=dtype, order=order, copy=True), device=device)

    expected = sclb.dtrsm(1e-2, mat, arr, side=0, lower=lower, trans_a=transpose, overwrite_b=0)

    if device.startswith("cuda") and order == "C":
        with pytest.raises(ValueError):
            actual = trsm(n_arr, n_mat, alpha=1e-2, lower=lower, transpose=transpose)
    else:
        actual = trsm(n_arr, n_mat, alpha=1e-2, lower=lower, transpose=transpose)
        np.testing.assert_allclose(expected, actual.cpu().numpy(), rtol=rtol)
Esempio n. 2
0
    def test_trsm(self, mat, vec, solution, alpha, dtype, order_v, order_A,
                  device):
        mat = move_tensor(fix_mat(mat, dtype, order_A, copy=True, numpy=False),
                          device=device)
        vec = move_tensor(fix_mat(vec, dtype, order_v, copy=True, numpy=False),
                          device=device)

        sol_vec, lower, trans = solution
        out = trsm(vec, mat, alpha, lower=int(lower), transpose=int(trans))

        assert out.data_ptr() != vec.data_ptr(), "Vec was overwritten."
        assert out.device == vec.device, "Output device is incorrect."
        assert out.stride() == vec.stride(), "Stride was modified."
        assert out.dtype == vec.dtype, "Dtype was modified."
        np.testing.assert_allclose(sol_vec,
                                   out.cpu().numpy(),
                                   rtol=self.rtol[dtype])
Esempio n. 3
0
    def invTt(self, v: torch.Tensor) -> torch.Tensor:
        r"""Solve the system of equations :math:`T^\\top x = v` for unknown vector :math:`x`.

        Multiple right-hand sides are supported (by simply passing a 2D tensor for `v`)

        Parameters
        ----------
        v
            The right-hand side of the triangular system of equations

        Returns
        -------
        x
            The solution, computed with the `trsm` function.

        See Also
        --------
        :func:`falkon.preconditioner.pc_utils.trsm` : the function used to solve the system of equations
        """
        inplace_set_diag_th(self.fC, self.dT)
        return trsm(v, self.fC, alpha=1.0, lower=0, transpose=1)