def decompose(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ decompose(model, example_inputs) takes in a model, decomposes any of the functions in `decomposition_rules` to its constituent operations, and returns a `nn.Module` without any of the operations with decomposition rules. """ # Run it multiple times so we converge to a fixed point. for _ in range(5): model = fx.symbolic_trace(model) ShapeProp(model).propagate(*example_inputs) new_graph = fx.Graph() env = {} for node in model.graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: # If the current function is in `decomposition_rules`, we use # `Proxy` objects to decompose the operations using the # decomposition rule. See # https://pytorch.org/docs/master/fx.html#proxy-retracing for # more details. proxy_args = map_arg(node.args, lambda n: fx.Proxy(env[n.name])) proxy_kwargs = map_arg(node.kwargs, lambda n: fx.Proxy(env[n.name])) new_node = decomposition_rules[node.target]( *proxy_args, **proxy_kwargs).node env[node.name] = new_node else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node model = fx.GraphModule(model, new_graph) return model
def use_mkl_heuristic(graph: MklSubgraph) -> bool: nonlocal fx_model, old_modules input_nodes = graph.start_nodes if fx_model is None: fx_model = graph.fx_graph.owning_module old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] ShapeProp(fx_model).propagate(example_inputs) sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) def benchmark(f): for _ in range(warmup): f() begin = time.time() for _ in range(iters): out = f() return time.time() - begin mkl_time = benchmark(lambda: [ i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) ]) reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) return mkl_time < no_mkl_time
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
def _test_const_fold_tensor_meta(self, requires_grad): """ Verify tensor_meta is handled correctly. """ class ConstFoldTestModule(torch.nn.Module): def __init__(self): super().__init__() self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad) self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad) def forward(self, x, y): a = self.attr_1 + self.attr_1 x = x - a return x * y + self.attr_2 mod = ConstFoldTestModule() gm = torch.fx.symbolic_trace(mod) in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) ShapeProp(gm).propagate(in_x, in_y) mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) self._verify_const_fold_mod(mod_folded) mod_folded.run_folding() for n in mod_folded.graph.nodes: if n.op == "get_attr": attr = self._get_attr(n) self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"]) # Now run both folded and non-folded to check results equal. base_result = mod(in_x, in_y) fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result))
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module: """vmap Given a model with inputs, vmap will return a function that works on batched versions of those inputs. Which inputs will be batched is determined by in_axes. In addition, as vmap requires shape (actually rank) information, we will pass in example_args (example inputs for the original module). """ in_axes = iter(in_axes) fx_model = fx.symbolic_trace(model) # Here we run a shape propagation pass in order to annotate the graph with shape information. ShapeProp(fx_model).propagate(*example_args) # As vmap rewrites the whole graph, it's easiest to create an entirely new # graph and append to that. new_graph: fx.Graph = fx.Graph() # We will create an environment to map the new nodes created to the # corresponding old nodes. def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) env = {} for node in fx_model.graph.nodes: if node.op == 'placeholder': # If the node is an input placeholder, we simply copy it over and # annotate it with the batch dimension from `in_axes`. new_node = new_graph.placeholder(node.name) new_node.bdim = next(in_axes) new_node.meta = node.meta env[node.name] = new_node elif node.op == 'output': new_graph.output(env[node.args[0].name]) elif node.op == 'call_function': new_args = lookup_env(node.args) # If any of the inputs to the function has a new batch dimension, # we will need to use our batching rules. Otherwise, we will simply # copy the node over. if any([ x.bdim is not None for x in new_args if isinstance(x, fx.Node) ]): new_node = gen_batching_rule_function(node.target, *new_args) else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) new_node.bdim = None new_node.meta = node.meta env[node.name] = new_node else: raise RuntimeError("Not yet implemented") res = fx.GraphModule(fx_model, new_graph) print(res.code) res.graph.lint() return res
def grad(model: torch.nn.Module, example_inps: Tuple[Any, ...], get_value=True) -> torch.nn.Module: fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inps) # graph and append to that. val_map = {} new_graph: fx.Graph = fx.Graph() orig_output = new_graph.graph_copy(fx_model.graph, val_map) def shape_proxy(node): proxy = fx.Proxy(val_map[node]) proxy.shape = node.meta['shape'] proxy.dim = lambda: len(proxy.shape) return proxy inputs = [] ones = new_graph.create_node('call_function', torch.ones, ([], )) for node in reversed(fx_model.graph.nodes): if node.op == 'output': assert (len(node.args) == 1) val_map[node.args[0]].grad = [fx.Proxy(ones)] elif node.op == 'placeholder': inputs.append(sum(val_map[node].grad).node) elif node.op == 'call_function': g = sum(val_map[node].grad) new_args = [ shape_proxy(i) if isinstance(i, fx.Node) else i for i in node.args ] if node.target not in vjp_map: raise RuntimeError("vjp not yet implemented") new_grads = vjp_map[node.target](g, *new_args) if not isinstance(new_grads, tuple): new_grads = (new_grads, ) for new_g, arg in zip(new_grads, new_args): if isinstance(arg, fx.Proxy): if not hasattr(arg.node, 'grad'): arg.node.grad = [] arg.node.grad.append(new_g) elif node.op == 'call_method': raise RuntimeError("doesn't support methods since i'm lazy") if len(inputs) == 1: inputs = inputs[0] else: inputs = inputs[::-1] if get_value: new_graph.output((orig_output, inputs)) else: new_graph.output(inputs) res = fx.GraphModule(fx_model, new_graph) res.graph.lint() return res
def get_size_of_all_nodes(fx_module: GraphModule, args: List[torch.Tensor]) -> None: """Given a fx graph module, update each node with its total size (weights + bias + output) and its output_size(output). For a non-module node, the total size is the output size. return total size""" # Mark shape and dtype for each node (node.shape and node.dtype) ShapeProp(fx_module).propagate(*args) # Calculate the total size of the whole fx graph total_size_of_graph = 0.0 for node in fx_module.graph.nodes: if node.op == "output": break node.size_bytes = get_size_of_node(fx_module, node) return
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() gm_static.graph.eliminate_dead_code() gm_run.graph.eliminate_dead_code() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape) # apply shape inference to graph and check # that the batch size is equal across all layers infer_symbolic_types(gm_static) batch_sizes = set() gm_static.graph.eliminate_dead_code() for n in gm_static.graph.nodes: assert isinstance(n.type, TensorType) batch_sizes.add(n.type.__args__[0]) assert (len(batch_sizes) == 1)
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inputs) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_shapes(node): return [te.ExprHandle.int(i) for i in node.shape] def get_nnc_type(dtype): if dtype == torch.float: return te.Dtype.Float elif dtype == torch.long: return te.Dtype.Long else: raise RuntimeError("nyi") def get_te_type(node): return get_nnc_type(node.dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. shapes = get_te_shapes(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) inputs.append(env[node.name]) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. result = lower_function(node, node.target, lookup_env(node.args), node.args) env[node.name] = result elif node.op == 'output': outs = list(lookup_env(node.args)) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) else: raise RuntimeError("not yet implemented") loopnest = te.LoopNest(outs) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs ]) def f(inps): module_stuff = [fetch_attr(i.target) for i in module_attrs] cg.call(module_stuff + list(inps)) return f
def nnc_compile(fx_model: fx.GraphModule, example_inputs, get_loopnest=False) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ # fx_model = fx.symbolic_trace(model) t = fx_model.graph.flatten_inps(*example_inputs) ShapeProp(fx_model).propagate(*fx_model.graph.flatten_inps( *example_inputs)) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_type(node): return get_nnc_type(node.meta['tensor_meta'].dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): res = fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) return res def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] compute_stmts = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. if 'tensor_meta' not in node.meta: continue shapes = get_te_shapes(node.meta['tensor_meta'].shape) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() inputs.append(placeholder) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. if 'tensor_meta' in node.meta: # todo: fix kwargs handling if node.kwargs: raise RuntimeError("kwargs nyi") buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args) # if isinstance(stmt, list) compute_stmts.extend(stmt) env[node.name] = buf elif node.target == getattr or node.target == operator.getitem: # todo: handle non-tensor computations correctly continue elif node.op == 'output': args = node.args args = pytree.tree_map( lambda x: list(x) if isinstance( x, fx.immutable_collections.immutable_list) else x, args) flat_args, _ = pytree.tree_flatten(list(args)) te_args = lookup_env(flat_args) outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in flat_args]) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) shapes = get_te_shapes( process_shape(node.meta['tensor_meta'].shape)) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() else: print(node.op, node.target) raise RuntimeError("not yet implemented") loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0]) if get_loopnest: return loopnest # loopnest.inline_intermediate_bufs(True) loopnest.simplify() loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0] ]) if module_attrs: module_stuff = [ fetch_attr(i.target).contiguous().data for i in module_attrs ] else: module_stuff = [] def f(*inps, out_tensors=None): inps = fx_model.graph.flatten_inps(*inps) if out_tensors is None: results = [ torch.empty(shape, dtype=dtype) for shape, dtype in outs[1] ] # results = alloc_results else: results = out_tensors full_inps = module_stuff + list(inps) + results cg.call(full_inps) results = fx_model.graph.unflatten_outs(results) return results if len(results) == 1: return results[0] return results return f
compiled_f = nnc_compile(fx_model, args, get_loopnest=True) return compiled_f return wrapped ################################ # Example usage and Benchmarking ################################ def bench(f, warmup=3, iters=1000): for _ in range(warmup): f() begin = time.time() for _ in range(iters): f() print(time.time() - begin) if __name__ == '__main__': def f(a, b): return (torch.cos(a) * torch.sin(b))[:2000] mod = fx.symbolic_trace(f) inps = (torch.randn(5000), torch.randn(5000)) ShapeProp(mod).propagate(*inps) cg = nnc_compile(mod, inps) bench(lambda: cg(*inps)) bench(lambda: f(*inps))