def testSolveTriangular(self, lower, left_side, transpose_a, lhs_shape, rhs_shape, dtype, rng): # pylint: disable=invalid-name T = lambda X: onp.swapaxes(X, -1, -2) K = rng(lhs_shape, dtype) L = onp.linalg.cholesky( onp.matmul(K, T(K)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) L = L.astype(K.dtype) B = rng(rhs_shape, dtype) A = L if lower else T(L) inv = onp.linalg.inv(T(A) if transpose_a else A) np_ans = onp.matmul(inv, B) if left_side else onp.matmul(B, inv) lapax_ans = lapax.solve_triangular(L if lower else T(L), B, left_side, lower, transpose_a) self.assertAllClose(np_ans, lapax_ans, check_dtypes=False)
def testSolveLowerTriangularBroadcasting(self): npr = onp.random.RandomState(1) lhs = onp.tril(npr.randn(3, 3, 3)) lhs2 = onp.tril(npr.randn(3, 3, 3)) rhs = npr.randn(3, 3, 2) rhs2 = npr.randn(3, 3, 2) def check(fun, lhs, rhs): a1 = onp.linalg.solve(lhs, rhs) a2 = fun(lhs, rhs) a3 = fun(lhs, rhs) self.assertArraysAllClose(a1, a2, check_dtypes=True) self.assertArraysAllClose(a2, a3, check_dtypes=True) solve_triangular = lambda a, b: lapax.solve_triangular( a, b, left_side=True, lower=True, trans_a=False) fun = jit(solve_triangular) check(fun, lhs, rhs) check(fun, lhs2, rhs2)