Exemple #1
0
def test_affine_expr_roundtrip():
    """
    Create affine maps, semi-affine maps, and integer sets, checking for
    correct parsing.
    """
    code = '''#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0) -> (d0)
#map2 = () -> (0)
#map3 = () -> (10)
#map4 = (d0, d1, d2) -> (d0, d1 + d2 + 5)
#map5 = (d0, d1, d2) -> (d0 + d1, d2)
#map6 = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
#map7 = (d0, d1)[s0] -> (d0 + s0, d1)
#map8 = (d0, d1) -> (d0 + d1 + 11)
#map9 = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
#map10 = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
#samap0 = (d0)[s0] -> (d0 floordiv (s0 + 1))
#samap1 = (d0)[s0] -> (d0 floordiv s0)
#samap2 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)
#set0 = (d0) : (1 == 0)
#set1 = (d0, d1)[s0] : ()
#set2 = (d0, d1)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0)
#set3 = (d0, d1, d2) : (d0 - d2 * 4 == 0, d0 + d1 * 8 - 9 >= 0, -d0 - d1 * 8 + 11 >= 0)
#set4 = (d0, d1, d2, d3, d4, d5) : (d0 * 1089234 + d1 * 203472 + 82342 >= 0, d0 * -55 + d1 * 24 + d2 * 238 - d3 * 234 - 9743 >= 0, d0 * -5445 - d1 * 284 + d2 * 23 + d3 * 34 - 5943 >= 0, d0 * -5445 + d1 * 284 + d2 * 238 - d3 * 34 >= 0, d0 * 445 + d1 * 284 + d2 * 238 + d3 * 39 >= 0, d0 * -545 + d1 * 214 + d2 * 218 - d3 * 94 >= 0, d0 * 44 - d1 * 184 - d2 * 231 + d3 * 14 >= 0, d0 * -45 + d1 * 284 + d2 * 138 - d3 * 39 >= 0, d0 * 154 - d1 * 84 + d2 * 238 - d3 * 34 >= 0, d0 * 54 - d1 * 284 - d2 * 223 + d3 * 384 >= 0, d0 * -55 + d1 * 284 + d2 * 23 + d3 * 34 >= 0, d0 * 54 - d1 * 84 + d2 * 28 - d3 * 34 >= 0, d0 * 54 - d1 * 24 - d2 * 23 + d3 * 34 >= 0, d0 * -55 + d1 * 24 + d2 * 23 + d3 * 4 >= 0, d0 * 15 - d1 * 84 + d2 * 238 - d3 * 3 >= 0, d0 * 5 - d1 * 24 - d2 * 223 + d3 * 84 >= 0, d0 * -5 + d1 * 284 + d2 * 23 - d3 * 4 >= 0, d0 * 14 + d2 * 4 + 7234 >= 0, d0 * -174 - d2 * 534 + 9834 >= 0, d0 * 194 - d2 * 954 + 9234 >= 0, d0 * 47 - d2 * 534 + 9734 >= 0, d0 * -194 - d2 * 934 + 984 >= 0, d0 * -947 - d2 * 953 + 234 >= 0, d0 * 184 - d2 * 884 + 884 >= 0, d0 * -174 + d2 * 834 + 234 >= 0, d0 * 844 + d2 * 634 + 9874 >= 0, d2 * -797 - d3 * 79 + 257 >= 0, d0 * 2039 + d2 * 793 - d3 * 99 - d4 * 24 + d5 * 234 >= 0, d2 * 78 - d5 * 788 + 257 >= 0, d3 - (d5 + d0 * 97) floordiv 423 >= 0, ((d0 + (d3 mod 5) floordiv 2342) * 234) mod 2309 + (d0 + d3 * 2038) floordiv 208 >= 0, ((((d0 + d3 * 2300) * 239) floordiv 2342) mod 2309) mod 239423 == 0, d0 + d3 mod 2642 + (((((d3 + d0 * 2) mod 1247) mod 2038) mod 2390) mod 2039) floordiv 55 >= 0)
'''
    module = parse_string(code)
    assert module.dump() == code
