def syevd_mhlo(dtype, a, lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if n <= 32: kernel = b"cusolver_syevj" lwork, opaque = _cusolver.build_syevj_descriptor( np.dtype(dtype), lower, batch, n) else: kernel = b"cusolver_syevd" lwork, opaque = _cusolver.build_syevd_descriptor( np.dtype(dtype), lower, batch, n) if ir.ComplexType.isinstance(a_type.element_type): eigvals_type = ir.ComplexType(a_type.element_type).element_type else: eigvals_type = a_type.element_type layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get(kernel), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([layout]), result_layouts=ir.ArrayAttr.get([ layout, ir.DenseIntElementsAttr.get(np.array(range(num_bd, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array(range(num_bd - 1, -1, -1)), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) return [ mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(3) ]
def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) batch = _prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if have_jacobi_solver and n <= 32: kernel = f"{platform}solver_syevj" lwork, opaque = gpu_solver.build_syevj_descriptor( np.dtype(dtype), lower, batch, n) else: kernel = f"{platform}solver_syevd" lwork, opaque = gpu_solver.build_syevd_descriptor( np.dtype(dtype), lower, batch, n) if ir.ComplexType.isinstance(a_type.element_type): eigvals_type = ir.ComplexType(a_type.element_type).element_type else: eigvals_type = a_type.element_type i32_type = ir.IntegerType.get_signless(32) out = custom_call(kernel, [ a.type, ir.RankedTensorType.get(batch_dims + (n, ), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), [0], ]) return out[:3]
def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = tuple(range(num_bd - 1, -1, -1)) vector_layout = (num_bd, ) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) if have_jacobi_solver and m < 32 and n < 32: # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = gpu_solver.build_gesvdj_descriptor( np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0) k = min(m, n) matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) _, s, u, v, info, _ = custom_call(f"{platform}solver_gesvdj", [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k if econ else m), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, k if econ else n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[matrix_layout], result_layouts=[ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, [0], ]) vt = mhlo.TransposeOp( v, ir.DenseIntElementsAttr.get( np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result if np.issubdtype(dtype, np.complexfloating): vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result if not full_matrices and not econ: u = mhlo.SliceOp( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (m, min(m, n)))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result vt = mhlo.SliceOp( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (min(m, n), n))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result elif m < n: lwork, opaque = gpu_solver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m matrix_layout = (num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1)) _, s, vt, u, info, _ = custom_call(f"{platform}solver_gesvd", [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[matrix_layout], result_layouts=[ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, [0], ]) else: lwork, opaque = gpu_solver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) _, s, u, vt, info, _ = custom_call(f"{platform}solver_gesvd", [ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], [a], backend_config=opaque, operand_layouts=[matrix_layout], result_layouts=[ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, [0], ]) return s, u, vt, info
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < n: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _hipsolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("hipsolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def gesvd_mhlo(dtype, a, full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = _prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: singular_vals_type = a_type.element_type scalar_layout = ir.DenseIntElementsAttr.get(np.array( tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) vector_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, ) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) i32_type = ir.IntegerType.get_signless(32) if m < 32 and n < 32: # The batched kernel doesn't support "econ" mode. econ = not full_matrices and b == 1 lwork, opaque = _cusolver.build_gesvdj_descriptor( np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0) k = min(m, n) matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k if econ else m), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, k if econ else n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvdj"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result v = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)) info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result vt = mhlo.TransposeOp( v, ir.DenseIntElementsAttr.get( np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result if np.issubdtype(dtype, np.complexfloating): vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result if not full_matrices and not econ: u = mhlo.SliceOp( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (m, min(m, n)))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result vt = mhlo.SliceOp( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get( np.array(batch_dims + (min(m, n), n))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result elif m < n: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) k = n if full_matrices else m matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type), ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result else: lwork, opaque = _cusolver.build_gesvd_descriptor( np.dtype(dtype), b, m, n, compute_uv, full_matrices) k = m if full_matrices else n matrix_layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple([ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n), ), singular_vals_type), ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type), ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ]) ], [a], call_target_name=ir.StringAttr.get("cusolver_gesvd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([matrix_layout]), result_layouts=ir.ArrayAttr.get([ matrix_layout, vector_layout, matrix_layout, matrix_layout, scalar_layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ])) s = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 1)).result u = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 2)).result vt = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 3)).result info = mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, 4)).result return s, u, vt, info
def gees_mhlo(a, jobvs=True, sort=False, select=None): a_type = ir.RankedTensorType(a.type) etype = a_type.element_type dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = ir.DenseIntElementsAttr.get( np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), type=ir.IndexType.get()) if sort: raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") jobvs = ord('V' if jobvs else 'N') sort = ord('S' if sort else 'N') if not ir.ComplexType.isinstance(etype): fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees" schurvecs_type = etype workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)] workspace_layouts = [layout] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] * 2 eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] * 2 else: fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get()) else "lapack_zgees") schurvecs_type = etype workspaces = [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [ layout, ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()), ] eigvals = [ir.RankedTensorType.get(batch_dims + (n, ), etype)] eigvals_layouts = [ ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1), type=ir.IndexType.get()) ] i32_type = ir.IntegerType.get_signless(32) scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0, ), np.int64), type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ ir.TupleType.get_tuple(workspaces + eigvals + [ ir.RankedTensorType.get(dims, schurvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ]) ], [ _mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(np.uint8(jobvs)), _mhlo_u8(np.uint8(sort)), # TODO: figure out how to put the callable select function here a ], call_target_name=ir.StringAttr.get(fn), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(""), api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]), result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [ layout, ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1), type=ir.IndexType.get()), ])) i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i) if sort == ord('S'): return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(4)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result) else: return (mhlo.GetTupleElementOp(out, i32_attr(0)).result, mhlo.GetTupleElementOp(out, i32_attr(3)).result, mhlo.GetTupleElementOp(out, i32_attr(5)).result)
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None): a_type = ir.RankedTensorType(a.type) etype = a_type.element_type dims = a_type.shape assert len(dims) >= 2 m, n = dims[-2:] assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) b = 1 for d in batch_dims: b *= d layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if sort: raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") jobvs = ord('V' if jobvs else 'N') sort = ord('S' if sort else 'N') if dtype == np.float32: fn = "lapack_sgees" elif dtype == np.float64: fn = "lapack_dgees" elif dtype == np.complex64: fn = "lapack_cgees" elif dtype == np.complex128: fn = "lapack_zgees" else: raise NotImplementedError(f"Unsupported dtype {dtype}") if not np.issubdtype(dtype, np.complexfloating): workspaces = [ir.RankedTensorType.get(dims, etype)] workspace_layouts = [layout] eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2 eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 else: workspaces = [ ir.RankedTensorType.get(dims, etype), ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [layout, [0]] eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] eigvals_layouts = [tuple(range(num_bd, -1, -1))] i32_type = ir.IntegerType.get_signless(32) scalar_layout = [] out = custom_call( fn, workspaces + eigvals + [ ir.RankedTensorType.get(dims, etype), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ], [ _mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(np.uint8(jobvs)), _mhlo_u8(np.uint8(sort)), # TODO: figure out how to put the callable select function here a ], operand_layouts=[scalar_layout] * 4 + [layout], result_layouts=workspace_layouts + eigvals_layouts + [ layout, tuple(range(num_bd - 1, -1, -1)), tuple(range(num_bd - 1, -1, -1)), ] ) if sort == ord('S'): return (out[0], out[3], out[4], out[5]) else: return (out[0], out[3], out[5])