""" Implementation of the LLVM dialect. """ import inspect import sys from mlir.dialect import Dialect, DialectOp, DialectType, is_op, is_type class LLVMVec(DialectType): _syntax_ = [ ("llvm.vec < {size.integer_literal} x {type.type} >"), ("llvm.vec < ? x {size_factor.integer_literal} x {type.type} >"), ] class LLVMPtr(DialectType): _syntax_ = "llvm.ptr< {type.type} >" # Inspect current module to get all classes defined above llvm = Dialect( "llvm", # TODO Add some operations # ops=[m[1] for m in inspect.getmembers( # sys.modules[__name__], lambda obj: is_op(obj, __name__))], types=[ m[1] for m in inspect.getmembers(sys.modules[__name__], lambda obj: is_type(obj, __name__)) ])
_syntax_ = [ 'toy.densify {arg.ssa_id} : {type.tensor_type}', 'toy.densify {arg.ssa_id} , {pad.constant_literal} : {type.tensor_type}' ] ############################################################################## # Dialect my_dialect = Dialect( 'toy', ops=[DensifyOp], types=[RaggedTensorType], preamble=''' // Exclamation mark in Lark means that string tokens will be preserved upon parsing !toy_impl_type : "coo" | "csr" | "csc" | "ell" toy_impl_list : toy_impl_type ("+" toy_impl_type)* ''', transformers=dict( toy_impl_list=ToyImplementation, # Will convert every instance to its contents toy_impl_type=lambda v: v[0])) ############################################################################## # Tests def test_custom_dialect(): code = '''module { func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> { %t_tensor = toy.densify %ragged : tensor<32x14xf64>
tag_index: List[SsaUse] src_type: ast.MemRefType dst_type: ast.MemRefType tag_type: ast.MemRefType stride: Optional[SsaUse] = None transfer_per_stride: Optional[SsaUse] = None _syntax_ = [ 'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}', 'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}' ] @dataclass class AffineDmaWaitOperation(DialectOp): tag: SsaUse tag_index: ast.MultiDimAffineExpr size: SsaUse type: ast.MemRefType _syntax_ = 'affine.dma_wait {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {type.memref_type}' # Inspect current module to get all classes defined above affine = Dialect( 'affine', ops=[ m[1] for m in inspect.getmembers(sys.modules[__name__], lambda obj: is_op(obj, __name__)) ])
class LinalgYield(DialectOp): _syntax_ = ("linalg.yield {operand_ids.ssa_id_list}" " : {operand_types.type_list_no_parens}") class LinalgMatmul(DialectOp): _syntax_ = [ ("linalg.matmul" " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" " outs( {c_id.ssa_id} : {c_type.type} )"), ("linalg.matmul" " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" " init( {init_id.ssa_id} : {init_type.type} ) -> {out_type.type}") ] class LinalgMatvec(DialectOp): _syntax_ = [( "linalg.matvec" " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" " outs( {c_id.ssa_id} : {c_type.type} )")] # Inspect current module to get all classes defined above linalg = Dialect( "linalg", ops=[ m[1] for m in inspect.getmembers(sys.modules[__name__], lambda obj: is_op(obj, __name__)) ])