def test_external_calls(self): dtype = torch.float32 ONE = te.ExprHandle.int(1) FOUR = te.ExprHandle.int(4) A = te.BufHandle('A', [ONE, FOUR], dtype) B = te.BufHandle('B', [FOUR, ONE], dtype) C = te.BufHandle('C', [ONE, ONE], dtype) s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) loopnest = te.LoopNest(s, [C]) loopnest.prepare_for_codegen() codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) tA = torch.ones(1, 4) tB = torch.ones(4, 1) tC = torch.empty(1, 1) codegen.call([tA, tB, tC]) torch.testing.assert_close(torch.matmul(tA, tB), tC)
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inputs) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_shapes(node): return [te.ExprHandle.int(i) for i in node.shape] def get_nnc_type(dtype): if dtype == torch.float: return te.Dtype.Float elif dtype == torch.long: return te.Dtype.Long else: raise RuntimeError("nyi") def get_te_type(node): return get_nnc_type(node.dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. shapes = get_te_shapes(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) inputs.append(env[node.name]) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. result = lower_function(node, node.target, lookup_env(node.args), node.args) env[node.name] = result elif node.op == 'output': outs = list(lookup_env(node.args)) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) else: raise RuntimeError("not yet implemented") loopnest = te.LoopNest(outs) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs ]) def f(inps): module_stuff = [fetch_attr(i.target) for i in module_attrs] cg.call(module_stuff + list(inps)) return f
def nnc_compile(fx_model: fx.GraphModule, example_inputs, get_loopnest=False) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ # fx_model = fx.symbolic_trace(model) t = fx_model.graph.flatten_inps(*example_inputs) ShapeProp(fx_model).propagate(*fx_model.graph.flatten_inps( *example_inputs)) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_type(node): return get_nnc_type(node.meta['tensor_meta'].dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): res = fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) return res def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] compute_stmts = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. if 'tensor_meta' not in node.meta: continue shapes = get_te_shapes(node.meta['tensor_meta'].shape) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() inputs.append(placeholder) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. if 'tensor_meta' in node.meta: # todo: fix kwargs handling if node.kwargs: raise RuntimeError("kwargs nyi") buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args) # if isinstance(stmt, list) compute_stmts.extend(stmt) env[node.name] = buf elif node.target == getattr or node.target == operator.getitem: # todo: handle non-tensor computations correctly continue elif node.op == 'output': args = node.args args = pytree.tree_map( lambda x: list(x) if isinstance( x, fx.immutable_collections.immutable_list) else x, args) flat_args, _ = pytree.tree_flatten(list(args)) te_args = lookup_env(flat_args) outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in flat_args]) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) shapes = get_te_shapes( process_shape(node.meta['tensor_meta'].shape)) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() else: print(node.op, node.target) raise RuntimeError("not yet implemented") loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0]) if get_loopnest: return loopnest # loopnest.inline_intermediate_bufs(True) loopnest.simplify() loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0] ]) if module_attrs: module_stuff = [ fetch_attr(i.target).contiguous().data for i in module_attrs ] else: module_stuff = [] def f(*inps, out_tensors=None): inps = fx_model.graph.flatten_inps(*inps) if out_tensors is None: results = [ torch.empty(shape, dtype=dtype) for shape, dtype in outs[1] ] # results = alloc_results else: results = out_tensors full_inps = module_stuff + list(inps) + results cg.call(full_inps) results = fx_model.graph.unflatten_outs(results) return results if len(results) == 1: return results[0] return results return f