def triangular_solve_cpu_translation_rule(c, a, b, left_side, lower, transpose_a, conjugate_a): shape = c.GetShape(a) if len(shape.dimensions()) == 2 and shape.element_type() == np.float32: return lapack.jax_trsm(c, c.ConstantF32Scalar(1.0), a, b, left_side, lower, transpose_a, conjugate_a) elif len(shape.dimensions()) == 2 and shape.element_type() == np.float64: return lapack.jax_trsm(c, c.ConstantF64Scalar(1.0), a, b, left_side, lower, transpose_a, conjugate_a) else: # Fall back to the HLO implementation for batched triangular_solve or # unsupported types. # TODO(phawkins): support BLAS primitives in batched mode. return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a)
def triangular_solve_cpu_translation_rule(c, a, b, left_side, lower, transpose_a, conjugate_a): shape = c.GetShape(a) dtype = shape.element_type().type if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types: return lapack.jax_trsm(c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower, transpose_a, conjugate_a) else: # Fall back to the HLO implementation for batched triangular_solve or # unsupported types. # TODO(phawkins): support BLAS primitives in batched mode. return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a)