def test_dynamic_shape(self): with kernel_arena_scope(): dN = te.VarHandle("n", te.Dtype.Int) A = te.Placeholder('A', te.Dtype.Double, [dN]) B = te.Placeholder('B', te.Dtype.Double, [dN]) def compute(i): return A.load([i]) - B.load([i]) C = te.Compute('C', [te.DimArg(dN, 'i')], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('ir_eval', stmt, [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_allclose(tA - tB, tC) test_with_shape(8) test_with_shape(31)
def test_dynamic_shape(self): dN = te.VarHandle(torch.int32) A = te.BufHandle(torch.float64) B = te.BufHandle(torch.float64) def compute(i): return A.load(i) - B.load(i) C = te.Compute('C', [dN], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen('ir_eval', loopnest.simplify(), [A, B, C, dN]) def test_with_shape(n): tA = torch.randn(n, dtype=torch.double) tB = torch.randn(n, dtype=torch.double) tC = torch.empty(n, dtype=torch.double) cg.call([tA, tB, tC, n]) torch.testing.assert_close(tA - tB, tC) test_with_shape(8) test_with_shape(31)
def test_dynamic_shape_2d(self): dN = te.VarHandle(torch.int32) dM = te.VarHandle(torch.int32) A = te.BufHandle([dN, dM], torch.float64) B = te.BufHandle([dN, dM], torch.float64) def compute(i, j): return A.load([i, j]) - B.load([i, j]) C = te.Compute("C", [dN, dM], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) def test_with_shape(n, m): tA = torch.randn(n, m, dtype=torch.double) tB = torch.randn(n, m, dtype=torch.double) tC = torch.empty(n, m, dtype=torch.double) cg.call([tA, tB, tC, n, m]) torch.testing.assert_close(tA - tB, tC) test_with_shape(2, 4) test_with_shape(5, 3)
def construct_te_fn(op, n: int, dtype=torch.float32): A = torch._C._te.BufHandle("A", [n], dtype) def compute(i): return op(A.load([i])) C = te.Compute("C", [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen("ir_eval", stmt, [A, C])
def test_alloc_in_loop(self): a, tmp, b = [ te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"] ] body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))]) for _ in range(4): i = te.VarHandle("i", torch.int32) body = te.For.make(i, 0, 100, body) nest = te.LoopNest(body, [b]) nest.prepare_for_codegen() f = te.construct_codegen("llvm", nest.simplify(), [a, b]) ta, tb = [torch.ones(1) for _ in range(2)] f.call([ta.data_ptr(), tb.data_ptr()])
def construct_adder(n: int, dtype=torch.float32): A = te.BufHandle('A', [n], dtype) B = te.BufHandle('B', [n], dtype) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute('C', [n], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen('ir_eval', stmt, [A, B, C])
def construct_adder(n: int, dtype=te.Dtype.Float): dN = te.ExprHandle.int(n) A = te.Placeholder('A', dtype, [dN]) B = te.Placeholder('B', dtype, [dN]) def compute(i): return A.load([i]) + B.load([i]) C = te.Compute('C', [te.DimArg(dN, 'i')], compute) loopnest = te.LoopNest([C]) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) return te.construct_codegen('ir_eval', stmt, [A, B, C])
def test_alloc_in_loop(self): a, tmp, b = [ te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)]) for name in ["a", "tmp", "b"] ] t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]] body = te.Block( [tmp.store([t0], a.load([t0])), b.store([t0], tmp.load([t0]))]) for _ in range(4): i = te.VarHandle("i", te.Dtype.Int) body = te.For.make(i, t0, t100, body) nest = te.LoopNest(body, [b.data()]) nest.prepare_for_codegen() f = te.construct_codegen("llvm", nest.simplify(), [a, b]) ta, tb = [torch.ones(1) for _ in range(2)] f.call([ta.data_ptr(), tb.data_ptr()])
def test_external_calls(self): dtype = torch.float32 A = te.BufHandle('A', [1, 4], dtype) B = te.BufHandle('B', [4, 1], dtype) C = te.BufHandle('C', [1, 1], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [A, B, C]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def test_external_calls(self): dtype = torch.float32 ONE = te.ExprHandle.int(1) FOUR = te.ExprHandle.int(4) A = te.BufHandle('A', [ONE, FOUR], dtype) B = te.BufHandle('B', [FOUR, ONE], dtype) C = te.BufHandle('C', [ONE, ONE], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inputs) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_shapes(node): return [te.ExprHandle.int(i) for i in node.shape] def get_nnc_type(dtype): if dtype == torch.float: return te.Dtype.Float elif dtype == torch.long: return te.Dtype.Long else: raise RuntimeError("nyi") def get_te_type(node): return get_nnc_type(node.dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. shapes = get_te_shapes(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) inputs.append(env[node.name]) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. result = lower_function(node, node.target, lookup_env(node.args), node.args) env[node.name] = result elif node.op == 'output': outs = list(lookup_env(node.args)) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) else: raise RuntimeError("not yet implemented") loopnest = te.LoopNest(outs) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs ]) def f(inps): module_stuff = [fetch_attr(i.target) for i in module_attrs] cg.call(module_stuff + list(inps)) return f
def nnc_compile(fx_model: fx.GraphModule, example_inputs, get_loopnest=False) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ # fx_model = fx.symbolic_trace(model) t = fx_model.graph.flatten_inps(*example_inputs) ShapeProp(fx_model).propagate(*fx_model.graph.flatten_inps( *example_inputs)) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_type(node): return get_nnc_type(node.meta['tensor_meta'].dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): res = fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) return res def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] compute_stmts = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. if 'tensor_meta' not in node.meta: continue shapes = get_te_shapes(node.meta['tensor_meta'].shape) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() inputs.append(placeholder) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. if 'tensor_meta' in node.meta: # todo: fix kwargs handling if node.kwargs: raise RuntimeError("kwargs nyi") buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args) # if isinstance(stmt, list) compute_stmts.extend(stmt) env[node.name] = buf elif node.target == getattr or node.target == operator.getitem: # todo: handle non-tensor computations correctly continue elif node.op == 'output': args = node.args args = pytree.tree_map( lambda x: list(x) if isinstance( x, fx.immutable_collections.immutable_list) else x, args) flat_args, _ = pytree.tree_flatten(list(args)) te_args = lookup_env(flat_args) outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in flat_args]) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) shapes = get_te_shapes( process_shape(node.meta['tensor_meta'].shape)) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() else: print(node.op, node.target) raise RuntimeError("not yet implemented") loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0]) if get_loopnest: return loopnest # loopnest.inline_intermediate_bufs(True) loopnest.simplify() loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0] ]) if module_attrs: module_stuff = [ fetch_attr(i.target).contiguous().data for i in module_attrs ] else: module_stuff = [] def f(*inps, out_tensors=None): inps = fx_model.graph.flatten_inps(*inps) if out_tensors is None: results = [ torch.empty(shape, dtype=dtype) for shape, dtype in outs[1] ] # results = alloc_results else: results = out_tensors full_inps = module_stuff + list(inps) + results cg.call(full_inps) results = fx_model.graph.unflatten_outs(results) return results if len(results) == 1: return results[0] return results return f