예제 #1
0
파일: lax_linalg.py 프로젝트: yotarok/jax
def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                           transpose_a, conjugate_a,
                                           unit_diagonal):
    shape = c.GetShape(a)
    dtype = shape.element_type().type

    if conjugate_a and not transpose_a:
        a = xops.Conj(a)
        conjugate_a = False
    if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
        return lapack.jax_trsm(xb.computation_builder_shim(c),
                               xb.constant(c, onp.array(1, dtype=dtype)), a, b,
                               left_side, lower, transpose_a, conjugate_a,
                               unit_diagonal)
    else:
        # Fall back to the HLO implementation for unsupported types or batching.
        # TODO: Consider swapping XLA for LAPACK in batched case
        if not transpose_a:
            transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
        else:
            transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT
                         if conjugate_a else
                         xops.TriangularSolveOptions_Transpose.TRANSPOSE)
        return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                    transpose)
예제 #2
0
def _triangular_solve_cpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.GetShape(a)
  dtype = shape.element_type().type
  if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
    if conjugate_a and not transpose_a:
      a = c.Conj(a)
      conjugate_a = False
    return lapack.jax_trsm(
      c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
                    transpose_a, conjugate_a, unit_diagonal)
  else:
    # Fall back to the HLO implementation for batched triangular_solve or
    # unsupported types.
    # TODO(phawkins): support BLAS primitives in batched mode.
    return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
                             unit_diagonal)
예제 #3
0
def _triangular_solve_cpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.GetShape(a)
  dtype = shape.element_type().type

  if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
    if conjugate_a and not transpose_a:
      a = c.Conj(a)
      conjugate_a = False
    return lapack.jax_trsm(
      c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
                    transpose_a, conjugate_a, unit_diagonal)
  else:
    # Fall back to the HLO implementation for unsupported types or batching.
    # TODO: Consider swapping XLA for LAPACK in batched case
    return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
                             unit_diagonal)