def potrf_mhlo(dtype, a, lower=False): a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] if m != n: raise ValueError(f"potrf expects a square matrix, got {a_type}") if dtype == np.float32: fn = b"lapack_spotrf" elif dtype == np.float64: fn = b"lapack_dpotrf" elif dtype == np.complex64: fn = b"lapack_cpotrf" elif dtype == np.complex128: fn = b"lapack_zpotrf" else: raise NotImplementedError(f"Unsupported dtype {dtype}") batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) info_layout = tuple(range(num_bd - 1, -1, -1)) out = custom_call(fn, [ a.type, ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32)) ], [_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a], operand_layouts=[scalar_layout] * 3 + [layout], result_layouts=[layout, info_layout]) return out[:2]
def _geqrf_mhlo(platform, gpu_solver, dtype, a): """QR decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) lwork, opaque = gpu_solver.build_geqrf_descriptor(np.dtype(dtype), batch, m, n) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) out = custom_call(f"{platform}solver_geqrf", [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], ]) return out[:3]
def _potrf_mhlo(platform, gpu_solver, dtype, a, lower): """Cholesky decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) lwork, opaque = gpu_solver.build_potrf_descriptor(np.dtype(dtype), lower, batch, n) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) info_layout = tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) info_type = ir.RankedTensorType.get(batch_dims, i32_type) work_layout = [0] out = custom_call(f"{platform}solver_potrf", [ a.type, info_type, ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), ], [a], backend_config=opaque, operand_layouts=[layout], result_layouts=[layout, info_layout, work_layout]) return out[:2]
def csr_matmat_mhlo(data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None, compute_type=None, index_dtype, data_dtype, B_dtype): """CSR from dense matrix.""" data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) rows, cols = shape B_shape = ir.RankedTensorType(B.type).shape _, Ccols = B_shape if compute_dtype is None: compute_dtype = data_dtype compute_type = data_type buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor( data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols, nnz, transpose) out_size = cols if transpose else rows i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get([out_size, Ccols], compute_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [data, indices, indptr, B], call_target_name=ir.StringAttr.get("hipsparse_csr_matmat"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 0)).result
def syevd_mhlo(dtype, a, lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if n <= 32: kernel = b"cusolver_syevj" lwork, opaque = _cusolver.build_syevj_descriptor( np.dtype(dtype), lower, batch, n) else: kernel = b"cusolver_syevd" lwork, opaque = _cusolver.build_syevd_descriptor( np.dtype(dtype), lower, batch, n) if ir.ComplexType.isinstance(a_type.element_type): eigvals_type = ir.ComplexType(a_type.element_type).element_type else: eigvals_type = a_type.element_type layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def trsm_mhlo(dtype, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): """Batched triangular solve. XLA implements unbatched triangular solve directly, so we need only implement the batched case.""" b_type = ir.RankedTensorType(b.type) dims = b_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) k = m if left_side else n a_type = ir.RankedTensorType(a.type) if (batch_dims + (k, k) != tuple(a_type.shape) or a_type.element_type != b_type.element_type): raise ValueError("Argument mismatch for trsm, got {} and {}".format( a_type, b_type)) if conj_a and not trans_a: raise NotImplementedError("Conjugation without transposition not supported") lwork, opaque = _hipblas.build_trsm_batched_descriptor( np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) work_type = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) work_layout = ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) tup = mhlo.CustomCallOp( [ir.TupleType.get_tuple([b_type, work_type, work_type])], [a, b], call_target_name = ir.StringAttr.get("hipblas_trsm_batched"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout] * 2), result_layouts=ir.ArrayAttr.get([layout, work_layout, work_layout])).result return mhlo.GetTupleElementOp( tup, ir.IntegerAttr.get(i32_type, 0)).result
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): a_type = ir.RankedTensorType(a.type) b_type = ir.RankedTensorType(b.type) dims = b_type.shape m, n = dims[-2:] k = m if left_side else n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) num_b = 1 for d in batch_dims: num_b *= d if (batch_dims + (k, k) != tuple(a_type.shape) or a_type.element_type != b_type.element_type): raise ValueError("Argument mismatch for trsm, got {} and {}".format( a_type, b_type)) if dtype == np.float32: fn = "blas_strsm" elif dtype == np.float64: fn = "blas_dtrsm" elif dtype == np.complex64: fn = "blas_ctrsm" elif dtype == np.complex128: fn = "blas_ztrsm" else: raise NotImplementedError(f"Unsupported dtype {dtype}") if conj_a and not trans_a: raise NotImplementedError("Conjugation without transposition not supported") scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) return custom_call( fn, [b.type], [_mhlo_s32(int(left_side)), _mhlo_s32(int(lower)), _mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b), alpha, a, b], operand_layouts=[scalar_layout] * 8 + [layout] * 2, result_layouts=[layout])
def getrf_mhlo(dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) if batch > 1 and m == n and m // batch <= 128: lwork, opaque = _hipblas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) kernel = "hipblas_getrf_batched" else: lwork, opaque = _hipsolver.build_getrf_descriptor( np.dtype(dtype), batch, m, n) workspace = ir.RankedTensorType.get([lwork], a_type.element_type) kernel = "hipsolver_getrf" layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), workspace, ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def getrf_mhlo(dtype, a): dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d if dtype == np.float32: fn = b"lapack_sgetrf" elif dtype == np.float64: fn = b"lapack_dgetrf" elif dtype == np.complex64: fn = b"lapack_cgetrf" elif dtype == np.complex128: fn = b"lapack_zgetrf" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def _threefry2x32_lowering(prng, platform, keys, data): """ThreeFry2x32 kernel for GPU.""" assert len(keys) == 2, keys assert len(data) == 2, data assert (ir.RankedTensorType( keys[0].type).element_type == ir.IntegerType.get_unsigned(32) ), keys[0].type typ = keys[0].type dims = ir.RankedTensorType(typ).shape for x in itertools.chain(keys, data): assert x.type == typ, (x.type, typ) ndims = len(dims) opaque = prng.threefry2x32_descriptor(_prod(dims)) layout = tuple(range(ndims - 1, -1, -1)) return custom_call(f"{platform}_threefry2x32", [typ, typ], [keys[0], keys[1], data[0], data[1]], backend_config=opaque, operand_layouts=[layout] * 4, result_layouts=[layout] * 2)
def potrf_mhlo(dtype, a, lower=False): a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] if m != n: raise ValueError( "potrf expects a square matrix, got {}".format(a_type)) if dtype == np.float32: fn = b"lapack_spotrf" elif dtype == np.float64: fn = b"lapack_dpotrf" elif dtype == np.complex64: fn = b"lapack_cpotrf" elif dtype == np.complex128: fn = b"lapack_zpotrf" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) info_layout = ir.DenseIntElementsAttr.get(np.array( range(num_bd - 1, -1, -1)), type=ir.IndexType.get()) tup = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([layout, info_layout])).result return [ mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result, mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result, ]
def geqrf_mhlo(dtype, a): a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d if dtype == np.float32: fn = b"lapack_sgeqrf" lwork = _lapack.lapack_sgeqrf_workspace(m, n) elif dtype == np.float64: fn = b"lapack_dgeqrf" lwork = _lapack.lapack_dgeqrf_workspace(m, n) elif dtype == np.complex64: fn = b"lapack_cgeqrf" lwork = _lapack.lapack_cgeqrf_workspace(m, n) elif dtype == np.complex128: fn = b"lapack_zgeqrf" lwork = _lapack.lapack_zgeqrf_workspace(m, n) else: raise NotImplementedError(f"Unsupported dtype {dtype}") scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) out = custom_call(fn, [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a], operand_layouts=[scalar_layout] * 4 + [layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], ]) return out[:3]
def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if have_jacobi_solver and n <= 32: kernel = f"{platform}solver_syevj" lwork, opaque = gpu_solver.build_syevj_descriptor( np.dtype(dtype), lower, batch, n) else: kernel = f"{platform}solver_syevd" lwork, opaque = gpu_solver.build_syevd_descriptor( np.dtype(dtype), lower, batch, n) if ir.ComplexType.isinstance(a_type.element_type): eigvals_type = ir.ComplexType(a_type.element_type).element_type else: eigvals_type = a_type.element_type i32_type = ir.IntegerType.get_signless(32) out = custom_call(kernel, [ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], ]) return out[:3]
def potrf_mhlo(dtype, a, lower): """Cholesky decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) lwork, opaque = _hipsolver.build_potrf_descriptor(np.dtype(dtype), lower, batch, n) kernel = b"hipsolver_potrf" layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) info_layout = ir.DenseIntElementsAttr.get(np.array( range(num_bd - 1, -1, -1)), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) info_type = ir.RankedTensorType.get(batch_dims, i32_type) work_layout = ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) tup = mhlo.CustomCallOp([ ir.TupleType.get_tuple([ a.type, info_type, ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get( ir.IntegerType.get_signless(32), 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get( [layout, info_layout, work_layout])).result return [ mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 0)).result, mhlo.GetTupleElementOp(tup, ir.IntegerAttr.get(i32_type, 1)).result, ]
def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype, data_dtype, index_type): """CSR from dense matrix.""" mat_type = ir.RankedTensorType(mat.type) rows, cols = mat_type.shape buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( data_dtype, index_dtype, rows, cols, nnz) out = custom_call(f"{platform}sparse_csr_fromdense", [ ir.RankedTensorType.get([nnz], mat_type.element_type), ir.RankedTensorType.get([nnz], index_type), ir.RankedTensorType.get([rows + 1], index_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ], [mat], backend_config=opaque, operand_layouts=[[1, 0]], result_layouts=[[0]] * 4) return out[:3]
def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) if batch > 1 and m == n and m // batch <= 128: lwork, opaque = gpu_blas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) kernel = f"{platform}blas_getrf_batched" else: lwork, opaque = gpu_solver.build_getrf_descriptor( np.dtype(dtype), batch, m, n) workspace = ir.RankedTensorType.get([lwork], a_type.element_type) kernel = f"{platform}solver_getrf" layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) out = custom_call(kernel, [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), workspace, ], [a], backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], ]) return out[:3]
def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None, compute_type=None, index_dtype, data_dtype, B_dtype): """CSR from dense matrix.""" data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) rows, cols = shape B_shape = ir.RankedTensorType(B.type).shape _, Ccols = B_shape if compute_dtype is None: compute_dtype = data_dtype compute_type = data_type buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( data_dtype, B_dtype, compute_dtype, index_dtype, rows, cols, Ccols, nnz, transpose) out_size = cols if transpose else rows out = custom_call(f"{platform}sparse_csr_matmat", [ ir.RankedTensorType.get([out_size, Ccols], compute_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ], [data, indices, indptr, B], backend_config=opaque, operand_layouts=[[0], [0], [0], [1, 0]], result_layouts=[[1, 0], [0]]) return out[0]
def coo_fromdense_mhlo(mat, *, nnz, data_dtype, index_dtype, index_type): """COO from dense matrix.""" mat_type = ir.RankedTensorType(mat.type) rows, cols = mat_type.shape buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor( data_dtype, index_dtype, rows, cols, nnz) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ ir.RankedTensorType.get([nnz], mat_type.element_type), ir.RankedTensorType.get([nnz], index_type), ir.RankedTensorType.get([nnz], index_type), ir.RankedTensorType.get([buffer_size], ir.IntegerType.get_signless(8)), ]) ], [mat], call_target_name=ir.StringAttr.get("hipsparse_coo_fromdense"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([1, 0]), type=ir.IndexType.get()), ]), result_layouts=ir.ArrayAttr.get([ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] * 4)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_size): """Kernel for the transformation of pivots to permutations on GPU.""" typ = ir.RankedTensorType(pivots.type) dims = typ.shape i32_type = ir.IntegerType.get_signless(32) assert typ.element_type == i32_type, typ batch_size = _prod(dims[:-1]) pivot_size = dims[-1] opaque = gpu_linalg.lu_pivots_to_permutation_descriptor( batch_size, pivot_size, permutation_size) pivots_layout = tuple(range(len(dims) - 1, -1, -1)) permutations_layout = pivots_layout permutations_dims = list(dims) permutations_dims[-1] = permutation_size permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type) return custom_call(f"{platform}_lu_pivots_to_permutation", [permutations_type], [pivots], backend_config=opaque, operand_layouts=[pivots_layout], result_layouts=[permutations_layout])
def syevd_mhlo(dtype, a, lower=False): a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) if dtype == np.float32: fn = b"lapack_ssyevd" eigvals_type = ir.F32Type.get() workspace = [ ir.RankedTensorType.get([_lapack.syevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.float64: fn = b"lapack_dsyevd" eigvals_type = ir.F64Type.get() workspace = [ ir.RankedTensorType.get([_lapack.syevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex64: fn = b"lapack_cheevd" eigvals_type = ir.F32Type.get() workspace = [ ir.RankedTensorType.get([_lapack.heevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex128: fn = b"lapack_zheevd" eigvals_type = ir.F64Type.get() workspace = [ ir.RankedTensorType.get([_lapack.heevd_work_size(n)], a_type.element_type), ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type), ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace) ], [_mhlo_s32(1 if lower else 0), _mhlo_s32(b), _mhlo_s32(n), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 3 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ] + workspace_layouts)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < n: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): """PocketFFT kernel for CPU.""" a_type = ir.RankedTensorType(a.type) n = len(a_type.shape) fft_lengths = list(fft_lengths) descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor( list(a_type.shape), dtype, fft_type, fft_lengths) if out_dtype == np.float32: out_type = ir.F32Type.get() elif out_dtype == np.float64: out_type = ir.F64Type.get() elif out_dtype == np.complex64: out_type = ir.ComplexType.get(ir.F32Type.get()) elif out_dtype == np.complex128: out_type = ir.ComplexType.get(ir.F64Type.get()) else: raise ValueError(f"Unknown output type {out_dtype}") if 0 in a_type.shape or 0 in out_shape: if xla_client._version >= 64: if jax._src.lib.mlir_api_version < 21: zero = mhlo.ConstOp( ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype), type=out_type)) else: zero = mhlo.ConstantOp( ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype), type=out_type)) else: if jax._src.lib.mlir_api_version < 21: zero = mhlo.ConstOp( ir.RankedTensorType.get([], out_type), ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype), type=out_type)) else: zero = mhlo.ConstantOp( ir.RankedTensorType.get([], out_type), ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype), type=out_type)) return mhlo.BroadcastOp( zero, ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result u8_type = ir.IntegerType.get_unsigned(8) if xla_client._version >= 64: if jax._src.lib.mlir_api_version < 21: descriptor = mhlo.ConstOp( ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) else: descriptor = mhlo.ConstantOp( ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) else: if jax._src.lib.mlir_api_version < 21: descriptor = mhlo.ConstOp( ir.RankedTensorType.get([len(descriptor_bytes)], u8_type), ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) else: descriptor = mhlo.ConstantOp( ir.RankedTensorType.get([len(descriptor_bytes)], u8_type), ir.DenseElementsAttr.get(np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) layout = tuple(range(n - 1, -1, -1)) return custom_call("pocketfft", [ir.RankedTensorType.get(out_shape, out_type)], [descriptor, a], operand_layouts=[[0], layout], result_layouts=[layout])
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < 32 and n < 32: # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = _cusolver.build_gesvdj_descriptor( np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0) k = min(m, n) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k if econ else m), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, k if econ else n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvdj"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)) info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result vt = mhlo.TransposeOp( v, ir.DenseIntElementsAttr.get( np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result if np.issubdtype(dtype, np.complexfloating): vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result if not full_matrices and not econ: u = mhlo.SliceOp( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (m, min(m, n)))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result vt = mhlo.SliceOp( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (min(m, n), n))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result elif m < n: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True): a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d i32_type = ir.IntegerType.get_signless(32) if dtype == np.float32: fn = b"lapack_sgesdd" singular_vals_type = ir.F32Type.get() lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.float64: fn = b"lapack_dgesdd" singular_vals_type = ir.F64Type.get() lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex64: fn = b"lapack_cgesdd" singular_vals_type = ir.F32Type.get() lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get( [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] elif dtype == np.complex128: fn = b"lapack_zgesdd" singular_vals_type = ir.F64Type.get() lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) workspace = [ ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type), ir.RankedTensorType.get( [_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), ir.RankedTensorType.get([lwork], a_type.element_type), ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp([ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get( batch_dims + (m, m if full_matrices else min(m, n)), a_type.element_type), ir.RankedTensorType.get( batch_dims + (n if full_matrices else min(m, n), n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace) ], [ _mhlo_s32(int(full_matrices)), _mhlo_s32(int(compute_uv)), _mhlo_s32(b), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 6 + [layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()), layout, layout, ir.DenseIntElementsAttr.get( np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ] + workspace_layouts)) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(1, 5) ]
def gees_mhlo(a, jobvs=True, sort=False, select=None): a_type = ir.RankedTensorType(a.type) etype = a_type.element_type dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) if sort: raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") jobvs = ord('V' if jobvs else 'N') sort = ord('S' if sort else 'N') if not ir.ComplexType.isinstance(etype): fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees" schurvecs_type = etype workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)] workspace_layouts = [layout] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2 eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 else: fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get()) else "lapack_zgees") schurvecs_type = etype workspaces = [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [ layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] i32_type = ir.IntegerType.get_signless(32) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple(workspaces + eigvals + [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [ _mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(np.uint8(jobvs)), _mhlo_u8(np.uint8(sort)), # TODO: figure out how to put the callable select function here a ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]), result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ])) i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i) if sort == ord('S'): return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(4)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result) else: return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result)
def geev_mhlo(dtype, a, jobvl=True, jobvr=True): dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) jobvl_c = ord('V' if jobvl else 'N') jobvr_c = ord('V' if jobvr else 'N') if dtype == np.float32: fn = b"lapack_sgeev" real = True eigvecs_type = ir.ComplexType.get(ir.F32Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.F32Type.get()), ir.RankedTensorType.get([n, n], ir.F32Type.get()), ir.RankedTensorType.get([n, n], ir.F32Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()) ] * 3 eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()), ir.RankedTensorType.get(batch_dims + (n, ), ir.F32Type.get()) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 elif dtype == np.float64: fn = b"lapack_dgeev" real = True eigvecs_type = ir.ComplexType.get(ir.F64Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.F64Type.get()), ir.RankedTensorType.get([n, n], ir.F64Type.get()), ir.RankedTensorType.get([n, n], ir.F64Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()) ] * 3 eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()), ir.RankedTensorType.get(batch_dims + (n, ), ir.F64Type.get()) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 elif dtype == np.complex64: fn = b"lapack_cgeev" real = False eigvecs_type = ir.ComplexType.get(ir.F32Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.ComplexType.get(ir.F32Type.get())), ir.RankedTensorType.get([2 * n], ir.F32Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) ] eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.ComplexType.get(ir.F32Type.get())) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] elif dtype == np.complex128: fn = b"lapack_zgeev" real = False eigvecs_type = ir.ComplexType.get(ir.F64Type.get()) workspaces = [ ir.RankedTensorType.get([n, n], ir.ComplexType.get(ir.F64Type.get())), ir.RankedTensorType.get([2 * n], ir.F64Type.get()) ] workspace_layouts = [ ir.DenseIntElementsAttr.get(np.array([0, 1]), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()) ] eigvals = [ ir.RankedTensorType.get(batch_dims + (n, ), ir.ComplexType.get(ir.F64Type.get())) ] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) i32_type = ir.IntegerType.get_signless(32) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) info_layout = ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple(workspaces + eigvals + [ ir.RankedTensorType.get(dims, eigvecs_type), ir.RankedTensorType.get(dims, eigvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [_mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(jobvl_c), _mhlo_u8(jobvr_c), a], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]), result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [layout] * 2 + [info_layout])).result i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i) if real: return (mhlo.ComplexOp(mhlo.GetTupleElementOp(out, i32_attr(3)), mhlo.GetTupleElementOp(out, i32_attr(4))).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result, mhlo.GetTupleElementOp(out, i32_attr(6)).result, mhlo.GetTupleElementOp(out, i32_attr(7)).result) else: return (mhlo.GetTupleElementOp(out, i32_attr(2)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(4)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result)
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): a_type = ir.RankedTensorType(a.type) b_type = ir.RankedTensorType(b.type) dims = b_type.shape m, n = dims[-2:] k = m if left_side else n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) num_b = 1 for d in batch_dims: num_b *= d if (batch_dims + (k, k) != tuple(a_type.shape) or a_type.element_type != b_type.element_type): raise ValueError("Argument mismatch for trsm, got {} and {}".format( a_type, b_type)) if dtype == np.float32: fn = "blas_strsm" elif dtype == np.float64: fn = "blas_dtrsm" elif dtype == np.complex64: fn = "blas_ctrsm" elif dtype == np.complex128: fn = "blas_ztrsm" else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) if conj_a and not trans_a: raise NotImplementedError( "Conjugation without transposition not supported") scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) return mhlo.CustomCallOp( [b.type], [ _mhlo_s32(int(left_side)), _mhlo_s32(int(lower)), _mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b), alpha, a, b ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 8 + [layout] * 2), result_layouts=ir.ArrayAttr.get([layout])).result