def _triangular_solve_gpu_translation_rule( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.GetShape(a) dtype = shape.element_type().type dims = shape.dimensions() m, n = dims[-2:] batch = prod(dims[:-2]) if batch > 1 and m <= 32 and n <= 32: if conjugate_a and not transpose_a: a = c.Conj(a) conjugate_a = False return cusolver.trsm( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Use the XLA implementation for unbatched triangular_solve. return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
def _triangular_solve_gpu_translation_rule( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.get_shape(a) dims = shape.dimensions() m, n = dims[-2:] batch = prod(dims[:-2]) if conjugate_a and not transpose_a: a = xops.Conj(a) conjugate_a = False if batch > 1 and m <= 32 and n <= 32: return cusolver.trsm( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Use the XLA implementation for unbatched triangular_solve. if not transpose_a: transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE else: transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a else xops.TriangularSolveOptions_Transpose.TRANSPOSE) return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)