コード例 #1
0
    def testSolveTriangular(self, lower, left_side, transpose_a, lhs_shape,
                            rhs_shape, dtype, rng):
        # pylint: disable=invalid-name
        T = lambda X: onp.swapaxes(X, -1, -2)
        K = rng(lhs_shape, dtype)
        L = onp.linalg.cholesky(
            onp.matmul(K, T(K)) + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
        L = L.astype(K.dtype)
        B = rng(rhs_shape, dtype)

        A = L if lower else T(L)
        inv = onp.linalg.inv(T(A) if transpose_a else A)
        np_ans = onp.matmul(inv, B) if left_side else onp.matmul(B, inv)

        lapax_ans = lapax.solve_triangular(L if lower else T(L), B, left_side,
                                           lower, transpose_a)

        self.assertAllClose(np_ans, lapax_ans, check_dtypes=False)
コード例 #2
0
    def testSolveLowerTriangularBroadcasting(self):
        npr = onp.random.RandomState(1)
        lhs = onp.tril(npr.randn(3, 3, 3))
        lhs2 = onp.tril(npr.randn(3, 3, 3))
        rhs = npr.randn(3, 3, 2)
        rhs2 = npr.randn(3, 3, 2)

        def check(fun, lhs, rhs):
            a1 = onp.linalg.solve(lhs, rhs)
            a2 = fun(lhs, rhs)
            a3 = fun(lhs, rhs)
            self.assertArraysAllClose(a1, a2, check_dtypes=True)
            self.assertArraysAllClose(a2, a3, check_dtypes=True)

        solve_triangular = lambda a, b: lapax.solve_triangular(
            a, b, left_side=True, lower=True, trans_a=False)

        fun = jit(solve_triangular)
        check(fun, lhs, rhs)
        check(fun, lhs2, rhs2)