示例#1
0
def triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                          transpose_a, conjugate_a):
    shape = c.GetShape(a)
    if len(shape.dimensions()) == 2 and shape.element_type() == np.float32:
        return lapack.jax_trsm(c, c.ConstantF32Scalar(1.0), a, b, left_side,
                               lower, transpose_a, conjugate_a)
    elif len(shape.dimensions()) == 2 and shape.element_type() == np.float64:
        return lapack.jax_trsm(c, c.ConstantF64Scalar(1.0), a, b, left_side,
                               lower, transpose_a, conjugate_a)
    else:
        # Fall back to the HLO implementation for batched triangular_solve or
        # unsupported types.
        # TODO(phawkins): support BLAS primitives in batched mode.
        return c.TriangularSolve(a, b, left_side, lower, transpose_a,
                                 conjugate_a)
示例#2
0
def triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                          transpose_a, conjugate_a):
    shape = c.GetShape(a)
    dtype = shape.element_type().type
    if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
        return lapack.jax_trsm(c, c.Constant(onp.array(1, dtype=dtype)), a, b,
                               left_side, lower, transpose_a, conjugate_a)
    else:
        # Fall back to the HLO implementation for batched triangular_solve or
        # unsupported types.
        # TODO(phawkins): support BLAS primitives in batched mode.
        return c.TriangularSolve(a, b, left_side, lower, transpose_a,
                                 conjugate_a)