Esempio n. 1
0
    def testIndexedValue(self):
        memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
        with self.module.function_context("indexed", [memrefType],
                                          [memrefType]) as fun:
            A = E.IndexedValue(fun.arg(0))
            cst = E.constant_float(1., self.f32Type)
            with E.LoopNestContext(
                [E.constant_index(0), E.constant_index(0)],
                [E.constant_index(10),
                 E.constant_index(42)], [1, 1]) as (i, j):
                A.store([i, j], A.load([i, j]) + cst)
            E.ret([fun.arg(0)])

        code = str(fun)
        self.assertIn('"affine.for"()', code)
        self.assertIn(
            "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}",
            code)
        self.assertIn('"affine.for"()', code)
        self.assertIn(
            "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}",
            code)
        self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code)
        self.assertIn("%1 = addf %0, %cst : f32", code)
        self.assertIn("store %1, %arg0[%i0, %i1] : memref<10x42xf32>", code)
 def testBr(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         with E.BlockContext() as b:
             blk = b
             E.ret()
         E.br(blk)
         printWithCurrentFunctionName(str(fun))
 def testBrDeclaration(self):
     self.setUp()
     with self.module.function_context("foo", [], []) as fun:
         blk = E.BlockContext()
         E.br(blk.handle())
         with blk:
             E.ret()
         printWithCurrentFunctionName(str(fun))
 def testSelectOp(self):
     self.setUp()
     with self.module.function_context("foo", [self.boolType],
                                       [self.i32Type]) as fun:
         a = E.constant_int(42, 32)
         b = E.constant_int(0, 32)
         E.ret([E.select(fun.arg(0), a, b)])
         printWithCurrentFunctionName(str(fun))
 def testRet(self):
     self.setUp()
     with self.module.function_context(
             "foo", [], [self.indexType, self.indexType]) as fun:
         c42 = E.constant_index(42)
         c0 = E.constant_index(0)
         E.ret([c42, c0])
         printWithCurrentFunctionName(str(fun))
Esempio n. 6
0
    def testCustomOpCompilation(self):
        with self.module.function_context("adder", [self.i32Type], []) as f:
            c1 = E.op("std.constant", [], [self.i32Type],
                      value=self.module.integerAttr(self.i32Type, 42))
            E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
            E.ret([])

        self.module.compile()
        self.assertNotEqual(self.module.get_engine_address(), 0)
Esempio n. 7
0
    def testSelectOp(self):
        with self.module.function_context("foo", [self.boolType],
                                          [self.i32Type]) as fun:
            a = E.constant_int(42, 32)
            b = E.constant_int(0, 32)
            E.ret([E.select(fun.arg(0), a, b)])

        code = str(fun)
        self.assertIn("%0 = select %arg0, %c42_i32, %c0_i32 : i32", code)
Esempio n. 8
0
 def testRet(self):
     with self.module.function_context(
             "foo", [], [self.indexType, self.indexType]) as fun:
         c42 = E.constant_index(42)
         c0 = E.constant_index(0)
         E.ret([c42, c0])
     code = str(fun)
     self.assertIn("  %c42 = constant 42 : index", code)
     self.assertIn("  %c0 = constant 0 : index", code)
     self.assertIn("  return %c42, %c0 : index, index", code)
Esempio n. 9
0
 def testBrDeclaration(self):
     with self.module.function_context("foo", [], []) as fun:
         blk = E.BlockContext()
         E.br(blk.handle())
         with blk:
             E.ret()
     code = str(fun)
     self.assertIn("  br ^bb1", code)
     self.assertIn("^bb1:", code)
     self.assertIn("  return", code)
Esempio n. 10
0
 def testBr(self):
     with self.module.function_context("foo", [], []) as fun:
         with E.BlockContext() as b:
             blk = b
             E.ret()
         E.br(blk)
     code = str(fun)
     self.assertIn("  br ^bb1", code)
     self.assertIn("^bb1:", code)
     self.assertIn("  return", code)
 def testCustomOpCompilation(self):
     self.setUp()
     with self.module.function_context("adder", [self.i32Type], []) as f:
         c1 = E.op("std.constant", [], [self.i32Type],
                   value=self.module.integerAttr(self.i32Type, 42))
         E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
         E.ret([])
     self.module.compile()
     printWithCurrentFunctionName(
         str(self.module.get_engine_address() == 0))
 def testCondBr(self):
     self.setUp()
     with self.module.function_context("foo", [self.boolType], []) as fun:
         with E.BlockContext() as blk1:
             E.ret([])
         with E.BlockContext([self.indexType]) as blk2:
             E.ret([])
         cst = E.constant_index(0)
         E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
         printWithCurrentFunctionName(str(fun))
Esempio n. 13
0
    def testCondBr(self):
        with self.module.function_context("foo", [self.boolType], []) as fun:
            with E.BlockContext() as blk1:
                E.ret([])
            with E.BlockContext([self.indexType]) as blk2:
                E.ret([])
            cst = E.constant_index(0)
            E.cond_br(fun.arg(0), blk1, [], blk2, [cst])

        code = str(fun)
        self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code)
