def coo_todense_mhlo(data, row, col, *, shape, data_dtype, index_dtype): """COO to dense matrix.""" data_type, _, nnz = _validate_coo_mhlo(data, row, col, shape) rows, cols = shape buffer_size, opaque = _hipsparse.build_coo_todense_descriptor( data_dtype, index_dtype, rows, cols, nnz) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get(shape, data_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [data, row, col], call_target_name=ir.StringAttr.get("hipsparse_coo_todense"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 3), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def threefry2x32_lowering(keys, data): """ThreeFry2x32 kernel for GPU.""" assert len(keys) == 2, keys assert len(data) == 2, data assert (ir.RankedTensorType(keys[0].type).element_type == ir.IntegerType.get_unsigned(32)), keys[0].type typ = keys[0].type dims = ir.RankedTensorType(typ).shape for x in itertools.chain(keys, data): assert x.type == typ, (x.type, typ) ndims = len(dims) opaque = _cuda_prng.cuda_threefry2x32_descriptor(_prod(dims)) layout = ir.DenseIntElementsAttr.get(np.arange(ndims - 1, -1, -1), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) tup = mhlo.CustomCallOp( [ir.TupleType.get_tuple([typ, typ])], [keys[0], keys[1], data[0], data[1]], call_target_name = ir.StringAttr.get("cuda_threefry2x32"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout] * 4), result_layouts=ir.ArrayAttr.get([layout] * 2)).result return [ mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, i)).result for i in range(2) ]
def lu_pivots_to_permutation_mhlo(pivots, *, permutation_size): """Kernel for the transformation of pivots to permutations on GPU.""" typ = ir.RankedTensorType(pivots.type) dims = typ.shape i32_type = ir.IntegerType.get_signless(32) assert typ.element_type == i32_type, typ batch_size = _prod(dims[:-1]) pivot_size = dims[-1] opaque = _hip_linalg.hip_lu_pivots_to_permutation_descriptor( batch_size, pivot_size, permutation_size) pivots_layout = ir.DenseIntElementsAttr.get(np.arange( len(dims) - 1, -1, -1), type=ir.IndexType.get()) permutations_layout = pivots_layout permutations_dims = list(dims) permutations_dims[-1] = permutation_size permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type) return mhlo.CustomCallOp( [permutations_type], [pivots], call_target_name=ir.StringAttr.get("hip_lu_pivots_to_permutation"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([pivots_layout]), result_layouts=ir.ArrayAttr.get([permutations_layout])).result
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None, compute_type=None, index_dtype, data_dtype, B_dtype): """CSR from dense matrix.""" data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) rows, cols = shape B_shape = ir.RankedTensorType(B.type).shape _, Ccols = B_shape if compute_dtype is None: compute_dtype = data_dtype compute_type = data_type buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor( data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols, nnz, transpose) out_size = cols if transpose else rows i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get([out_size, Ccols], compute_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [data, indices, indptr, B], call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def syevd_mhlo(dtype, a, lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if n <= 32: kernel = b"cusolver_syevj" lwork, opaque = _cusolver.build_syevj_descriptor( np.dtype(dtype), lower, batch, n) else: kernel = b"cusolver_syevd" lwork, opaque = _cusolver.build_syevd_descriptor( np.dtype(dtype), lower, batch, n) if ir.ComplexType.isinstance(a_type.element_type): eigvals_type = ir.ComplexType(a_type.element_type).element_type else: eigvals_type = a_type.element_type layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): """PocketFFT kernel for CPU.""" a_type = ir.RankedTensorType(a.type) n = len(a_type.shape) fft_lengths = list(fft_lengths) descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor( list(a_type.shape), dtype, fft_type, fft_lengths) if out_dtype == np.float32: out_type = ir.F32Type.get() elif out_dtype == np.float64: out_type = ir.F64Type.get() elif out_dtype == np.complex64: out_type = ir.ComplexType.get(ir.F32Type.get()) elif out_dtype == np.complex128: out_type = ir.ComplexType.get(ir.F64Type.get()) else: raise ValueError(f"Unknown output type {out_dtype}") if 0 in a_type.shape or 0 in out_shape: zero = mhlo.ConstOp( ir.RankedTensorType.get([], out_type), ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype), type=out_type)) if jax._src.lib.mlir_api_version < 9: return mhlo.BroadcastOp( ir.RankedTensorType.get(out_shape, out_type), zero, ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result else: return mhlo.BroadcastOp( zero, ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result u8_type = ir.IntegerType.get_unsigned(8) descriptor = mhlo.ConstOp( ir.RankedTensorType.get([len(descriptor_bytes)], u8_type), ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) layout = ir.DenseIntElementsAttr.get(np.arange(n - 1, -1, -1), type=ir.IndexType.get()) return mhlo.CustomCallOp( [ir.RankedTensorType.get(out_shape, out_type)], [descriptor, a], call_target_name=ir.StringAttr.get("pocketfft"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0], np.int64), type=ir.IndexType.get()), layout, ]), result_layouts=ir.ArrayAttr.get([layout])).result
def getrf_mhlo(dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) if batch > 1 and m == n and m // batch <= 128: lwork, opaque = _hipblas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) kernel = "hipblas_getrf_batched" else: lwork, opaque = _hipsolver.build_getrf_descriptor( np.dtype(dtype), batch, m, n) workspace = ir.RankedTensorType.get([lwork], a_type.element_type) kernel = "hipsolver_getrf" layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), workspace, ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def getrf_mhlo(dtype, a): dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d if dtype == np.float32: fn = b"lapack_sgetrf" elif dtype == np.float64: fn = b"lapack_dgetrf" elif dtype == np.complex64: fn = b"lapack_cgetrf" elif dtype == np.complex128: fn = b"lapack_zgetrf" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def trsm_mhlo(dtype, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): """Batched triangular solve. XLA implements unbatched triangular solve directly, so we need only implement the batched case.""" b_type = ir.RankedTensorType(b.type) dims = b_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) k = m if left_side else n a_type = ir.RankedTensorType(a.type) if (batch_dims + (k, k) != tuple(a_type.shape) or a_type.element_type != b_type.element_type): raise ValueError("Argument mismatch for trsm, got {} and {}".format( a_type, b_type)) if conj_a and not trans_a: raise NotImplementedError( "Conjugation without transposition not supported") lwork, opaque = _hipblas.build_trsm_batched_descriptor( np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) work_layout = ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) tup = mhlo.CustomCallOp( [ir.TupleType.get_tuple([b_type, work_type, work_type])], [a, b], call_target_name=ir.StringAttr.get("hipblas_trsm_batched"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout] * 2), result_layouts=ir.ArrayAttr.get([layout, work_layout, work_layout])).result return mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result
def potrf_mhlo(dtype, a, lower=False): a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] if m != n: raise ValueError( "potrf expects a square matrix, got {}".format(a_type)) if dtype == np.float32: fn = b"lapack_spotrf" elif dtype == np.float64: fn = b"lapack_dpotrf" elif dtype == np.complex64: fn = b"lapack_cpotrf" elif dtype == np.complex128: fn = b"lapack_zpotrf" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) info_layout = ir.DenseIntElementsAttr.get(np.array( range(num_bd - 1, -1, -1)), type=ir.IndexType.get()) tup = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([layout, info_layout])).result return [ mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result, mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result, ]
def orgqr_mhlo(dtype, a, tau): """Product of elementary Householder reflections.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) tau_dims = ir.RankedTensorType(tau.type).shape assert tau_dims[:-1] == dims[:-2] k = tau_dims[-1] lwork, opaque = _hipsolver.build_orgqr_descriptor(np.dtype(dtype), batch, m, n, k) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a, tau], call_target_name=ir.StringAttr.get(b"hipsolver_orgqr"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()), ]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(2) ]
def coo_matvec_mhlo(data, row, col, x, *, shape, transpose=False, compute_dtype=None, compute_type=None, index_dtype, data_dtype, x_dtype): """COO matrix/vector multiply.""" data_type, index_type, nnz = _validate_coo_mhlo(data, row, col, shape) rows, cols = shape if compute_dtype is None: compute_dtype = data_dtype compute_type = data_type buffer_size, opaque = _cusparse.build_coo_matvec_descriptor( data_dtype, x_dtype, compute_dtype, index_dtype, rows, cols, nnz, transpose) out_size = cols if transpose else rows i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get([out_size], compute_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [data, row, col, x], call_target_name=ir.StringAttr.get("cusparse_coo_matvec"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 4), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 2)) return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def custom_call( call_target_name: str, out_types: Sequence[ir.Type], operands: Sequence[ir.Value], operand_layouts: Sequence[Sequence[int]], result_layouts: Sequence[Sequence[int]], backend_config: Optional[str] = None, has_side_effect: bool = False, api_version: int = 2, ) -> Union[ir.Value, Sequence[ir.Value]]: """Less-verbose helper for building an MHLO custom call op. Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper may be able to go away. """ i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( (out_types if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]), operands, call_target_name=ir.StringAttr.get(call_target_name), has_side_effect=ir.BoolAttr.get(has_side_effect), backend_config=ir.StringAttr.get( "" if backend_config is None else backend_config), api_version=ir.IntegerAttr.get(i32_type, api_version), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.atleast_1d( np.asarray(l, dtype=np.int64)), type=ir.IndexType.get()) for l in operand_layouts ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.atleast_1d( np.asarray(l, dtype=np.int64)), type=ir.IndexType.get()) for l in result_layouts ])) if len(out_types) == 1: return out.result else: return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(len(out_types)) ]
def potrf_mhlo(dtype, a, lower): """Cholesky decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) lwork, opaque = _hipsolver.build_potrf_descriptor(np.dtype(dtype), lower, batch, n) kernel = b"hipsolver_potrf" layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) info_layout = ir.DenseIntElementsAttr.get(np.array( range(num_bd - 1, -1, -1)), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) info_type = ir.RankedTensorType.get(batch_dims, i32_type) work_layout = ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) tup = mhlo.CustomCallOp([ ir.TupleType.get_tuple([ a.type, info_type, ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get( ir.IntegerType.get_signless(32), 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get( [layout, info_layout, work_layout])).result return [ mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result, mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result, ]
def gtsv2_mhlo(dl, d, du, B, *, m, n, ldb, t): """Calls `hipsparse<t>gtsv2(dl, d, du, B, m, n, ldb)`.""" f32 = (t == np.float32) if f32: buffer_size = _hipsparse.gtsv2_f32_buffer_size(m, n, ldb) else: buffer_size = _hipsparse.gtsv2_f64_buffer_size(m, n, ldb) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get( [ldb, n], ir.F32Type.get() if f32 else ir.F64Type.get()), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [dl, d, du, B], call_target_name=ir.StringAttr.get("hipsparse_gtsv2_" + ("f32" if f32 else "f64")), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get( _hipsparse.build_gtsv2_descriptor(m, n, ldb)), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 3 + [ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()) ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype, index_type): """COO from dense matrix.""" mat_type = ir.RankedTensorType(mat.type) rows, cols = mat_type.shape buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor( data_dtype, index_dtype, rows, cols, nnz) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get([nnz], mat_type.element_type), ir.RankedTensorType.get([nnz], index_type), ir.RankedTensorType.get([nnz], index_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [mat], call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 4)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def syevd_mhlo(dtype, a, lower=False): a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) if dtype == np.float32: fn = b"lapack_ssyevd" eigvals_type = ir.F32Type.get() workspace = [ ir.RankedTensorType.get([_lapack.syevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.float64: fn = b"lapack_dsyevd" eigvals_type = ir.F64Type.get() workspace = [ ir.RankedTensorType.get([_lapack.syevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex64: fn = b"lapack_cheevd" eigvals_type = ir.F32Type.get() workspace = [ ir.RankedTensorType.get([_lapack.heevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex128: fn = b"lapack_zheevd" eigvals_type = ir.F64Type.get() workspace = [ ir.RankedTensorType.get([_lapack.heevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace) ], [_mhlo_s32(1 if lower else 0), _mhlo_s32(b), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ] + workspace_layouts)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def gees_mhlo(a, jobvs=True, sort=False, select=None): a_type = ir.RankedTensorType(a.type) etype = a_type.element_type dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) if sort: raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") jobvs = ord('V' if jobvs else 'N') sort = ord('S' if sort else 'N') if not ir.ComplexType.isinstance(etype): fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees" schurvecs_type = etype workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)] workspace_layouts = [layout] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2 eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 else: fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get()) else "lapack_zgees") schurvecs_type = etype workspaces = [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [ layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] i32_type = ir.IntegerType.get_signless(32) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple(workspaces + eigvals + [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [ _mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(np.uint8(jobvs)), _mhlo_u8(np.uint8(sort)), # TODO: figure out how to put the callable select function here a ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]), result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ])) i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i) if sort == ord('S'): return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(4)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result) else: return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result)
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): a_type = ir.RankedTensorType(a.type) b_type = ir.RankedTensorType(b.type) dims = b_type.shape m, n = dims[-2:] k = m if left_side else n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) num_b = 1 for d in batch_dims: num_b *= d if (batch_dims + (k, k) != tuple(a_type.shape) or a_type.element_type != b_type.element_type): raise ValueError("Argument mismatch for trsm, got {} and {}".format( a_type, b_type)) if dtype == np.float32: fn = "blas_strsm" elif dtype == np.float64: fn = "blas_dtrsm" elif dtype == np.complex64: fn = "blas_ctrsm" elif dtype == np.complex128: fn = "blas_ztrsm" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) if conj_a and not trans_a: raise NotImplementedError( "Conjugation without transposition not supported") scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) return mhlo.CustomCallOp( [b.type], [ _mhlo_s32(int(left_side)), _mhlo_s32(int(lower)), _mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b), alpha, a, b ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 8 + [layout] * 2), result_layouts=ir.ArrayAttr.get([layout])).result
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True): a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d i32_type = ir.IntegerType.get_signless(32) if dtype == np.float32: fn = b"lapack_sgesdd" singular_vals_type = ir.F32Type.get() lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.float64: fn = b"lapack_dgesdd" singular_vals_type = ir.F64Type.get() lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex64: fn = b"lapack_cgesdd" singular_vals_type = ir.F32Type.get() lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get( [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex128: fn = b"lapack_zgesdd" singular_vals_type = ir.F64Type.get() lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get( [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp([ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get( batch_dims + (m, m if full_matrices else min(m, n)), a_type.element_type), ir.RankedTensorType.get( batch_dims + (n if full_matrices else min(m, n), n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace) ], [ _mhlo_s32(int(full_matrices)), _mhlo_s32(int(compute_uv)), _mhlo_s32(b), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 6 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()), layout, layout, ir.DenseIntElementsAttr.get( np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ] + workspace_layouts)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(1, 5) ]
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < 32 and n < 32: # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = _cusolver.build_gesvdj_descriptor( np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0) k = min(m, n) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k if econ else m), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, k if econ else n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvdj"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)) info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result vt = mhlo.TransposeOp( v, ir.DenseIntElementsAttr.get( np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result if np.issubdtype(dtype, np.complexfloating): vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result if not full_matrices and not econ: u = mhlo.SliceOp( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (m, min(m, n)))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result vt = mhlo.SliceOp( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (min(m, n), n))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result elif m < n: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < n: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def geev_mhlo(dtype, a, jobvl=True, jobvr=True): dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) jobvl_c = ord('V' if jobvl else 'N') jobvr_c = ord('V' if jobvr else 'N') if dtype == np.float32: fn = b"lapack_sgeev" real = True eigvecs_type = ir.ComplexType.get(ir.F32Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.F32Type.get()), ir.RankedTensorType.get([n, n], ir.F32Type.get()), ir.RankedTensorType.get([n, n], ir.F32Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()) ] * 3 eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()), ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 elif dtype == np.float64: fn = b"lapack_dgeev" real = True eigvecs_type = ir.ComplexType.get(ir.F64Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.F64Type.get()), ir.RankedTensorType.get([n, n], ir.F64Type.get()), ir.RankedTensorType.get([n, n], ir.F64Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()) ] * 3 eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()), ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 elif dtype == np.complex64: fn = b"lapack_cgeev" real = False eigvecs_type = ir.ComplexType.get(ir.F32Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.ComplexType.get(ir.F32Type.get())), ir.RankedTensorType.get([2 * n], ir.F32Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) ] eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.ComplexType.get(ir.F32Type.get())) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] elif dtype == np.complex128: fn = b"lapack_zgeev" real = False eigvecs_type = ir.ComplexType.get(ir.F64Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.ComplexType.get(ir.F64Type.get())), ir.RankedTensorType.get([2 * n], ir.F64Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) ] eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.ComplexType.get(ir.F64Type.get())) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) i32_type = ir.IntegerType.get_signless(32) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) info_layout = ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple(workspaces + eigvals + [ ir.RankedTensorType.get(dims, eigvecs_type), ir.RankedTensorType.get(dims, eigvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(jobvl_c), _mhlo_u8(jobvr_c), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]), result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [layout] * 2 + [info_layout])).result i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i) if real: return (mhlo.ComplexOp(mhlo.GetTupleElementOp(out, i32_attr(3)), mhlo.GetTupleElementOp(out, i32_attr(4))).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result, mhlo.GetTupleElementOp(out, i32_attr(6)).result, mhlo.GetTupleElementOp(out, i32_attr(7)).result) else: return (mhlo.GetTupleElementOp(out, i32_attr(2)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(4)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result)