Exemple #2
0
def test_query():
    block = parse_string("""
func @saxpy(%a : f64, %x : memref<?xf64>, %y : memref<?xf64>) {
%c0 = constant 0 : index
%n = dim %x, %c0 : memref<?xf64>

affine.for %i = 0 to %n {
  %xi = affine.load %x[%i+1] : memref<?xf64>
  %axi =  mulf %a, %xi : f64
  %yi = affine.load %y[%i] : memref<?xf64>
  %axpyi = addf %yi, %axi : f64
  affine.store %axpyi, %y[%i] : memref<?xf64>
}
return
}""").module.region.body[0].region.body[0]
    for_block = block.body[2].op.region.body[0]

    c0 = block.body[0].result_list[0].value

    def query(expr):
        return next((op for op in block.body + for_block.body if expr(op)))

    assert query(Writes("%c0")).dump() == "%c0 = constant 0 : index"
    assert (query(Reads("%y") & Isa(AffineLoadOp)).dump() ==
            "%yi = affine.load %y [ %i ] : memref<?xf64>")

    assert query(Reads(c0)).dump() == "%n = dim %x , %c0 : memref<?xf64>"
Exemple #3
0
def test_toy_simple():
    code = '''
module {
  func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
    %t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %t_tensor : tensor<3x2xf64>
  }
}
    '''

    module = parse_string(code)
    print(module.pretty())
Exemple #4
0
def test_custom_dialect():
    code = '''module {
  func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> {
    %t_tensor = toy.densify %ragged : tensor<32x14xf64>
    return %t_tensor : tensor<32x14xf64>
  }
}'''
    m = parse_string(code, dialects=[my_dialect])
    dump = m.pretty()
    print(dump)

    # Test for round-trip
    assert dump == code
Exemple #5
0
def test_toy_roundtrip():
    """
    Create MLIR code without extra whitespace and check that it can parse
    and dump the same way.
    """
    code = '''module {
  func @toy_func(%arg0: tensor<2x3xf64>) -> tensor<3x2xf64> {
    %0 = "toy.transpose"(%arg0) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %0 : tensor<3x2xf64>
  }
}'''

    module = parse_string(code)
    dump = module.dump()
    assert dump == code
Exemple #6
0
def test_loop_dialect_roundtrip():
    src = """module {
  func @for(%outer: index, %A: memref<?xf32>, %B: memref<?xf32>, %C: memref<?xf32>, %result: memref<?xf32>) {
    %c0 = constant 0 : index
    %c1 = constant 1 : index
    %d0 = dim %A , %c0 : memref<?xf32>
    %b0 = affine.min affine_map<()[s0, s1] -> (1024, s0 - s1)> ()[%d0, %outer]
    scf.for %i0 = %c0 to %b0 step %c1 {
      %B_elem = load %B [ %i0 ] : memref<?xf32>
      %C_elem = load %C [ %i0 ] : memref<?xf32>
      %sum_elem = addf %B_elem , %C_elem : f32
      store %sum_elem , %result [ %i0 ] : memref<?xf32>
    }
    return
  }
}"""
    assert parse_string(src).dump() == src
Exemple #7
0
def parse_mlir_functions(mlir_text: Union[str, bytes],
                         cli: MlirOptCli) -> mlir.astnodes.Module:
    if isinstance(mlir_text, str):
        mlir_text = mlir_text.encode()
    # Run text thru mlir-opt to apply aliases and flatten function signatures
    mlir_text = cli.apply_passes(mlir_text, [])
    # Remove everything except function signatures
    func_lines = [
        line.strip().replace("builtin.func ", "func ")
        for line in mlir_text.splitlines() if FUNC_PATTERN.match(line)
    ]
    # Add in trailing "}" to make defined functions valid
    func_lines = [
        line + "}" if line[-1] == "{" else line for line in func_lines
    ]

    mlir_ast = mlir.parse_string("\n".join(func_lines))
    return mlir_ast
Exemple #8
0
def get_ast(code: str):
    return mlir.parse_string(code).modules[0]
Exemple #9
0
def assert_roundtrip_equivalence(source):
    assert source == mlir.parse_string(source).dump()