Esempio n. 14
0
 def testIndexedValue(self):
   self.setUp()
   memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
   with self.module.function_context("indexed", [memrefType],
                                     [memrefType]) as fun:
     A = E.IndexedValue(fun.arg(0))
     cst = E.constant_float(1., self.f32Type)
     with E.LoopNestContext(
         [E.constant_index(0), E.constant_index(0)],
         [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
       A.store([i, j], A.load([i, j]) + cst)
     E.ret([fun.arg(0)])
     printWithCurrentFunctionName(str(fun))
Esempio n. 15
0
 def testMatrixMultiply(self):
   self.setUp()
   memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
   with self.module.function_context(
       "matmul", [memrefType, memrefType, memrefType], []) as fun:
     A = E.IndexedValue(fun.arg(0))
     B = E.IndexedValue(fun.arg(1))
     C = E.IndexedValue(fun.arg(2))
     c0 = E.constant_index(0)
     c32 = E.constant_index(32)
     with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j,
                                                                          k):
       C.store([i, j], A.load([i, k]) * B.load([k, j]))
     E.ret([])
     printWithCurrentFunctionName(str(fun))
Esempio n. 16
0
 def testMLIRBooleanCompilation(self):
   self.setUp()
   m = self.module.make_memref_type(self.boolType, [10])  # i1 tensor
   with self.module.function_context("mkbooltensor", [m, m], []) as f:
     input = E.IndexedValue(f.arg(0))
     output = E.IndexedValue(f.arg(1))
     zero = E.constant_index(0)
     ten = E.constant_index(10)
     with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
       b1 = (i < j) & (j < k)
       b2 = ~b1
       b3 = b2 | (k < j)
       output.store([i], input.load([i]) & b3)
     E.ret([])
   self.module.compile()
   printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
Esempio n. 17
0
    def testMatrixMultiply(self):
        memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
        with self.module.function_context("matmul",
                                          [memrefType, memrefType, memrefType],
                                          []) as fun:
            A = E.IndexedValue(fun.arg(0))
            B = E.IndexedValue(fun.arg(1))
            C = E.IndexedValue(fun.arg(2))
            c0 = E.constant_index(0)
            c32 = E.constant_index(32)
            with E.LoopNestContext([c0, c0, c0], [c32, c32, c32],
                                   [1, 1, 1]) as (i, j, k):
                C.store([i, j], A.load([i, k]) * B.load([k, j]))
            E.ret([])

        code = str(fun)
        self.assertIn(
            '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()',
            code)
        self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code)
        self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code)
        self.assertIn("%2 = mulf %0, %1 : f32", code)
        self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)