def testFunctionMultiple(self): self.setUp() with self.module.function_context("foo", [], []): pass with self.module.function_context("foo", [], []): E.constant_index(0) printWithCurrentFunctionName(str(self.module))
def testLoopNestContext(self): with self.module.function_context("foo", [], []) as fun: lbs = [E.constant_index(i) for i in range(4)] ubs = [E.constant_index(10 * i + 5) for i in range(4)] with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l): i + j + k + l code = str(fun) self.assertIn( ' "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (5)} : () -> () {\n', code) self.assertIn(" ^bb1(%i0: index):", code) self.assertIn( ' "affine.for"() {lower_bound: () -> (1), step: 3 : index, upper_bound: () -> (15)} : () -> () {\n', code) self.assertIn(" ^bb2(%i1: index):", code) self.assertIn( ' "affine.for"() {lower_bound: () -> (2), step: 5 : index, upper_bound: () -> (25)} : () -> () {\n', code) self.assertIn(" ^bb3(%i2: index):", code) self.assertIn( ' "affine.for"() {lower_bound: () -> (3), step: 7 : index, upper_bound: () -> (35)} : () -> () {\n', code) self.assertIn(" ^bb4(%i3: index):", code) self.assertIn( ' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index', code)
def testConstants(self): with self.module.function_context("constants", [], []) as fun: E.constant_float(1.23, self.module.make_scalar_type("bf16")) E.constant_float(1.23, self.module.make_scalar_type("f16")) E.constant_float(1.23, self.module.make_scalar_type("f32")) E.constant_float(1.23, self.module.make_scalar_type("f64")) E.constant_int(1, 1) E.constant_int(123, 8) E.constant_int(123, 16) E.constant_int(123, 32) E.constant_int(123, 64) E.constant_index(123) E.constant_function(fun) code = str(fun) self.assertIn("constant 1.230000e+00 : bf16", code) self.assertIn("constant 1.230470e+00 : f16", code) self.assertIn("constant 1.230000e+00 : f32", code) self.assertIn("constant 1.230000e+00 : f64", code) self.assertIn("constant 1 : i1", code) self.assertIn("constant 123 : i8", code) self.assertIn("constant 123 : i16", code) self.assertIn("constant 123 : i32", code) self.assertIn("constant 123 : index", code) self.assertIn("constant @constants : () -> ()", code)
def testBlockArguments(self): self.setUp() with self.module.function_context("foo", [], []) as fun: E.constant_index(42) with E.BlockContext([self.f32Type, self.f32Type]) as b: b.arg(0) + b.arg(1) printWithCurrentFunctionName(str(fun))
def testIndexedValue(self): memrefType = self.module.make_memref_type(self.f32Type, [10, 42]) with self.module.function_context("indexed", [memrefType], [memrefType]) as fun: A = E.IndexedValue(fun.arg(0)) cst = E.constant_float(1., self.f32Type) with E.LoopNestContext( [E.constant_index(0), E.constant_index(0)], [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j): A.store([i, j], A.load([i, j]) + cst) E.ret([fun.arg(0)]) code = str(fun) self.assertIn('"affine.for"()', code) self.assertIn( "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}", code) self.assertIn('"affine.for"()', code) self.assertIn( "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}", code) self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code) self.assertIn("%1 = addf %0, %cst : f32", code) self.assertIn("store %1, %arg0[%i0, %i1] : memref<10x42xf32>", code)
def testLoopNestContext(self): self.setUp() with self.module.function_context("foo", [], []) as fun: lbs = [E.constant_index(i) for i in range(4)] ubs = [E.constant_index(10 * i + 5) for i in range(4)] with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l): i + j + k + l printWithCurrentFunctionName(str(fun))
def testBrArgs(self): self.setUp() with self.module.function_context("foo", [], []) as fun: # Create an infinite loop. with E.BlockContext([self.indexType, self.indexType]) as b: E.br(b, [b.arg(1), b.arg(0)]) E.br(b, [E.constant_index(0), E.constant_index(1)]) printWithCurrentFunctionName(str(fun))
def testRet(self): self.setUp() with self.module.function_context( "foo", [], [self.indexType, self.indexType]) as fun: c42 = E.constant_index(42) c0 = E.constant_index(0) E.ret([c42, c0]) printWithCurrentFunctionName(str(fun))
def testBlockArguments(self): with self.module.function_context("foo", [], []) as fun: E.constant_index(42) with E.BlockContext([self.f32Type, self.f32Type]) as b: b.arg(0) + b.arg(1) code = str(fun) self.assertIn("%c42 = constant 42 : index", code) self.assertIn("^bb1(%0: f32, %1: f32):", code) self.assertIn(" %2 = addf %0, %1 : f32", code)
def testRet(self): with self.module.function_context( "foo", [], [self.indexType, self.indexType]) as fun: c42 = E.constant_index(42) c0 = E.constant_index(0) E.ret([c42, c0]) code = str(fun) self.assertIn(" %c42 = constant 42 : index", code) self.assertIn(" %c0 = constant 0 : index", code) self.assertIn(" return %c42, %c0 : index, index", code)
def testLoopContext(self): self.setUp() with self.module.function_context("foo", [], []) as fun: lhs = E.constant_index(0) rhs = E.constant_index(42) with E.LoopContext(lhs, rhs, 1) as i: lhs + rhs + i with E.LoopContext(rhs, rhs + rhs, 2) as j: x = i + j printWithCurrentFunctionName(str(fun))
def testBrArgs(self): with self.module.function_context("foo", [], []) as fun: # Create an infinite loop. with E.BlockContext([self.indexType, self.indexType]) as b: E.br(b, [b.arg(1), b.arg(0)]) E.br(b, [E.constant_index(0), E.constant_index(1)]) code = str(fun) self.assertIn(" %c0 = constant 0 : index", code) self.assertIn(" %c1 = constant 1 : index", code) self.assertIn(" br ^bb1(%c0, %c1 : index, index)", code) self.assertIn("^bb1(%0: index, %1: index):", code) self.assertIn(" br ^bb1(%1, %0 : index, index)", code)
def testIndexedValue(self): self.setUp() memrefType = self.module.make_memref_type(self.f32Type, [10, 42]) with self.module.function_context("indexed", [memrefType], [memrefType]) as fun: A = E.IndexedValue(fun.arg(0)) cst = E.constant_float(1., self.f32Type) with E.LoopNestContext( [E.constant_index(0), E.constant_index(0)], [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j): A.store([i, j], A.load([i, j]) + cst) E.ret([fun.arg(0)]) printWithCurrentFunctionName(str(fun))
def testBlockContextAppend(self): with self.module.function_context("foo", [], []) as fun: E.constant_index(41) with E.BlockContext() as b: blk = b # save block handle for later E.constant_index(0) E.constant_index(42) with E.BlockContext(E.appendTo(blk)): E.constant_index(1) code = str(fun) # Find positions of instructions and make sure they are in the block we put # them by comparing those positions. c41pos = code.find("%c41 = constant 41 : index") c42pos = code.find("%c42 = constant 42 : index") bb1pos = code.find("^bb1:") c0pos = code.find("%c0 = constant 0 : index") c1pos = code.find("%c1 = constant 1 : index") self.assertNotEqual(c41pos, -1) self.assertNotEqual(c42pos, -1) self.assertNotEqual(bb1pos, -1) self.assertNotEqual(c0pos, -1) self.assertNotEqual(c1pos, -1) self.assertGreater(bb1pos, c41pos) self.assertGreater(bb1pos, c42pos) self.assertLess(bb1pos, c0pos) self.assertLess(bb1pos, c1pos)
def testMultipleFunctions(self): with self.module.function_context("foo", [], []): E.constant_index(0) code = str(self.module) self.assertIn("func @foo()", code) self.assertIn(" %c0 = constant 0 : index", code) with self.module.function_context("bar", [], []): E.constant_index(42) code = str(self.module) barPos = code.find("func @bar()") c42Pos = code.find("%c42 = constant 42 : index") self.assertNotEqual(barPos, -1) self.assertNotEqual(c42Pos, -1) self.assertGreater(c42Pos, barPos)
def testMatrixMultiply(self): self.setUp() memrefType = self.module.make_memref_type(self.f32Type, [32, 32]) with self.module.function_context( "matmul", [memrefType, memrefType, memrefType], []) as fun: A = E.IndexedValue(fun.arg(0)) B = E.IndexedValue(fun.arg(1)) C = E.IndexedValue(fun.arg(2)) c0 = E.constant_index(0) c32 = E.constant_index(32) with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j, k): C.store([i, j], A.load([i, k]) * B.load([k, j])) E.ret([]) printWithCurrentFunctionName(str(fun))
def testConstants(self): self.setUp() with self.module.function_context("constants", [], []) as fun: E.constant_float(1.23, self.module.make_scalar_type("bf16")) E.constant_float(1.23, self.module.make_scalar_type("f16")) E.constant_float(1.23, self.module.make_scalar_type("f32")) E.constant_float(1.23, self.module.make_scalar_type("f64")) E.constant_int(1, 1) E.constant_int(123, 8) E.constant_int(123, 16) E.constant_int(123, 32) E.constant_int(123, 64) E.constant_index(123) E.constant_function(fun) printWithCurrentFunctionName(str(fun))
def testMLIRBooleanCompilation(self): self.setUp() m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor with self.module.function_context("mkbooltensor", [m, m], []) as f: input = E.IndexedValue(f.arg(0)) output = E.IndexedValue(f.arg(1)) zero = E.constant_index(0) ten = E.constant_index(10) with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k): b1 = (i < j) & (j < k) b2 = ~b1 b3 = b2 | (k < j) output.store([i], input.load([i]) & b3) E.ret([]) self.module.compile() printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
def testBlockContext(self): self.setUp() with self.module.function_context("foo", [], []) as fun: cst = E.constant_index(42) with E.BlockContext(): cst + cst printWithCurrentFunctionName(str(fun))
def testDivisions(self): self.setUp() with self.module.function_context( "division", [self.indexType, self.i32Type, self.i32Type], []) as fun: # indices only support floor division fun.arg(0) // E.constant_index(42) # regular values only support regular division fun.arg(1) / fun.arg(2) printWithCurrentFunctionName(str(self.module))
def testBlockContextAppend(self): self.setUp() with self.module.function_context("foo", [], []) as fun: E.constant_index(41) with E.BlockContext() as b: blk = b # save block handle for later E.constant_index(0) E.constant_index(42) with E.BlockContext(E.appendTo(blk)): E.constant_index(1) printWithCurrentFunctionName(str(fun))
def testCondBr(self): self.setUp() with self.module.function_context("foo", [self.boolType], []) as fun: with E.BlockContext() as blk1: E.ret([]) with E.BlockContext([self.indexType]) as blk2: E.ret([]) cst = E.constant_index(0) E.cond_br(fun.arg(0), blk1, [], blk2, [cst]) printWithCurrentFunctionName(str(fun))
def testCondBr(self): with self.module.function_context("foo", [self.boolType], []) as fun: with E.BlockContext() as blk1: E.ret([]) with E.BlockContext([self.indexType]) as blk2: E.ret([]) cst = E.constant_index(0) E.cond_br(fun.arg(0), blk1, [], blk2, [cst]) code = str(fun) self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code)
def testLoopContext(self): with self.module.function_context("foo", [], []) as fun: lhs = E.constant_index(0) rhs = E.constant_index(42) with E.LoopContext(lhs, rhs, 1) as i: lhs + rhs + i with E.LoopContext(rhs, rhs + rhs, 2) as j: x = i + j code = str(fun) # TODO(zinenko,ntv): use FileCheck for these tests self.assertIn( ' "affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} : () -> () {\n', code) self.assertIn(" ^bb1(%i0: index):", code) self.assertIn( ' "affine.for"(%c42, %2) {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () {\n', code) self.assertIn(" ^bb2(%i1: index):", code) self.assertIn( ' %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index', code)
def testDivisions(self): with self.module.function_context( "division", [self.indexType, self.i32Type, self.i32Type], []) as fun: # indices only support floor division fun.arg(0) // E.constant_index(42) # regular values only support regular division fun.arg(1) / fun.arg(2) code = str(self.module) self.assertIn("floordiv 42", code) self.assertIn("divis %arg1, %arg2 : i32", code)
def testMatrixMultiply(self): memrefType = self.module.make_memref_type(self.f32Type, [32, 32]) with self.module.function_context("matmul", [memrefType, memrefType, memrefType], []) as fun: A = E.IndexedValue(fun.arg(0)) B = E.IndexedValue(fun.arg(1)) C = E.IndexedValue(fun.arg(2)) c0 = E.constant_index(0) c32 = E.constant_index(32) with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j, k): C.store([i, j], A.load([i, k]) * B.load([k, j])) E.ret([]) code = str(fun) self.assertIn( '"affine.for"() {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()', code) self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code) self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code) self.assertIn("%2 = mulf %0, %1 : f32", code) self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)
def testBlockContext(self): with self.module.function_context("foo", [], []) as fun: cst = E.constant_index(42) with E.BlockContext(): cst + cst code = str(fun) # Find positions of instructions and make sure they are in the block we # put them by comparing those positions. # TODO(zinenko,ntv): this (and tests below) should use FileCheck instead. c42pos = code.find("%c42 = constant 42 : index") bb1pos = code.find("^bb1:") c84pos = code.find( '%0 = "affine.apply"() {map: () -> (84)} : () -> index') self.assertNotEqual(c42pos, -1) self.assertNotEqual(bb1pos, -1) self.assertNotEqual(c84pos, -1) self.assertGreater(bb1pos, c42pos) self.assertLess(bb1pos, c84pos)
def testIndexCast(self): self.setUp() with self.module.function_context("testIndexCast", [], []): index = E.constant_index(0) E.index_cast(index, self.i32Type) printWithCurrentFunctionName(str(self.module))
def testBlockContextStandalone(self): self.setUp() with self.module.function_context("foo", [], []) as fun: blk1 = E.BlockContext() blk2 = E.BlockContext() with blk1: E.constant_index(0) with blk2: E.constant_index(56) E.constant_index(57) E.constant_index(41) with blk1: E.constant_index(1) E.constant_index(42) printWithCurrentFunctionName(str(fun))
def testBlockContextStandalone(self): with self.module.function_context("foo", [], []) as fun: blk1 = E.BlockContext() blk2 = E.BlockContext() with blk1: E.constant_index(0) with blk2: E.constant_index(56) E.constant_index(57) E.constant_index(41) with blk1: E.constant_index(1) E.constant_index(42) code = str(fun) # Find positions of instructions and make sure they are in the block we put # them by comparing those positions. c41pos = code.find(" %c41 = constant 41 : index") c42pos = code.find(" %c42 = constant 42 : index") bb1pos = code.find("^bb1:") c0pos = code.find(" %c0 = constant 0 : index") c1pos = code.find(" %c1 = constant 1 : index") bb2pos = code.find("^bb2:") c56pos = code.find(" %c56 = constant 56 : index") c57pos = code.find(" %c57 = constant 57 : index") self.assertNotEqual(c41pos, -1) self.assertNotEqual(c42pos, -1) self.assertNotEqual(bb1pos, -1) self.assertNotEqual(c0pos, -1) self.assertNotEqual(c1pos, -1) self.assertNotEqual(bb2pos, -1) self.assertNotEqual(c56pos, -1) self.assertNotEqual(c57pos, -1) self.assertGreater(bb1pos, c41pos) self.assertGreater(bb1pos, c42pos) self.assertLess(bb1pos, c0pos) self.assertLess(bb1pos, c1pos) self.assertGreater(bb2pos, c0pos) self.assertGreater(bb2pos, c1pos) self.assertGreater(bb2pos, bb1pos) self.assertLess(bb2pos, c56pos) self.assertLess(bb2pos, c57pos)