def test_dynamic_shape_2d(self): dN = te.VarHandle(torch.int32) dM = te.VarHandle(torch.int32) A = te.BufHandle([dN, dM], torch.float64) B = te.BufHandle([dN, dM], torch.float64) def compute(i, j): return A.load([i, j]) - B.load([i, j]) C = te.Compute("C", [dN, dM], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) def test_with_shape(n, m): tA = torch.randn(n, m, dtype=torch.double) tB = torch.randn(n, m, dtype=torch.double) tC = torch.empty(n, m, dtype=torch.double) cg.call([tA, tB, tC, n, m]) torch.testing.assert_close(tA - tB, tC) test_with_shape(2, 4) test_with_shape(5, 3)
def test_dynamic_shape(self): dN = te.VarHandle(torch.int32) A = te.BufHandle(torch.float64) B = te.BufHandle(torch.float64) def compute(i): return A.load(i) - B.load(i) C = te.Compute('C', [dN], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen('ir_eval', loopnest.simplify(), [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_close(tA - tB, tC) test_with_shape(8) test_with_shape(31)
def construct_adder(n: int, dtype=torch.float32): A = te.BufHandle('A', [n], dtype) B = te.BufHandle('B', [n], dtype) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute('C', [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen('ir_eval', stmt, [A, B, C])
def test_external_calls(self): dtype = torch.float32 A = te.BufHandle('A', [1, 4], dtype) B = te.BufHandle('B', [4, 1], dtype) C = te.BufHandle('C', [1, 1], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [A, B, C]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def test_external_calls(self): dtype = torch.float32 ONE = te.ExprHandle.int(1) FOUR = te.ExprHandle.int(4) A = te.BufHandle('A', [ONE, FOUR], dtype) B = te.BufHandle('B', [FOUR, ONE], dtype) C = te.BufHandle('C', [ONE, ONE], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def test_alloc_in_loop(self): a, tmp, b = [ te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"] ] body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))]) for _ in range(4): i = te.VarHandle("i", torch.int32) body = te.For.make(i, 0, 100, body) nest = te.LoopNest(body, [b]) nest.prepare_for_codegen() f = te.construct_codegen("llvm", nest.simplify(), [a, b]) ta, tb = [torch.ones(1) for _ in range(2)] f.call([ta.data_ptr(), tb.data_ptr()])
def test_dtype_error(self): te.BufHandle('a', [1], torch.float32) # ok self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
def digamma_lower(name, out_shape, inp_shapes, args): out = te.BufHandle('out', get_te_shapes(out_shape), get_nnc_type(inp_shapes[0][1])) s = te.ExternalCall(out, "nnc_aten_digamma", [args[0]], []) return out, [s]