def test_dynamic_shape_2d(self):
        dN = te.VarHandle(torch.int32)
        dM = te.VarHandle(torch.int32)
        A = te.BufHandle([dN, dM], torch.float64)
        B = te.BufHandle([dN, dM], torch.float64)

        def compute(i, j):
            return A.load([i, j]) - B.load([i, j])

        C = te.Compute("C", [dN, dM], compute)

        loopnest = te.LoopNest([C])
        loopnest.prepare_for_codegen()

        cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])

        def test_with_shape(n, m):
            tA = torch.randn(n, m, dtype=torch.double)
            tB = torch.randn(n, m, dtype=torch.double)
            tC = torch.empty(n, m, dtype=torch.double)
            cg.call([tA, tB, tC, n, m])
            torch.testing.assert_close(tA - tB, tC)

        test_with_shape(2, 4)
        test_with_shape(5, 3)
Пример #2
0
    def test_dynamic_shape(self):
        dN = te.VarHandle(torch.int32)
        A = te.BufHandle(torch.float64)
        B = te.BufHandle(torch.float64)

        def compute(i):
            return A.load(i) - B.load(i)

        C = te.Compute('C', [dN], compute)

        loopnest = te.LoopNest([C])
        loopnest.prepare_for_codegen()

        cg = te.construct_codegen('ir_eval', loopnest.simplify(),
                                  [A, B, C, dN])

        def test_with_shape(n):
            tA = torch.randn(n, dtype=torch.double)
            tB = torch.randn(n, dtype=torch.double)
            tC = torch.empty(n, dtype=torch.double)
            cg.call([tA, tB, tC, n])
            torch.testing.assert_close(tA - tB, tC)

        test_with_shape(8)
        test_with_shape(31)
Пример #3
0
def construct_adder(n: int, dtype=torch.float32):
    A = te.BufHandle('A', [n], dtype)
    B = te.BufHandle('B', [n], dtype)

    def compute(i):
        return A.load([i]) + B.load([i])

    C = te.Compute('C', [n], compute)

    loopnest = te.LoopNest([C])
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())

    return te.construct_codegen('ir_eval', stmt, [A, B, C])
Пример #4
0
    def test_external_calls(self):
        dtype = torch.float32

        A = te.BufHandle('A', [1, 4], dtype)
        B = te.BufHandle('B', [4, 1], dtype)
        C = te.BufHandle('C', [1, 1], dtype)

        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])

        loopnest = te.LoopNest(s, [C])
        loopnest.prepare_for_codegen()
        codegen = te.construct_codegen('ir_eval', s, [A, B, C])

        tA = torch.ones(1, 4)
        tB = torch.ones(4, 1)
        tC = torch.empty(1, 1)
        codegen.call([tA, tB, tC])
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
Пример #5
0
    def test_external_calls(self):
        dtype = torch.float32

        ONE = te.ExprHandle.int(1)
        FOUR = te.ExprHandle.int(4)
        A = te.BufHandle('A', [ONE, FOUR], dtype)
        B = te.BufHandle('B', [FOUR, ONE], dtype)
        C = te.BufHandle('C', [ONE, ONE], dtype)

        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])

        loopnest = te.LoopNest(s, [C])
        loopnest.prepare_for_codegen()
        codegen = te.construct_codegen('ir_eval', s,
                                       [te.BufferArg(x) for x in [A, B, C]])

        tA = torch.ones(1, 4)
        tB = torch.ones(4, 1)
        tC = torch.empty(1, 1)
        codegen.call([tA, tB, tC])
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
 def test_alloc_in_loop(self):
     a, tmp, b = [
         te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]
     ]
     body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))])
     for _ in range(4):
         i = te.VarHandle("i", torch.int32)
         body = te.For.make(i, 0, 100, body)
     nest = te.LoopNest(body, [b])
     nest.prepare_for_codegen()
     f = te.construct_codegen("llvm", nest.simplify(), [a, b])
     ta, tb = [torch.ones(1) for _ in range(2)]
     f.call([ta.data_ptr(), tb.data_ptr()])
Пример #7
0
 def test_dtype_error(self):
     te.BufHandle('a', [1], torch.float32)  # ok
     self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
Пример #8
0
def digamma_lower(name, out_shape, inp_shapes, args):
    out = te.BufHandle('out', get_te_shapes(out_shape),
                       get_nnc_type(inp_shapes[0][1]))
    s = te.ExternalCall(out, "nnc_aten_digamma", [args[0]], [])
    return out, [s]