def testFunctionMultiple(self):
     self.setUp()
     with self.module.function_context("foo", [], []):
         pass
     with self.module.function_context("foo", [], []):
         E.constant_index(0)
     printWithCurrentFunctionName(str(self.module))
示例#2
0
    def testLoopNestContext(self):
        with self.module.function_context("foo", [], []) as fun:
            lbs = [E.constant_index(i) for i in range(4)]
            ubs = [E.constant_index(10 * i + 5) for i in range(4)]
            with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
                i + j + k + l

        code = str(fun)
        self.assertIn(
            ' "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n',
            code)
        self.assertIn("  ^bb1(%i0: index):", code)
        self.assertIn(
            '    "affine.for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n',
            code)
        self.assertIn("    ^bb2(%i1: index):", code)
        self.assertIn(
            '      "affine.for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n',
            code)
        self.assertIn("      ^bb3(%i2: index):", code)
        self.assertIn(
            '        "affine.for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n',
            code)
        self.assertIn("        ^bb4(%i3: index):", code)
        self.assertIn(
            '          %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
            code)
示例#3
0
    def testConstants(self):
        with self.module.function_context("constants", [], []) as fun:
            E.constant_float(1.23, self.module.make_scalar_type("bf16"))
            E.constant_float(1.23, self.module.make_scalar_type("f16"))
            E.constant_float(1.23, self.module.make_scalar_type("f32"))
            E.constant_float(1.23, self.module.make_scalar_type("f64"))
            E.constant_int(1, 1)
            E.constant_int(123, 8)
            E.constant_int(123, 16)
            E.constant_int(123, 32)
            E.constant_int(123, 64)
            E.constant_index(123)
            E.constant_function(fun)

        code = str(fun)
        self.assertIn("constant 1.230000e+00 : bf16", code)
        self.assertIn("constant 1.230470e+00 : f16", code)
        self.assertIn("constant 1.230000e+00 : f32", code)
        self.assertIn("constant 1.230000e+00 : f64", code)
        self.assertIn("constant 1 : i1", code)
        self.assertIn("constant 123 : i8", code)
        self.assertIn("constant 123 : i16", code)
        self.assertIn("constant 123 : i32", code)
        self.assertIn("constant 123 : index", code)
        self.assertIn("constant @constants : () -> ()", code)
 def testBlockArguments(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         E.constant_index(42)
         with E.BlockContext([self.f32Type, self.f32Type]) as b:
             b.arg(0) + b.arg(1)
         printWithCurrentFunctionName(str(fun))
示例#5
0
    def testIndexedValue(self):
        memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
        with self.module.function_context("indexed", [memrefType],
                                          [memrefType]) as fun:
            A = E.IndexedValue(fun.arg(0))
            cst = E.constant_float(1., self.f32Type)
            with E.LoopNestContext(
                [E.constant_index(0), E.constant_index(0)],
                [E.constant_index(10),
                 E.constant_index(42)], [1, 1]) as (i, j):
                A.store([i, j], A.load([i, j]) + cst)
            E.ret([fun.arg(0)])

        code = str(fun)
        self.assertIn('"affine.for"()', code)
        self.assertIn(
            "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}",
            code)
        self.assertIn('"affine.for"()', code)
        self.assertIn(
            "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}",
            code)
        self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code)
        self.assertIn("%1 = addf %0, %cst : f32", code)
        self.assertIn("store %1, %arg0[%i0, %i1] : memref<10x42xf32>", code)
 def testLoopNestContext(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         lbs = [E.constant_index(i) for i in range(4)]
         ubs = [E.constant_index(10 * i + 5) for i in range(4)]
         with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
             i + j + k + l
     printWithCurrentFunctionName(str(fun))
 def testBrArgs(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         # Create an infinite loop.
         with E.BlockContext([self.indexType, self.indexType]) as b:
             E.br(b, [b.arg(1), b.arg(0)])
         E.br(b, [E.constant_index(0), E.constant_index(1)])
         printWithCurrentFunctionName(str(fun))
 def testRet(self):
     self.setUp()
     with self.module.function_context(
             "foo", [], [self.indexType, self.indexType]) as fun:
         c42 = E.constant_index(42)
         c0 = E.constant_index(0)
         E.ret([c42, c0])
         printWithCurrentFunctionName(str(fun))
示例#9
0
 def testBlockArguments(self):
     with self.module.function_context("foo", [], []) as fun:
         E.constant_index(42)
         with E.BlockContext([self.f32Type, self.f32Type]) as b:
             b.arg(0) + b.arg(1)
     code = str(fun)
     self.assertIn("%c42 = constant 42 : index", code)
     self.assertIn("^bb1(%0: f32, %1: f32):", code)
     self.assertIn("  %2 = addf %0, %1 : f32", code)
示例#10
0
 def testRet(self):
     with self.module.function_context(
             "foo", [], [self.indexType, self.indexType]) as fun:
         c42 = E.constant_index(42)
         c0 = E.constant_index(0)
         E.ret([c42, c0])
     code = str(fun)
     self.assertIn("  %c42 = constant 42 : index", code)
     self.assertIn("  %c0 = constant 0 : index", code)
     self.assertIn("  return %c42, %c0 : index, index", code)
 def testLoopContext(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         lhs = E.constant_index(0)
         rhs = E.constant_index(42)
         with E.LoopContext(lhs, rhs, 1) as i:
             lhs + rhs + i
             with E.LoopContext(rhs, rhs + rhs, 2) as j:
                 x = i + j
         printWithCurrentFunctionName(str(fun))
示例#12
0
 def testBrArgs(self):
     with self.module.function_context("foo", [], []) as fun:
         # Create an infinite loop.
         with E.BlockContext([self.indexType, self.indexType]) as b:
             E.br(b, [b.arg(1), b.arg(0)])
         E.br(b, [E.constant_index(0), E.constant_index(1)])
     code = str(fun)
     self.assertIn("  %c0 = constant 0 : index", code)
     self.assertIn("  %c1 = constant 1 : index", code)
     self.assertIn("  br ^bb1(%c0, %c1 : index, index)", code)
     self.assertIn("^bb1(%0: index, %1: index):", code)
     self.assertIn("  br ^bb1(%1, %0 : index, index)", code)
示例#13
0
 def testIndexedValue(self):
   self.setUp()
   memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
   with self.module.function_context("indexed", [memrefType],
                                     [memrefType]) as fun:
     A = E.IndexedValue(fun.arg(0))
     cst = E.constant_float(1., self.f32Type)
     with E.LoopNestContext(
         [E.constant_index(0), E.constant_index(0)],
         [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
       A.store([i, j], A.load([i, j]) + cst)
     E.ret([fun.arg(0)])
     printWithCurrentFunctionName(str(fun))
示例#14
0
 def testBlockContextAppend(self):
     with self.module.function_context("foo", [], []) as fun:
         E.constant_index(41)
         with E.BlockContext() as b:
             blk = b  # save block handle for later
             E.constant_index(0)
         E.constant_index(42)
         with E.BlockContext(E.appendTo(blk)):
             E.constant_index(1)
     code = str(fun)
     # Find positions of instructions and make sure they are in the block we put
     # them by comparing those positions.
     c41pos = code.find("%c41 = constant 41 : index")
     c42pos = code.find("%c42 = constant 42 : index")
     bb1pos = code.find("^bb1:")
     c0pos = code.find("%c0 = constant 0 : index")
     c1pos = code.find("%c1 = constant 1 : index")
     self.assertNotEqual(c41pos, -1)
     self.assertNotEqual(c42pos, -1)
     self.assertNotEqual(bb1pos, -1)
     self.assertNotEqual(c0pos, -1)
     self.assertNotEqual(c1pos, -1)
     self.assertGreater(bb1pos, c41pos)
     self.assertGreater(bb1pos, c42pos)
     self.assertLess(bb1pos, c0pos)
     self.assertLess(bb1pos, c1pos)
示例#15
0
    def testMultipleFunctions(self):
        with self.module.function_context("foo", [], []):
            E.constant_index(0)
        code = str(self.module)
        self.assertIn("func @foo()", code)
        self.assertIn("  %c0 = constant 0 : index", code)

        with self.module.function_context("bar", [], []):
            E.constant_index(42)
        code = str(self.module)
        barPos = code.find("func @bar()")
        c42Pos = code.find("%c42 = constant 42 : index")
        self.assertNotEqual(barPos, -1)
        self.assertNotEqual(c42Pos, -1)
        self.assertGreater(c42Pos, barPos)
示例#16
0
 def testMatrixMultiply(self):
   self.setUp()
   memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
   with self.module.function_context(
       "matmul", [memrefType, memrefType, memrefType], []) as fun:
     A = E.IndexedValue(fun.arg(0))
     B = E.IndexedValue(fun.arg(1))
     C = E.IndexedValue(fun.arg(2))
     c0 = E.constant_index(0)
     c32 = E.constant_index(32)
     with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j,
                                                                          k):
       C.store([i, j], A.load([i, k]) * B.load([k, j]))
     E.ret([])
     printWithCurrentFunctionName(str(fun))
 def testConstants(self):
     self.setUp()
     with self.module.function_context("constants", [], []) as fun:
         E.constant_float(1.23, self.module.make_scalar_type("bf16"))
         E.constant_float(1.23, self.module.make_scalar_type("f16"))
         E.constant_float(1.23, self.module.make_scalar_type("f32"))
         E.constant_float(1.23, self.module.make_scalar_type("f64"))
         E.constant_int(1, 1)
         E.constant_int(123, 8)
         E.constant_int(123, 16)
         E.constant_int(123, 32)
         E.constant_int(123, 64)
         E.constant_index(123)
         E.constant_function(fun)
         printWithCurrentFunctionName(str(fun))
示例#18
0
 def testMLIRBooleanCompilation(self):
   self.setUp()
   m = self.module.make_memref_type(self.boolType, [10])  # i1 tensor
   with self.module.function_context("mkbooltensor", [m, m], []) as f:
     input = E.IndexedValue(f.arg(0))
     output = E.IndexedValue(f.arg(1))
     zero = E.constant_index(0)
     ten = E.constant_index(10)
     with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
       b1 = (i < j) & (j < k)
       b2 = ~b1
       b3 = b2 | (k < j)
       output.store([i], input.load([i]) & b3)
     E.ret([])
   self.module.compile()
   printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
 def testBlockContext(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         cst = E.constant_index(42)
         with E.BlockContext():
             cst + cst
         printWithCurrentFunctionName(str(fun))
示例#20
0
 def testDivisions(self):
   self.setUp()
   with self.module.function_context(
       "division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
     # indices only support floor division
     fun.arg(0) // E.constant_index(42)
     # regular values only support regular division
     fun.arg(1) / fun.arg(2)
     printWithCurrentFunctionName(str(self.module))
 def testBlockContextAppend(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         E.constant_index(41)
         with E.BlockContext() as b:
             blk = b  # save block handle for later
             E.constant_index(0)
         E.constant_index(42)
         with E.BlockContext(E.appendTo(blk)):
             E.constant_index(1)
         printWithCurrentFunctionName(str(fun))
 def testCondBr(self):
     self.setUp()
     with self.module.function_context("foo", [self.boolType], []) as fun:
         with E.BlockContext() as blk1:
             E.ret([])
         with E.BlockContext([self.indexType]) as blk2:
             E.ret([])
         cst = E.constant_index(0)
         E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
         printWithCurrentFunctionName(str(fun))
示例#23
0
    def testCondBr(self):
        with self.module.function_context("foo", [self.boolType], []) as fun:
            with E.BlockContext() as blk1:
                E.ret([])
            with E.BlockContext([self.indexType]) as blk2:
                E.ret([])
            cst = E.constant_index(0)
            E.cond_br(fun.arg(0), blk1, [], blk2, [cst])

        code = str(fun)
        self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code)
示例#24
0
 def testLoopContext(self):
     with self.module.function_context("foo", [], []) as fun:
         lhs = E.constant_index(0)
         rhs = E.constant_index(42)
         with E.LoopContext(lhs, rhs, 1) as i:
             lhs + rhs + i
             with E.LoopContext(rhs, rhs + rhs, 2) as j:
                 x = i + j
     code = str(fun)
     # TODO(zinenko,ntv): use FileCheck for these tests
     self.assertIn(
         '  "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n',
         code)
     self.assertIn("  ^bb1(%i0: index):", code)
     self.assertIn(
         '    "affine.for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n',
         code)
     self.assertIn("    ^bb2(%i1: index):", code)
     self.assertIn(
         '      %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index',
         code)
示例#25
0
    def testDivisions(self):
        with self.module.function_context(
                "division", [self.indexType, self.i32Type, self.i32Type],
            []) as fun:
            # indices only support floor division
            fun.arg(0) // E.constant_index(42)
            # regular values only support regular division
            fun.arg(1) / fun.arg(2)

        code = str(self.module)
        self.assertIn("floordiv 42", code)
        self.assertIn("divis %arg1, %arg2 : i32", code)
示例#26
0
    def testMatrixMultiply(self):
        memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
        with self.module.function_context("matmul",
                                          [memrefType, memrefType, memrefType],
                                          []) as fun:
            A = E.IndexedValue(fun.arg(0))
            B = E.IndexedValue(fun.arg(1))
            C = E.IndexedValue(fun.arg(2))
            c0 = E.constant_index(0)
            c32 = E.constant_index(32)
            with E.LoopNestContext([c0, c0, c0], [c32, c32, c32],
                                   [1, 1, 1]) as (i, j, k):
                C.store([i, j], A.load([i, k]) * B.load([k, j]))
            E.ret([])

        code = str(fun)
        self.assertIn(
            '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()',
            code)
        self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code)
        self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code)
        self.assertIn("%2 = mulf %0, %1 : f32", code)
        self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)
示例#27
0
 def testBlockContext(self):
     with self.module.function_context("foo", [], []) as fun:
         cst = E.constant_index(42)
         with E.BlockContext():
             cst + cst
     code = str(fun)
     # Find positions of instructions and make sure they are in the block we
     # put them by comparing those positions.
     # TODO(zinenko,ntv): this (and tests below) should use FileCheck instead.
     c42pos = code.find("%c42 = constant 42 : index")
     bb1pos = code.find("^bb1:")
     c84pos = code.find(
         '%0 = "affine.apply"() {map: () -> (84)} : () -> index')
     self.assertNotEqual(c42pos, -1)
     self.assertNotEqual(bb1pos, -1)
     self.assertNotEqual(c84pos, -1)
     self.assertGreater(bb1pos, c42pos)
     self.assertLess(bb1pos, c84pos)
示例#28
0
 def testIndexCast(self):
     self.setUp()
     with self.module.function_context("testIndexCast", [], []):
         index = E.constant_index(0)
         E.index_cast(index, self.i32Type)
     printWithCurrentFunctionName(str(self.module))
 def testBlockContextStandalone(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         blk1 = E.BlockContext()
         blk2 = E.BlockContext()
         with blk1:
             E.constant_index(0)
         with blk2:
             E.constant_index(56)
             E.constant_index(57)
         E.constant_index(41)
         with blk1:
             E.constant_index(1)
         E.constant_index(42)
         printWithCurrentFunctionName(str(fun))
示例#30
0
 def testBlockContextStandalone(self):
     with self.module.function_context("foo", [], []) as fun:
         blk1 = E.BlockContext()
         blk2 = E.BlockContext()
         with blk1:
             E.constant_index(0)
         with blk2:
             E.constant_index(56)
             E.constant_index(57)
         E.constant_index(41)
         with blk1:
             E.constant_index(1)
         E.constant_index(42)
     code = str(fun)
     # Find positions of instructions and make sure they are in the block we put
     # them by comparing those positions.
     c41pos = code.find("  %c41 = constant 41 : index")
     c42pos = code.find("  %c42 = constant 42 : index")
     bb1pos = code.find("^bb1:")
     c0pos = code.find("  %c0 = constant 0 : index")
     c1pos = code.find("  %c1 = constant 1 : index")
     bb2pos = code.find("^bb2:")
     c56pos = code.find("  %c56 = constant 56 : index")
     c57pos = code.find("  %c57 = constant 57 : index")
     self.assertNotEqual(c41pos, -1)
     self.assertNotEqual(c42pos, -1)
     self.assertNotEqual(bb1pos, -1)
     self.assertNotEqual(c0pos, -1)
     self.assertNotEqual(c1pos, -1)
     self.assertNotEqual(bb2pos, -1)
     self.assertNotEqual(c56pos, -1)
     self.assertNotEqual(c57pos, -1)
     self.assertGreater(bb1pos, c41pos)
     self.assertGreater(bb1pos, c42pos)
     self.assertLess(bb1pos, c0pos)
     self.assertLess(bb1pos, c1pos)
     self.assertGreater(bb2pos, c0pos)
     self.assertGreater(bb2pos, c1pos)
     self.assertGreater(bb2pos, bb1pos)
     self.assertLess(bb2pos, c56pos)
     self.assertLess(bb2pos, c57pos)