def test_dynamic_shape(self): dN = te.VarHandle(torch.int32) A = te.BufHandle(torch.float64) B = te.BufHandle(torch.float64) def compute(i): return A.load(i) - B.load(i) C = te.Compute('C', [dN], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen('ir_eval', loopnest.simplify(), [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_close(tA - tB, tC) test_with_shape(8) test_with_shape(31)
def flatten_lower(name, out_shape, inp_shapes, args): A, start_dim, end_dim = args shape = list(inp_shapes[0][0]) flattened_region = shape[start_dim:end_dim + 1] def prod(x): t = 1 for i in x: t *= i return t def get_orig_idxs(i): idxs = [] total = prod(flattened_region) for dim in flattened_region: total //= dim idxs.append(i / to_expr(total)) i = i % to_expr(total) return idxs def f(*idxs): idxs = list(idxs) idxs = idxs[:start_dim] + get_orig_idxs( idxs[start_dim]) + idxs[start_dim + 1:] return A.load(idxs) return te.Compute(name, get_dim_args(out_shape), f)
def test_dynamic_shape_2d(self): dN = te.VarHandle(torch.int32) dM = te.VarHandle(torch.int32) A = te.BufHandle([dN, dM], torch.float64) B = te.BufHandle([dN, dM], torch.float64) def compute(i, j): return A.load([i, j]) - B.load([i, j]) C = te.Compute("C", [dN, dM], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) def test_with_shape(n, m): tA = torch.randn(n, m, dtype=torch.double) tB = torch.randn(n, m, dtype=torch.double) tC = torch.empty(n, m, dtype=torch.double) cg.call([tA, tB, tC, n, m]) torch.testing.assert_close(tA - tB, tC) test_with_shape(2, 4) test_with_shape(5, 3)
def test_dynamic_shape(self): with kernel_arena_scope(): dN = te.VarHandle("n", te.Dtype.Int) A = te.Placeholder('A', te.Dtype.Double, [dN]) B = te.Placeholder('B', te.Dtype.Double, [dN]) def compute(i): return A.load([i]) - B.load([i]) C = te.Compute('C', [te.DimArg(dN, 'i')], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('ir_eval', stmt, [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_allclose(tA - tB, tC) test_with_shape(8) test_with_shape(31)
def mm_lower(name, out_shape, inp_shapes, args): M1 = args[0] M2 = args[1] N, M = inp_shapes[0][0] P = inp_shapes[1][0][1] def f(n, p, m): return M1.load([n, m]) * M2.load([m, p]) mm = te.Compute('mm', get_dim_args([N, P, M]), f) return te.SumReduce(name, get_dim_args([N, P]), mm, get_dim_args([M]))
def bmm_lower(name, out_shape, inp_shapes, args): M1 = args[0] M2 = args[1] B, N, M = inp_shapes[0][0] P = inp_shapes[1][0][2] def f(b, n, p, m): return M1.load([b, n, m]) * M2.load([b, m, p]) mm = te.Compute('mm', get_dim_args([B, N, P, M]), f) return te.Reduce(name, get_dim_args([B, N, P]), te.Sum(), mm, get_dim_args([M]))
def transpose_lower(name, out_shape, inp_shapes, args): idx_1, idx_2 = args[1], args[2] def transpose(shape): shape[idx_1], shape[idx_2] = shape[idx_2], shape[idx_1] return shape def f(*idxs): idxs = transpose(list(idxs)) return args[0].load(idxs) return te.Compute(name, get_dim_args(out_shape), f)
def construct_te_fn(op, n: int, dtype=torch.float32): A = torch._C._te.BufHandle("A", [n], dtype) def compute(i): return op(A.load([i])) C = te.Compute("C", [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen("ir_eval", stmt, [A, C])
def construct_adder(n: int, dtype=torch.float32): A = te.BufHandle('A', [n], dtype) B = te.BufHandle('B', [n], dtype) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute('C', [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen('ir_eval', stmt, [A, B, C])
def construct_adder(n: int, dtype=te.Dtype.Float): dN = te.ExprHandle.int(n) A = te.Placeholder('A', dtype, [dN]) B = te.Placeholder('B', dtype, [dN]) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute('C', [te.DimArg(dN, 'i')], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen('ir_eval', stmt, [A, B, C])
def cat_lower(name, out_shape, inp_shapes, args): tensors = args[0] dim = args[1] lengths = [i[0][dim] for i in inp_shapes[0]] def f(*idxs): idxs = list(idxs) sm = lengths[0] load = tensors[0].load(idxs) for length, tensor in list(zip(lengths, tensors))[1:]: new_idxs = idxs[:] new_idxs[dim] -= to_expr(sm) load = te.ifThenElse(idxs[dim] < to_expr(sm), load, tensor.load(new_idxs)) return load return te.Compute(name, get_dim_args(out_shape), f)
def fn_lower(name, out_shape, inp_shapes, args): X = te.Compute(name, get_dim_args(out_shape), f(inp_shapes, args)) return X
def ones_like_lower(name, out_shape, inp_shapes, args): def f(*idxs): return to_expr(1.0) res = te.Compute(name, get_dim_args(out_shape), f) return res