def test_affine_expr_roundtrip(): """ Create affine maps, semi-affine maps, and integer sets, checking for correct parsing. """ code = '''#map0 = (d0, d1) -> (d0, d1) #map1 = (d0) -> (d0) #map2 = () -> (0) #map3 = () -> (10) #map4 = (d0, d1, d2) -> (d0, d1 + d2 + 5) #map5 = (d0, d1, d2) -> (d0 + d1, d2) #map6 = (d0, d1)[s0] -> (d0, d1 + s0 + 7) #map7 = (d0, d1)[s0] -> (d0 + s0, d1) #map8 = (d0, d1) -> (d0 + d1 + 11) #map9 = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7) #map10 = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1) #samap0 = (d0)[s0] -> (d0 floordiv (s0 + 1)) #samap1 = (d0)[s0] -> (d0 floordiv s0) #samap2 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1) #set0 = (d0) : (1 == 0) #set1 = (d0, d1)[s0] : () #set2 = (d0, d1)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0) #set3 = (d0, d1, d2) : (d0 - d2 * 4 == 0, d0 + d1 * 8 - 9 >= 0, -d0 - d1 * 8 + 11 >= 0) #set4 = (d0, d1, d2, d3, d4, d5) : (d0 * 1089234 + d1 * 203472 + 82342 >= 0, d0 * -55 + d1 * 24 + d2 * 238 - d3 * 234 - 9743 >= 0, d0 * -5445 - d1 * 284 + d2 * 23 + d3 * 34 - 5943 >= 0, d0 * -5445 + d1 * 284 + d2 * 238 - d3 * 34 >= 0, d0 * 445 + d1 * 284 + d2 * 238 + d3 * 39 >= 0, d0 * -545 + d1 * 214 + d2 * 218 - d3 * 94 >= 0, d0 * 44 - d1 * 184 - d2 * 231 + d3 * 14 >= 0, d0 * -45 + d1 * 284 + d2 * 138 - d3 * 39 >= 0, d0 * 154 - d1 * 84 + d2 * 238 - d3 * 34 >= 0, d0 * 54 - d1 * 284 - d2 * 223 + d3 * 384 >= 0, d0 * -55 + d1 * 284 + d2 * 23 + d3 * 34 >= 0, d0 * 54 - d1 * 84 + d2 * 28 - d3 * 34 >= 0, d0 * 54 - d1 * 24 - d2 * 23 + d3 * 34 >= 0, d0 * -55 + d1 * 24 + d2 * 23 + d3 * 4 >= 0, d0 * 15 - d1 * 84 + d2 * 238 - d3 * 3 >= 0, d0 * 5 - d1 * 24 - d2 * 223 + d3 * 84 >= 0, d0 * -5 + d1 * 284 + d2 * 23 - d3 * 4 >= 0, d0 * 14 + d2 * 4 + 7234 >= 0, d0 * -174 - d2 * 534 + 9834 >= 0, d0 * 194 - d2 * 954 + 9234 >= 0, d0 * 47 - d2 * 534 + 9734 >= 0, d0 * -194 - d2 * 934 + 984 >= 0, d0 * -947 - d2 * 953 + 234 >= 0, d0 * 184 - d2 * 884 + 884 >= 0, d0 * -174 + d2 * 834 + 234 >= 0, d0 * 844 + d2 * 634 + 9874 >= 0, d2 * -797 - d3 * 79 + 257 >= 0, d0 * 2039 + d2 * 793 - d3 * 99 - d4 * 24 + d5 * 234 >= 0, d2 * 78 - d5 * 788 + 257 >= 0, d3 - (d5 + d0 * 97) floordiv 423 >= 0, ((d0 + (d3 mod 5) floordiv 2342) * 234) mod 2309 + (d0 + d3 * 2038) floordiv 208 >= 0, ((((d0 + d3 * 2300) * 239) floordiv 2342) mod 2309) mod 239423 == 0, d0 + d3 mod 2642 + (((((d3 + d0 * 2) mod 1247) mod 2038) mod 2390) mod 2039) floordiv 55 >= 0) ''' module = parse_string(code) assert module.dump() == code
def test_query(): block = parse_string(""" func @saxpy(%a : f64, %x : memref<?xf64>, %y : memref<?xf64>) { %c0 = constant 0 : index %n = dim %x, %c0 : memref<?xf64> affine.for %i = 0 to %n { %xi = affine.load %x[%i+1] : memref<?xf64> %axi = mulf %a, %xi : f64 %yi = affine.load %y[%i] : memref<?xf64> %axpyi = addf %yi, %axi : f64 affine.store %axpyi, %y[%i] : memref<?xf64> } return }""").module.region.body[0].region.body[0] for_block = block.body[2].op.region.body[0] c0 = block.body[0].result_list[0].value def query(expr): return next((op for op in block.body + for_block.body if expr(op))) assert query(Writes("%c0")).dump() == "%c0 = constant 0 : index" assert (query(Reads("%y") & Isa(AffineLoadOp)).dump() == "%yi = affine.load %y [ %i ] : memref<?xf64>") assert query(Reads(c0)).dump() == "%n = dim %x , %c0 : memref<?xf64>"
def test_toy_simple(): code = ''' module { func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { %t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64> return %t_tensor : tensor<3x2xf64> } } ''' module = parse_string(code) print(module.pretty())
def test_custom_dialect(): code = '''module { func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> { %t_tensor = toy.densify %ragged : tensor<32x14xf64> return %t_tensor : tensor<32x14xf64> } }''' m = parse_string(code, dialects=[my_dialect]) dump = m.pretty() print(dump) # Test for round-trip assert dump == code
def test_toy_roundtrip(): """ Create MLIR code without extra whitespace and check that it can parse and dump the same way. """ code = '''module { func @toy_func(%arg0: tensor<2x3xf64>) -> tensor<3x2xf64> { %0 = "toy.transpose"(%arg0) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> return %0 : tensor<3x2xf64> } }''' module = parse_string(code) dump = module.dump() assert dump == code
def test_loop_dialect_roundtrip(): src = """module { func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>, %C: memref<?xf32>, %result: memref<?xf32>) { %c0 = constant 0 : index %c1 = constant 1 : index %d0 = dim %A , %c0 : memref<?xf32> %b0 = affine.min affine_map<()[s0, s1] -> (1024, s0 - s1)> ()[%d0, %outer] scf.for %i0 = %c0 to %b0 step %c1 { %B_elem = load %B [ %i0 ] : memref<?xf32> %C_elem = load %C [ %i0 ] : memref<?xf32> %sum_elem = addf %B_elem , %C_elem : f32 store %sum_elem , %result [ %i0 ] : memref<?xf32> } return } }""" assert parse_string(src).dump() == src
def parse_mlir_functions(mlir_text: Union[str, bytes], cli: MlirOptCli) -> mlir.astnodes.Module: if isinstance(mlir_text, str): mlir_text = mlir_text.encode() # Run text thru mlir-opt to apply aliases and flatten function signatures mlir_text = cli.apply_passes(mlir_text, []) # Remove everything except function signatures func_lines = [ line.strip().replace("builtin.func ", "func ") for line in mlir_text.splitlines() if FUNC_PATTERN.match(line) ] # Add in trailing "}" to make defined functions valid func_lines = [ line + "}" if line[-1] == "{" else line for line in func_lines ] mlir_ast = mlir.parse_string("\n".join(func_lines)) return mlir_ast
def get_ast(code: str): return mlir.parse_string(code).modules[0]
def assert_roundtrip_equivalence(source): assert source == mlir.parse_string(source).dump()