def test_external_calls(self): dtype = torch.float32 A = te.BufHandle('A', [1, 4], dtype) B = te.BufHandle('B', [4, 1], dtype) C = te.BufHandle('C', [1, 1], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [A, B, C]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def test_external_calls(self): dtype = torch.float32 ONE = te.ExprHandle.int(1) FOUR = te.ExprHandle.int(4) A = te.BufHandle('A', [ONE, FOUR], dtype) B = te.BufHandle('B', [FOUR, ONE], dtype) C = te.BufHandle('C', [ONE, ONE], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def digamma_lower(name, out_shape, inp_shapes, args): out = te.BufHandle('out', get_te_shapes(out_shape), get_nnc_type(inp_shapes[0][1])) s = te.ExternalCall(out, "nnc_aten_digamma", [args[0]], []) return out, [s]