Exemplo n.º 1
0
    def test_external_calls(self):
        dtype = torch.float32

        A = te.BufHandle('A', [1, 4], dtype)
        B = te.BufHandle('B', [4, 1], dtype)
        C = te.BufHandle('C', [1, 1], dtype)

        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])

        loopnest = te.LoopNest(s, [C])
        loopnest.prepare_for_codegen()
        codegen = te.construct_codegen('ir_eval', s, [A, B, C])

        tA = torch.ones(1, 4)
        tB = torch.ones(4, 1)
        tC = torch.empty(1, 1)
        codegen.call([tA, tB, tC])
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
Exemplo n.º 2
0
    def test_external_calls(self):
        dtype = torch.float32

        ONE = te.ExprHandle.int(1)
        FOUR = te.ExprHandle.int(4)
        A = te.BufHandle('A', [ONE, FOUR], dtype)
        B = te.BufHandle('B', [FOUR, ONE], dtype)
        C = te.BufHandle('C', [ONE, ONE], dtype)

        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])

        loopnest = te.LoopNest(s, [C])
        loopnest.prepare_for_codegen()
        codegen = te.construct_codegen('ir_eval', s,
                                       [te.BufferArg(x) for x in [A, B, C]])

        tA = torch.ones(1, 4)
        tB = torch.ones(4, 1)
        tC = torch.empty(1, 1)
        codegen.call([tA, tB, tC])
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
Exemplo n.º 3
0
def digamma_lower(name, out_shape, inp_shapes, args):
    out = te.BufHandle('out', get_te_shapes(out_shape),
                       get_nnc_type(inp_shapes[0][1]))
    s = te.ExternalCall(out, "nnc_aten_digamma", [args[0]], [])
    return out, [s]