示例#1
0
def _triangular_solve_gpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.GetShape(a)
  dtype = shape.element_type().type
  dims = shape.dimensions()
  m, n = dims[-2:]
  batch = prod(dims[:-2])
  if batch > 1 and m <= 32 and n <= 32:
    if conjugate_a and not transpose_a:
      a = c.Conj(a)
      conjugate_a = False
    return cusolver.trsm(
      c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
  else:
    # Use the XLA implementation for unbatched triangular_solve.
    return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
                             unit_diagonal)
示例#2
0
文件: linalg.py 项目: varun-alla/jax
def _triangular_solve_gpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.get_shape(a)
  dims = shape.dimensions()
  m, n = dims[-2:]
  batch = prod(dims[:-2])
  if conjugate_a and not transpose_a:
    a = xops.Conj(a)
    conjugate_a = False
  if batch > 1 and m <= 32 and n <= 32:
    return cusolver.trsm(
      c, a, b, left_side, lower, transpose_a,
      conjugate_a, unit_diagonal)
  else:
    # Use the XLA implementation for unbatched triangular_solve.
    if not transpose_a:
      transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
    else:
      transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
                   else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
    return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                transpose)