def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.GetShape(a) dtype = shape.element_type().type if conjugate_a and not transpose_a: a = xops.Conj(a) conjugate_a = False if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types: return lapack.jax_trsm(xb.computation_builder_shim(c), xb.constant(c, onp.array(1, dtype=dtype)), a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Fall back to the HLO implementation for unsupported types or batching. # TODO: Consider swapping XLA for LAPACK in batched case if not transpose_a: transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE else: transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a else xops.TriangularSolveOptions_Transpose.TRANSPOSE) return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand, full_matrices): shape = c.GetShape(operand) dims = shape.dimensions() m, n = dims[-2:] batch_dims = dims[:-2] cs = xb.computation_builder_shim(c) r, tau, info_geqrf = geqrf_impl(cs, operand) if m < n: q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m], [1] * len(dims)) q, info_orgqr = orgqr_impl(cs, q, tau) elif not full_matrices: q, info_orgqr = orgqr_impl(cs, r, tau) r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n], [1] * len(dims)) else: padding_config = [(0, 0, 0)] * len(dims) padding_config[-1] = (0, m - n, 0) q = xops.Pad( r, xops.Constant(c, onp.array(0, dtype=shape.element_type())), xla_client.make_padding_config(padding_config)) q, info_orgqr = orgqr_impl(cs, q, tau) ok = xops.And( xops.Eq(info_geqrf, xops.ConstantLiteral(c, onp.array(0, onp.int32))), xops.Eq(info_orgqr, xops.ConstantLiteral(c, onp.array(0, onp.int32)))) q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q, _nan_like(c, q)) r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r, _nan_like(c, r)) return xops.Tuple(c, [q, r])
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] dtype = shape.element_type().type result, info = potrf_impl(xb.computation_builder_shim(c), operand, lower=True) ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32))) return _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), result, _nan_like(c, result))
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] lu, pivot, info = getrf_impl(xb.computation_builder_shim(c), operand) # Subtract 1 from the pivot to get 0-based indices. pivot = xops.Sub(pivot, xops.ConstantLiteral(c, onp.array(1, onp.int32))) ok = xops.Ge(info, xops.ConstantLiteral(c, onp.array(0, onp.int32))) lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu, _nan_like(c, lu)) return xops.Tuple(c, [lu, pivot])
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] v, w, info = syevd_impl(xb.computation_builder_shim(c), operand, lower=lower) ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32))) v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v, _nan_like(c, v)) w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w, _nan_like(c, w)) return xops.Tuple(c, [v, w])
def eig_cpu_translation_rule(c, operand): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] w, vl, vr, info = _cpu_geev(xb.computation_builder_shim(c), operand) ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32))) w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w, _nan_like(c, w)) vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl, _nan_like(c, vl)) vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr, _nan_like(c, vr)) return xops.Tuple(c, [w, vl, vr])
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2): shape = lax.broadcast_shapes( c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(), c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions()) rank = len(shape) def _broadcast(x): ndims = c.GetShape(x).rank() return xla_client.ops.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank))) return cuda_prng.threefry2x32( xla_bridge.computation_builder_shim(c), (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] s, u, vt, info = gesvd_impl(xb.computation_builder_shim(c), operand, full_matrices=full_matrices, compute_uv=compute_uv) ok = xops.Eq(info, xops.ConstantLiteral(c, onp.array(0, onp.int32))) s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s, _nan_like(c, s)) u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u, _nan_like(c, u)) vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt, _nan_like(c, vt)) return xops.Tuple(c, [s, u, vt])
def _triangular_solve_gpu_translation_rule( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.GetShape(a) dtype = shape.element_type().type dims = shape.dimensions() m, n = dims[-2:] batch = prod(dims[:-2]) if conjugate_a and not transpose_a: a = xops.Conj(a) conjugate_a = False if batch > 1 and m <= 32 and n <= 32: return cusolver.trsm( xb.computation_builder_shim(c), a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Use the XLA implementation for unbatched triangular_solve. if not transpose_a: transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE else: transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a else xops.TriangularSolveOptions_Transpose.TRANSPOSE) return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)