Пример #1
0
""" Implementation of the LLVM dialect. """

import inspect
import sys
from mlir.dialect import Dialect, DialectOp, DialectType, is_op, is_type


class LLVMVec(DialectType):
    _syntax_ = [
        ("llvm.vec < {size.integer_literal} x {type.type} >"),
        ("llvm.vec < ? x {size_factor.integer_literal} x {type.type} >"),
    ]


class LLVMPtr(DialectType):
    _syntax_ = "llvm.ptr< {type.type} >"


# Inspect current module to get all classes defined above
llvm = Dialect(
    "llvm",
    # TODO Add some operations
    # ops=[m[1] for m in inspect.getmembers(
    #     sys.modules[__name__], lambda obj: is_op(obj, __name__))],
    types=[
        m[1] for m in inspect.getmembers(sys.modules[__name__],
                                         lambda obj: is_type(obj, __name__))
    ])
Пример #2
0
    _syntax_ = [
        'toy.densify {arg.ssa_id} : {type.tensor_type}',
        'toy.densify {arg.ssa_id} , {pad.constant_literal} : {type.tensor_type}'
    ]


##############################################################################
# Dialect

my_dialect = Dialect(
    'toy',
    ops=[DensifyOp],
    types=[RaggedTensorType],
    preamble='''
// Exclamation mark in Lark means that string tokens will be preserved upon parsing
!toy_impl_type : "coo" | "csr" | "csc" | "ell"
toy_impl_list   : toy_impl_type ("+" toy_impl_type)*
                     ''',
    transformers=dict(
        toy_impl_list=ToyImplementation,
        # Will convert every instance to its contents
        toy_impl_type=lambda v: v[0]))

##############################################################################
# Tests


def test_custom_dialect():
    code = '''module {
  func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> {
    %t_tensor = toy.densify %ragged : tensor<32x14xf64>
Пример #3
0
    tag_index: List[SsaUse]
    src_type: ast.MemRefType
    dst_type: ast.MemRefType
    tag_type: ast.MemRefType
    stride: Optional[SsaUse] = None
    transfer_per_stride: Optional[SsaUse] = None

    _syntax_ = [
        'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}',
        'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}'
    ]


@dataclass
class AffineDmaWaitOperation(DialectOp):
    tag: SsaUse
    tag_index: ast.MultiDimAffineExpr
    size: SsaUse
    type: ast.MemRefType

    _syntax_ = 'affine.dma_wait {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {type.memref_type}'


# Inspect current module to get all classes defined above
affine = Dialect(
    'affine',
    ops=[
        m[1] for m in inspect.getmembers(sys.modules[__name__],
                                         lambda obj: is_op(obj, __name__))
    ])
Пример #4
0
class LinalgYield(DialectOp):
    _syntax_ = ("linalg.yield {operand_ids.ssa_id_list}"
                " : {operand_types.type_list_no_parens}")


class LinalgMatmul(DialectOp):
    _syntax_ = [
        ("linalg.matmul"
         " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
         " outs( {c_id.ssa_id} : {c_type.type} )"),
        ("linalg.matmul"
         " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
         " init( {init_id.ssa_id} : {init_type.type} )  -> {out_type.type}")
    ]


class LinalgMatvec(DialectOp):
    _syntax_ = [(
        "linalg.matvec"
        " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
        " outs( {c_id.ssa_id} : {c_type.type} )")]


# Inspect current module to get all classes defined above
linalg = Dialect(
    "linalg",
    ops=[
        m[1] for m in inspect.getmembers(sys.modules[__name__],
                                         lambda obj: is_op(obj, __name__))
    ])