def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values): fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) # Construct the forward module fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values) bwd_graph = _extract_graph_with_inputs_outputs( joint_module.graph, saved_values + tangent_inputs, bwd_outputs) # This is to filter out saved values that don't actually end up being used by the backwards pass for node in bwd_graph.nodes: if node.op == 'placeholder' and not node.users: for saved_value in saved_values: if saved_value.name == node.name: saved_values.remove(saved_value) break # Now, we re-generate the fwd/bwd graphs. # NB: This might increase compilation time, but I doubt it matters fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values) bwd_graph = _extract_graph_with_inputs_outputs( joint_module.graph, saved_values + tangent_inputs, bwd_outputs) fwd_module = fx.GraphModule(joint_module, fwd_graph) bwd_module = fx.GraphModule(joint_module, bwd_graph) return fwd_module, bwd_module
def graph_fails(graph, inps): nonlocal num_queries graph = copy.deepcopy(graph) num_queries += 1 mod = fx.GraphModule(fail_f, graph) mod.graph.lint() return module_fails(mod, inps)
def decompose(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ decompose(model, example_inputs) takes in a model, decomposes any of the functions in `decomposition_rules` to its constituent operations, and returns a `nn.Module` without any of the operations with decomposition rules. """ # Run it multiple times so we converge to a fixed point. for _ in range(5): model = fx.symbolic_trace(model) ShapeProp(model).propagate(*example_inputs) new_graph = fx.Graph() env = {} for node in model.graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: # If the current function is in `decomposition_rules`, we use # `Proxy` objects to decompose the operations using the # decomposition rule. See # https://pytorch.org/docs/master/fx.html#proxy-retracing for # more details. proxy_args = map_arg(node.args, lambda n: fx.Proxy(env[n.name])) proxy_kwargs = map_arg(node.kwargs, lambda n: fx.Proxy(env[n.name])) new_node = decomposition_rules[node.target]( *proxy_args, **proxy_kwargs).node env[node.name] = new_node else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node model = fx.GraphModule(model, new_graph) return model
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: """ Fuses convolution/BN layers for inference purposes. Will deepcopy your model by default, but can modify the model inplace as well. """ patterns = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] if not inplace: model = copy.deepcopy(model) fx_model = fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) new_graph = copy.deepcopy(fx_model.graph) for pattern in patterns: for node in new_graph.nodes: if matches_module_pattern(pattern, node, modules): if len(node.args[0].users ) > 1: # Output of conv is used by other nodes continue conv = modules[node.args[0].target] bn = modules[node.target] if not bn.track_running_stats: continue fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0]) new_graph.erase_node(node) return fx.GraphModule(fx_model, new_graph)
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module: """vmap Given a model with inputs, vmap will return a function that works on batched versions of those inputs. Which inputs will be batched is determined by in_axes. In addition, as vmap requires shape (actually rank) information, we will pass in example_args (example inputs for the original module). """ in_axes = iter(in_axes) fx_model = fx.symbolic_trace(model) # Here we run a shape propagation pass in order to annotate the graph with shape information. ShapeProp(fx_model).propagate(*example_args) # As vmap rewrites the whole graph, it's easiest to create an entirely new # graph and append to that. new_graph: fx.Graph = fx.Graph() # We will create an environment to map the new nodes created to the # corresponding old nodes. def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) env = {} for node in fx_model.graph.nodes: if node.op == 'placeholder': # If the node is an input placeholder, we simply copy it over and # annotate it with the batch dimension from `in_axes`. new_node = new_graph.placeholder(node.name) new_node.bdim = next(in_axes) new_node.meta = node.meta env[node.name] = new_node elif node.op == 'output': new_graph.output(env[node.args[0].name]) elif node.op == 'call_function': new_args = lookup_env(node.args) # If any of the inputs to the function has a new batch dimension, # we will need to use our batching rules. Otherwise, we will simply # copy the node over. if any([ x.bdim is not None for x in new_args if isinstance(x, fx.Node) ]): new_node = gen_batching_rule_function(node.target, *new_args) else: new_node = new_graph.node_copy(node, lambda x: env[x.name]) new_node.bdim = None new_node.meta = node.meta env[node.name] = new_node else: raise RuntimeError("Not yet implemented") res = fx.GraphModule(fx_model, new_graph) print(res.code) res.graph.lint() return res
def grad(model: torch.nn.Module, example_inps: Tuple[Any, ...], get_value=True) -> torch.nn.Module: fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inps) # graph and append to that. val_map = {} new_graph: fx.Graph = fx.Graph() orig_output = new_graph.graph_copy(fx_model.graph, val_map) def shape_proxy(node): proxy = fx.Proxy(val_map[node]) proxy.shape = node.meta['shape'] proxy.dim = lambda: len(proxy.shape) return proxy inputs = [] ones = new_graph.create_node('call_function', torch.ones, ([], )) for node in reversed(fx_model.graph.nodes): if node.op == 'output': assert (len(node.args) == 1) val_map[node.args[0]].grad = [fx.Proxy(ones)] elif node.op == 'placeholder': inputs.append(sum(val_map[node].grad).node) elif node.op == 'call_function': g = sum(val_map[node].grad) new_args = [ shape_proxy(i) if isinstance(i, fx.Node) else i for i in node.args ] if node.target not in vjp_map: raise RuntimeError("vjp not yet implemented") new_grads = vjp_map[node.target](g, *new_args) if not isinstance(new_grads, tuple): new_grads = (new_grads, ) for new_g, arg in zip(new_grads, new_args): if isinstance(arg, fx.Proxy): if not hasattr(arg.node, 'grad'): arg.node.grad = [] arg.node.grad.append(new_g) elif node.op == 'call_method': raise RuntimeError("doesn't support methods since i'm lazy") if len(inputs) == 1: inputs = inputs[0] else: inputs = inputs[::-1] if get_value: new_graph.output((orig_output, inputs)) else: new_graph.output(inputs) res = fx.GraphModule(fx_model, new_graph) res.graph.lint() return res
def test_remove_duplicate_output_args(self): class Sub(nn.Module): def forward(self, x): return (x, x) class Top(nn.Module): def __init__(self): super().__init__() self.a = Sub() def forward(self, x): a_res = self.a(x) return a_res[0] + a_res[1] class Tracer(fx.Tracer): def is_leaf_module(self, m, qn): if isinstance(m, Sub): # don't trace into return True return False top = Top() ttop = fx.GraphModule(top, Tracer().trace(top), "top") ttop.a = fx.symbolic_trace(ttop.a) name_to_processed_subnet = dedup.remove_duplicate_output_args( ttop, ["a"]) ttop(1) # run inference should work processed_a = name_to_processed_subnet["a"] *_, a_output = processed_a.module.graph.nodes a_output: fx.Node ttop_graph_actual = str(ttop.graph).strip() ttop_graph_expected = """ graph(): %x : [#users=1] = placeholder[target=x] %a : [#users=2] = call_module[target=a](args = (%x,), kwargs = {}) %getitem : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {}) return add """.strip() assert (ttop_graph_expected == ttop_graph_actual ), f"Unexpected ttop graph: {ttop_graph_actual}" ttop_a_graph_actual = str(ttop.a.graph).strip() ttop_a_graph_expected = """ graph(): %x : [#users=1] = placeholder[target=x] return (x,) """.strip() assert (ttop_a_graph_expected == ttop_a_graph_actual ), f"Unexpected ttop.a graph: {ttop_a_graph_actual}"
def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True): if clear_meta: new_graph = copy.deepcopy(traced.graph) traced = fx.GraphModule(traced, new_graph) for node in traced.graph.nodes: node.meta = {} base, ext = os.path.splitext(fname) if not ext: ext = ".svg" print(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer(traced, figname) x = g.get_main_dot_graph() getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}")
def truncate(model, k): model = fx.symbolic_trace(model) new_graph = fx.Graph() env = {} cnt = 0 for node in list(model.graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node cnt += 1 if cnt == k: new_graph.output(env[node.name]) break return fx.GraphModule(model, new_graph)
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ new_graph = fx.Graph() env: Dict[fx.Node, fx.Node] = {} for input in inputs: new_node = new_graph.placeholder(input.name) env[input] = new_node for node in nodes: new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node new_graph.output([env[output] for output in outputs]) new_graph.lint() return fx.GraphModule(orig_module, new_graph)
def profile_function(name, f, inp): fx_g = make_fx(f)(inp) new_g = fx_graph_cse(fx_g.graph) new_g = fx.GraphModule(fx_g, new_g) # do not benchmark against the scripted version because script already does some CSE # script_f = torch.jit.script(fx_g) # script_g = torch.jit.script(new_g) # avg_cuda_time_f = profile_it(script_f, inp) # avg_cuda_time_g = profile_it(script_g, inp) avg_cuda_time_f = profile_it(fx_g, inp) avg_cuda_time_g = profile_it(new_g, inp) num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
def check(f, t, delta, check_val=True, graph_input=False): if graph_input: fx_g = f else: fx_g = make_fx(f)(t) new_graph = fx_graph_cse(fx_g.graph) new_g = fx.GraphModule(fx_g, new_graph) # the number of nodes decrease/ or stay the same old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) if delta == -1: assert old_num_nodes >= new_num_nodes, ( f"number of nodes increased {old_num_nodes}, {new_num_nodes}") else: assert old_num_nodes == new_num_nodes + delta, ( f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" ) # a second pass should not reduce more nodes pass_2_graph = fx_graph_cse(new_graph) pass_2_num_nodes = len(pass_2_graph.nodes) assert pass_2_num_nodes == new_num_nodes, ( f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" ) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None assert our_result is None, f"true result is None, CSE result is {our_result}" else: # results returned are the same assert torch.all(true_result == our_result), ( f"results are different {true_result}, {our_result}" ) # check results are the same
def invert(model: torch.nn.Module) -> torch.nn.Module: fx_model = fx.symbolic_trace(model) new_graph = fx.Graph() # As we're building up a new graph env = {} for node in reversed(fx_model.graph.nodes): if node.op == "call_function": # This creates a node in the new graph with the inverse function, # and passes `env[node.name]` (i.e. the previous output node) as # input. new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name], )) env[node.args[0].name] = new_node elif node.op == "output": # We turn the output into an input placeholder new_node = new_graph.placeholder(node.name) env[node.args[0].name] = new_node elif node.op == "placeholder": # We turn the input placeholder into an output new_graph.output(env[node.name]) else: raise RuntimeError("Not implemented") new_graph.lint() return fx.GraphModule(fx_model, new_graph)
def codegen_tensor_product( irreps_in1: o3.Irreps, in1_var: List[float], irreps_in2: o3.Irreps, in2_var: List[float], irreps_out: o3.Irreps, out_var: List[float], instructions: List[Instruction], normalization: str = 'component', shared_weights: bool = False, specialized_code: bool = True, optimize_einsums: bool = True, ) -> Tuple[fx.GraphModule, fx.GraphModule]: graph_out = fx.Graph() graph_right = fx.Graph() # = Function definitions = x1s_out = fx.Proxy(graph_out.placeholder('x1', torch.Tensor)) x2s_out = fx.Proxy(graph_out.placeholder('x2', torch.Tensor)) ws_out = fx.Proxy(graph_out.placeholder('w', torch.Tensor)) x2s_right = fx.Proxy(graph_right.placeholder('x2', torch.Tensor)) ws_right = fx.Proxy(graph_right.placeholder('w', torch.Tensor)) empty_out = fx.Proxy( graph_out.call_function(torch.empty, ((), ), dict(device='cpu'))) empty_right = fx.Proxy( graph_right.call_function(torch.empty, ((), ), dict(device='cpu'))) if shared_weights: size_out = torch.broadcast_tensors( empty_out.expand(x1s_out.shape[:-1]), empty_out.expand(x2s_out.shape[:-1]))[0].shape size_right = x2s_right.shape[:-1] else: size_out = torch.broadcast_tensors( empty_out.expand(x1s_out.shape[:-1]), empty_out.expand(x2s_out.shape[:-1]), empty_out.expand(ws_out.shape[:-1]))[0].shape size_right = torch.broadcast_tensors( empty_right.expand(x2s_right.shape[:-1]), empty_right.expand(ws_right.shape[:-1]))[0].shape # = Short-circut for zero dimensional = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0: out_out = x1s_out.new_zeros(size_out + (irreps_out.dim, )) out_right = x2s_right.new_zeros(size_right + ( irreps_in1.dim, irreps_out.dim, )) graph_out.output(out_out.node, torch.Tensor) graph_right.output(out_right.node, torch.Tensor) # Short circut return (fx.GraphModule({}, graph_out, "tp_forward"), fx.GraphModule({}, graph_right, "tp_right")) # = Broadcast inputs = if shared_weights: x1s_out, x2s_out = x1s_out.broadcast_to( size_out + (-1, )), x2s_out.broadcast_to(size_out + (-1, )) else: x1s_out, x2s_out, ws_out = x1s_out.broadcast_to( size_out + (-1, )), x2s_out.broadcast_to( size_out + (-1, )), ws_out.broadcast_to(size_out + (-1, )) x2s_right, ws_right = x2s_right.broadcast_to( size_right + (-1, )), ws_right.broadcast_to(size_right + (-1, )) outsize_out = size_out + (irreps_out.dim, ) outsize_right = size_right + ( irreps_in1.dim, irreps_out.dim, ) x1s_out = x1s_out.reshape(-1, irreps_in1.dim) x2s_out = x2s_out.reshape(-1, irreps_in2.dim) x2s_right = x2s_right.reshape(-1, irreps_in2.dim) batch_out = x1s_out.shape[0] batch_right = x2s_right.shape[0] # = Determine number of weights and reshape weights == weight_numel = sum( prod(ins.path_shape) for ins in instructions if ins.has_weight) if weight_numel > 0: ws_out = ws_out.reshape(-1, weight_numel) ws_right = ws_right.reshape(-1, weight_numel) del weight_numel # = book-keeping for wigners = w3j = [] w3j_dict_out = dict() w3j_dict_right = dict() # = extract individual input irreps = # If only one input irrep, can avoid creating a view if len(irreps_in1) == 1: x1_list_out = [ x1s_out.reshape(batch_out, irreps_in1[0].mul, irreps_in1[0].ir.dim) ] else: x1_list_out = [ x1s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in1.slices(), irreps_in1) ] x2_list_out = [] x2_list_right = [] # If only one input irrep, can avoid creating a view if len(irreps_in2) == 1: x2_list_out.append( x2s_out.reshape(batch_out, irreps_in2[0].mul, irreps_in2[0].ir.dim)) x2_list_right.append( x2s_right.reshape(batch_right, irreps_in2[0].mul, irreps_in2[0].ir.dim)) else: for i, mul_ir in zip(irreps_in2.slices(), irreps_in2): x2_list_out.append(x2s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim)) x2_list_right.append(x2s_right[:, i].reshape(batch_right, mul_ir.mul, mul_ir.ir.dim)) # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension z = '' if shared_weights else 'z' # Cache of input irrep pairs whose outer products (xx) have already been computed xx_dict = dict() # Current index in the flat weight tensor flat_weight_index = 0 out_list_out = [] out_list_right = [] for ins in instructions: mul_ir_in1 = irreps_in1[ins.i_in1] mul_ir_in2 = irreps_in2[ins.i_in2] mul_ir_out = irreps_out[ins.i_out] assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l ) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0: continue alpha = ins.path_weight * out_var[ins.i_out] / sum( in1_var[i.i_in1] * in2_var[i.i_in2] for i in instructions if i.i_out == ins.i_out) # Open the profiler block name = f"{mul_ir_in1} x {mul_ir_in2} = {mul_ir_out} {ins.connection_mode} {ins.has_weight}" handle_out = graph_out.call_function( torch.ops.profiler._record_function_enter, (name, )) handle_right = graph_right.call_function( torch.ops.profiler._record_function_enter, (name, )) x1_out = x1_list_out[ins.i_in1] x2_out = x2_list_out[ins.i_in2] x2_right = x2_list_right[ins.i_in2] e1_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in1.mul, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) e2_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in2.mul, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) i1_right = fx.Proxy( graph_right.call_function( torch.eye, (mul_ir_in1.ir.dim, ), dict(dtype=x2s_right.dtype.node, device=x2s_right.device.node))) assert ins.connection_mode in [ 'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv' ] alpha = sqrt( alpha / { 'uvw': (mul_ir_in1.mul * mul_ir_in2.mul), 'uvu': mul_ir_in2.mul, 'uvv': mul_ir_in1.mul, 'uuw': mul_ir_in1.mul, 'uuu': 1, 'uvuv': 1, }[ins.connection_mode]) if ins.has_weight: # Extract the weight from the flattened weight tensor w_out = ws_out[:, flat_weight_index:flat_weight_index + prod(ins.path_shape)].reshape(( () if shared_weights else (-1, )) + tuple(ins.path_shape)) w_right = ws_right[:, flat_weight_index:flat_weight_index + prod(ins.path_shape)].reshape( (() if shared_weights else (-1, )) + tuple(ins.path_shape)) flat_weight_index += prod(ins.path_shape) # Construct the general xx in case this instruction isn't specialized # If this isn't used, the dead code will get removed key = (ins.i_in1, ins.i_in2, ins.connection_mode[:2]) if key not in xx_dict: if ins.connection_mode[:2] == 'uv': xx_dict[key] = torch.einsum('zui,zvj->zuvij', x1_out, x2_out) if ins.connection_mode[:2] == 'uu': xx_dict[key] = torch.einsum('zui,zuj->zuij', x1_out, x2_out) xx = xx_dict[key] # Create a proxy & request for the relevant wigner w3j # If not used (because of specialized code), will get removed later. key = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l) if key not in w3j: w3j_dict_out[key] = fx.Proxy( graph_out.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}")) w3j_dict_right[key] = fx.Proxy( graph_right.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}")) w3j.append(key) w3j_out = w3j_dict_out[key] w3j_right = w3j_dict_right[key] exp = {'component': 1, 'norm': -1}[normalization] if ins.connection_mode == 'uvw': assert ins.has_weight if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uvw,zu,zv->zw", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uvw,zv->zuw", w_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uvw,zu,zvj->zwj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uvw,zvi->zuwi", w_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uvw,zui,zv->zwi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uvw,ij,zv->zuiwj", w_right, i1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uvw,zui,zvi->zw", w_out, x1_out, x2_out) / sqrt(mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uvw,zvi->zuiw", w_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w_right, w3j_right, x2_right) if ins.connection_mode == 'uvu': assert mul_ir_in1.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uv,zu,zv->zu", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,uw,zv->zuw", w_right, e1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zu,zvj->zuj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uv,uw,zvi->zuwi", w_right, e1_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zui,zv->zui", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,ij,uw,zv->zuiwj", w_right, i1_right, e1_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uv,zui,zvi->zu", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uv,uw,zvi->zuiw", w_right, e1_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk", w_right, w3j_right, e1_right, x2_right) else: # not so useful operation because v is summed ein_out = torch.einsum("ijk,zuvij->zuk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zvj->zuiwk", w3j_right, e1_right, x2_right) if ins.connection_mode == 'uvv': assert mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uv,zu,zv->zv", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,vw,zv->zuw", w_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zu,zvj->zvj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}uv,vw,zvi->zuwi", w_right, e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uv,zui,zv->zvi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}uv,ij,vw,zv->zuiwj", w_right, i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uv,zui,zvi->zv", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}uv,vw,zvi->zuiw", w_right, e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w_right, w3j_right, x2_right) else: # not so useful operation because u is summed # only specialize out for this path if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( "zu,zv->zv", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( "zu,zvj->zvj", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( "zui,zv->zvi", x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum("zui,zvi->zv", x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum("ijk,zuvij->zvk", w3j_out, xx) s2ones = fx.Proxy( graph_right.call_function( torch.ones, (mul_ir_in1.mul, ), dict(device=x2_right.device.node, dtype=x2_right.dtype.node))) ein_right = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j_right, x2_right) if ins.connection_mode == 'uuw': assert mul_ir_in1.mul == mul_ir_in2.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}uw,zu,zu->zw", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}uw,zu,zuj->zwj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}uw,zui,zu->zwi", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}uw,zui,zui->zw", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}uw,ijk,zuij->zwk", w_out, w3j_out, xx) # TODO: specialize right() ein_right = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w_right, w3j_right, x2_right) else: # equivalent to tp(x, y, 'uuu').sum('u') assert mul_ir_out.mul == 1 ein_out = torch.einsum("ijk,zuij->zk", w3j_out, xx) ein_right = torch.einsum("ijk,zuj->zuik", w3j_right, x2_right) if ins.connection_mode == 'uuu': assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( f"{z}u,zu,zu->zu", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}u,uw,zu->zuw", w_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and key == ( 1, 1, 1) and normalization == "component": ein_out = torch.einsum(f"{z}u,zui->zui", w_out, torch.cross(x1_out, x2_out, dim=2)) / sqrt(2) # For cross product, use the general case right() ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w_right, w3j_right, e1_right, x2_right) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( f"{z}u,zu,zuj->zuj", w_out, x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum(f"{z}u,uw,zui->zuwi", w_right, e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( f"{z}u,zui,zu->zui", w_out, x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( f"{z}u,ij,uw,zu->zuiwj", w_right, i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum(f"{z}u,zui,zui->zu", w_out, x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum(f"{z}u,uw,zui->zuiw", w_right, e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum(f"{z}u,ijk,zuij->zuk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk", w_right, w3j_right, e1_right, x2_right) else: if specialized_code and key == (0, 0, 0): ein_out = torch.einsum( "zu,zu->zu", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( "uw,zu->zuw", e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and key == ( 1, 1, 1) and normalization == "component": ein_out = torch.cross(x1_out, x2_out, dim=2) * (1.0 / sqrt(2)) # For cross product, use the general case right() ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right, e1_right, x2_right) elif specialized_code and mul_ir_in1.ir.l == 0: ein_out = torch.einsum( "zu,zuj->zuj", x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out) ein_right = torch.einsum("uw,zui->zuwi", e2_right, x2_right) elif specialized_code and mul_ir_in2.ir.l == 0: ein_out = torch.einsum( "zui,zu->zui", x1_out, x2_out.reshape(batch_out, mul_ir_in2.dim)) ein_right = torch.einsum( "ij,uw,zu->zuiwj", i1_right, e2_right, x2_right.reshape(batch_right, mul_ir_in2.dim)) elif specialized_code and mul_ir_out.ir.l == 0: ein_out = torch.einsum("zui,zui->zu", x1_out, x2_out) / sqrt( mul_ir_in1.ir.dim)**exp ein_right = torch.einsum("uw,zui->zuiw", e2_right, x2_right) / sqrt( mul_ir_in1.ir.dim)**exp else: ein_out = torch.einsum("ijk,zuij->zuk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right, e1_right, x2_right) if ins.connection_mode == 'uvuv': assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul if ins.has_weight: # TODO implement specialized code ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w_out, w3j_out, xx) ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w_right, w3j_right, e1_right, x2_right) else: # TODO implement specialized code ein_out = torch.einsum("ijk,zuvij->zuvk", w3j_out, xx) ein_right = torch.einsum("ijk,uw,zvj->zuiwvk", w3j_right, e1_right, x2_right) ein_out = alpha * ein_out ein_right = alpha * ein_right out_list_out += [ein_out.reshape(batch_out, mul_ir_out.dim)] out_list_right += [ ein_right.reshape(batch_right, mul_ir_in1.dim, mul_ir_out.dim) ] # Close the profiler block graph_out.call_function(torch.ops.profiler._record_function_exit, (handle_out, )) graph_right.call_function(torch.ops.profiler._record_function_exit, (handle_right, )) # Remove unused w3js: if len(w3j_out.node.users) == 0 and len(w3j_right.node.users) == 0: del w3j[-1] # The w3j nodes are reshapes, so we have to remove them from the graph # Although they are dead code, they try to reshape to dimensions that don't exist # (since the corresponding w3js are not in w3j) # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript. graph_out.erase_node(w3j_dict_out.pop(key).node) graph_right.erase_node(w3j_dict_right.pop(key).node) # = Return the result = out_out = [ _sum_tensors([ out for ins, out in zip(instructions, out_list_out) if ins.i_out == i_out ], shape=(batch_out, mul_ir_out.dim), like=x1s_out) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(out_out) > 1: out_out = torch.cat(out_out, dim=1) else: # Avoid an unnecessary copy in a size one torch.cat out_out = out_out[0] out_right = [ torch.cat([ _sum_tensors([ out for ins, out in zip(instructions, out_list_right) if (ins.i_in1, ins.i_out) == (i_in1, i_out) ], shape=(batch_right, mul_ir_in1.dim, mul_ir_out.dim), like=x2s_right) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ], dim=2) for i_in1, mul_ir_in1 in enumerate(irreps_in1) if mul_ir_in1.mul > 0 ] if len(out_right) > 1: out_right = torch.cat(out_right, dim=1) else: out_right = out_right[0] out_out = out_out.reshape(outsize_out) out_right = out_right.reshape(outsize_right) graph_out.output(out_out.node, torch.Tensor) graph_right.output(out_right.node, torch.Tensor) # check graphs graph_out.lint() graph_right.lint() # Make GraphModules wigner_mats = {} for l_1, l_2, l_out in w3j: wig = o3.wigner_3j(l_1, l_2, l_out) if normalization == 'component': wig *= (2 * l_out + 1)**0.5 if normalization == 'norm': wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 wigner_mats[f"_w3j_{l_1}_{l_2}_{l_out}"] = wig # By putting the constants in a Module rather than a dict, # we force FX to copy them as buffers instead of as attributes. # # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0. constants_root = torch.nn.Module() for wkey, wmat in wigner_mats.items(): constants_root.register_buffer(wkey, wmat) graphmod_out = fx.GraphModule(constants_root, graph_out, class_name="tp_forward") graphmod_right = fx.GraphModule(constants_root, graph_right, class_name="tp_right") # == Optimize == # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # Note that for our einsums, we can optimize _once_ for _any_ batch dimension # and still get the right path for _all_ batch dimensions. # This is because our einsums are essentially of the form: # zuvw,ijk,zuvij->zwk OR uvw,ijk,zuvij->zwk # In the first case, all but one operands have the batch dimension # => The first contraction gains the batch dimension # => All following contractions have batch dimension # => All possible contraction paths have cost that scales linearly in batch size # => The optimal path is the same for all batch sizes # For the second case, this logic follows as long as the first contraction is not between the first two operands. Since those two operands do not share any indexes, contracting them first is a rare pathological case. See # https://github.com/dgasmith/opt_einsum/issues/158 # for more details. # # TODO: consider the impact maximum intermediate result size on this logic # \- this is the `memory_limit` option in opt_einsum # TODO: allow user to choose opt_einsum parameters? # # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or dtypes. batchdim = 4 example_inputs = ( torch.zeros((batchdim, irreps_in1.dim)), torch.zeros((batchdim, irreps_in2.dim)), torch.zeros( 1 if shared_weights else batchdim, flat_weight_index, ), ) graphmod_out = jitable( optimize_einsums_full(graphmod_out, example_inputs)) graphmod_right = jitable( optimize_einsums_full(graphmod_right, example_inputs[1:])) return graphmod_out, graphmod_right
def _codegen_linear( irreps_in: o3.Irreps, irreps_out: o3.Irreps, instructions: List[Instruction], biases: List[bool], f_in: Optional[int] = None, f_out: Optional[int] = None, shared_weights: bool = False, optimize_einsums: bool = True, ) -> Tuple[fx.GraphModule, int, int]: graph_out = fx.Graph() # = Function definitions = x = fx.Proxy(graph_out.placeholder('x', torch.Tensor)) ws = fx.Proxy(graph_out.placeholder('w', torch.Tensor)) bs = fx.Proxy(graph_out.placeholder('b', torch.Tensor)) if f_in is None: size = x.shape[:-1] outsize = size + (irreps_out.dim, ) else: size = x.shape[:-2] outsize = size + ( f_out, irreps_out.dim, ) bias_numel = sum(mul_ir.dim for bias, mul_ir in zip(biases, irreps_out) if bias) if bias_numel > 0: if f_out is None: bs = bs.reshape(-1, bias_numel) else: bs = bs.reshape(-1, f_out, bias_numel) # = Short-circut for nothing to do = # We produce no code for empty instructions instructions = [ins for ins in instructions if 0 not in ins.path_shape] if len(instructions) == 0 and bias_numel == 0: out = x.new_zeros(outsize) graph_out.output(out.node, torch.Tensor) # Short circut # 0 is weight_numel return fx.GraphModule({}, graph_out, "linear_forward"), 0, 0 if f_in is None: x = x.reshape(-1, irreps_in.dim) else: x = x.reshape(-1, f_in, irreps_in.dim) batch_out = x.shape[0] out_bias_list = [] bias_index = 0 for bias, mul_ir_out in zip(biases, irreps_out): if bias: if sum(biases) == 1: b = bs else: b = bs.narrow(-1, bias_index, mul_ir_out.dim) bias_index += mul_ir_out.dim out_bias_list += [[ b.expand(batch_out, -1) if f_out is None else b.expand( batch_out, f_out, -1) ]] else: out_bias_list += [[]] weight_numel = sum(prod(ins.path_shape) for ins in instructions) if weight_numel > 0: ws = ws.reshape(-1, weight_numel) if f_in is None else ws.reshape( -1, f_in, f_out, weight_numel) # = extract individual input irreps = if len(irreps_in) == 1: x_list = [ x.reshape(batch_out, *(() if f_in is None else (f_in, )), irreps_in[0].mul, irreps_in[0].ir.dim) ] else: x_list = [ x.narrow(-1, i.start, mul_ir.dim).reshape(batch_out, *(() if f_in is None else (f_in, )), mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in.slices(), irreps_in) ] z = '' if shared_weights else 'z' flat_weight_index = 0 out_list = [] for ins in instructions: mul_ir_in = irreps_in[ins.i_in] mul_ir_out = irreps_out[ins.i_out] # Short-circut for empty irreps if mul_ir_in.dim == 0 or mul_ir_out.dim == 0: continue # Extract the weight from the flattened weight tensor path_nweight = prod(ins.path_shape) if len(instructions) == 1: # Avoid unnecessary view when there is only one weight w = ws else: w = ws.narrow(-1, flat_weight_index, path_nweight) w = w.reshape((() if shared_weights else (-1, )) + (() if f_in is None else (f_in, f_out)) + ins.path_shape) flat_weight_index += path_nweight if f_in is None: ein_out = torch.einsum(f"{z}uw,zui->zwi", w, x_list[ins.i_in]) else: ein_out = torch.einsum(f"{z}xyuw,zxui->zywi", w, x_list[ins.i_in]) alpha = 1.0 / math.sqrt((f_in or 1) * mul_ir_in.mul * sum(1 if other_ins.i_out == ins.i_out else 0 for other_ins in instructions)) ein_out = alpha * ein_out out_list += [ ein_out.reshape(batch_out, *(() if f_out is None else (f_out, )), mul_ir_out.dim) ] # = Return the result = out = [ _sum_tensors([ out for ins, out in zip(instructions, out_list) if ins.i_out == i_out ] + out_bias_list[i_out], shape=(batch_out, *(() if f_out is None else (f_out, )), mul_ir_out.dim), like=x) for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0 ] if len(out) > 1: out = torch.cat(out, dim=-1) else: out = out[0] out = out.reshape(outsize) graph_out.output(out.node, torch.Tensor) # check graphs graph_out.lint() graphmod_out = fx.GraphModule({}, graph_out, "linear_forward") # TODO: when eliminate_dead_code() is in PyTorch stable, use that if optimize_einsums: # See _tensor_product/_codegen.py for notes batchdim = 4 example_inputs = ( torch.zeros((batchdim, *(() if f_in is None else (f_in, )), irreps_in.dim)), torch.zeros( 1 if shared_weights else batchdim, f_in or 1, f_out or 1, weight_numel, ), torch.zeros( 1 if shared_weights else batchdim, f_out or 1, bias_numel, ), ) graphmod_out = jitable( optimize_einsums_full(graphmod_out, example_inputs)) return graphmod_out, weight_numel, bias_numel
def optimize_for_inference( model: torch.nn.Module, pass_config: Optional[Dict[str, Any]] = None, tracer: Type[fx.Tracer] = fx.Tracer) -> torch.nn.Module: """ Performs a set of optimization passes to optimize a model for the purposes of inference. Specifically, the passes that are run are: 1. Conv/BN fusion 2. Dropout removal 3. MKL layout optimizations The third optimization takes a function `use_mkl_heuristic` that's used to determine whether a subgraph should be explicity run in MKL layout. Note: As FX does not currently handle aliasing, this pass currently assumes nothing aliases. If that isn't true, use at your own risk. """ default_pass_config = { "conv_bn_fuse": True, "remove_dropout": True, "mkldnn_layout_optimize": { 'heuristic': use_mkl_length }, } if pass_config is None: pass_config = {} default_pass_config.update(pass_config) if default_pass_config["conv_bn_fuse"]: model = fuse(model) if default_pass_config["remove_dropout"]: model = remove_dropout(model) if default_pass_config["mkldnn_layout_optimize"] is False: return model if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): raise RuntimeError("mkldnn_layout_optimize config is not a dict") if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: raise RuntimeError( "Heuristic not found in mkldnn_layout_optimize config") use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"][ "heuristic"] cur_tracer = tracer() fx_graph = cur_tracer.trace(copy.deepcopy(model)) fx_model = fx.GraphModule(cur_tracer.root, fx_graph) modules: Dict[str, nn.Module] = dict(model.named_modules()) class MklSupport(Enum): NO = 1 YES = 2 UNKNOWN = 3 # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. # However, if it's in `mkldnn_supported_unknown`, then we only treat it as # a MKLDNN node if its inputs are MKLDNN nodes. for node in list(fx_graph.nodes): supports_mkldnn = MklSupport.NO if node.op == 'call_module': cur_module = modules[node.target] if type(cur_module) in mkldnn_supported: supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: assert (sample_parameter.dtype == torch.float ), "this pass is only for torch.float modules" assert (sample_parameter.device == torch.device('cpu') ), "this pass is only for CPU modules" elif node.op == 'call_function': if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES elif node.target in mkldnn_supported_unknown: supports_mkldnn = MklSupport.UNKNOWN if supports_mkldnn != MklSupport.NO: if supports_mkldnn == MklSupport.UNKNOWN: if not any([arg.target == 'to_dense' for arg in node.args]): continue with fx_graph.inserting_before(node): mkldnn_args = fx.map_arg( node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) node.args = cast(Tuple[fx.node.Argument], mkldnn_args) with fx_graph.inserting_after(node): dense_x = fx_graph.create_node('call_method', 'to_dense', (node, )) node.replace_all_uses_with(dense_x) dense_x.args = (node, ) # Does pre-conversion of all modules into MKLDNN (when possible) old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) fx_graph.old_modules = old_modules # type: ignore[attr-defined] # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b for node in fx_graph.nodes: if node.op == 'call_method' and node.target == 'to_dense': prv_node = node.args[0] users = list(node.users) for user in users: if user.op == 'call_method' and user.target == 'to_mkldnn': user.replace_all_uses_with(prv_node) fx_graph.erase_node(user) if len(node.users) == 0: fx_graph.erase_node(node) num_nodes = len(fx_graph.nodes) uf = UnionFind(num_nodes) def get_color(n): if hasattr(n, 'color'): # Current node is part of a MKL subgraph return uf.find(n.color) if hasattr(n, 'start_color'): # Current node is input to MKL subgraph return uf.find(n.start_color) return None # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists # of input nodes (which are only `to_mkldnn` calls), output nodes # (`to_dense` calls), and intermediate nodes, which are run entirely on # MKLDNN layout tensors. # # Specifically, this code does a flood fill on a directed acyclic graph # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). # If every node only had one input, this would be sufficient. However, in # the case that a node has multiple inputs coming from different start # nodes (i.e. colors), we need to join these 2 colors into 1. That's done # using a Disjoint Set Union. for cur_idx, node in enumerate(fx_graph.nodes): if node.op == 'call_method' and node.target == 'to_mkldnn': node.start_color = cur_idx uf.make_set(cur_idx) elif node.op == 'call_method' and node.target == 'to_dense': assert (get_color(node.args[0]) is not None) node.end_color = get_color(node.args[0]) else: cur_colors = [ get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None ] if len(cur_colors) == 0: continue assert (not any(i is None for i in cur_colors)) cur_colors = sorted(cur_colors) node.color = cur_colors[0] for other_color in cur_colors[1:]: uf.join(cur_colors[0], other_color) mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict( lambda: MklSubgraph(fx_graph)) for node in fx_graph.nodes: if hasattr(node, 'color'): mkldnn_graphs[uf.find(node.color)].nodes.append(node) if hasattr(node, 'start_color'): mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) if hasattr(node, 'end_color'): mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) # Now that we have all the subgraphs, we need to decide which MKLDNN # subgraphs we actually want to keep in MKLDNN. for graph in mkldnn_graphs.values(): if not use_mkl_heuristic(graph): for node in graph.start_nodes + graph.end_nodes: prv = node.args[0] node.replace_all_uses_with(prv) fx_graph.erase_node(node) reset_modules(graph.nodes, modules, old_modules) mkldnn_conversions = 0 for node in fx_graph.nodes: if node.target == 'to_mkldnn' or node.target == 'to_dense': mkldnn_conversions += 1 logging.info(f"mkldnn conversions: {mkldnn_conversions}") fx_graph.lint() result = fx.GraphModule(model, fx_graph) return result
def create_feature_extractor( model: nn.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, ) -> fx.GraphModule: """ Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested outputs as values. This is achieved by re-writing the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, together with their corresponding parameters. Desired output nodes must be specified as a ``.`` separated path walking the module hierarchy from top level module down to leaf operation or leaf module. For more details on the node naming conventions used here, please see the :ref:`relevant subheading <about-node-names>` in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_. Not all models will be FX traceable, although with some massaging they can be made to cooperate. Here's a (not exhaustive) list of tips: - If you don't need to trace through a particular, problematic sub-module, turn it into a "leaf module" by passing a list of ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below). It will not be traced through, but rather, the resulting graph will hold a reference to that module's forward method. - Likewise, you may turn functions into leaf functions by passing a list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see example below). - Some inbuilt Python functions can be problematic. For instance, ``int`` will raise an error during tracing. You may wrap them in your own function and then pass that in ``autowrap_functions`` as one of the ``tracer_kwargs``. For further information on FX see the `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_. Args: model (nn.Module): model on which we will extract the features return_nodes (list or dict, optional): either a ``List`` or a ``Dict`` containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a ``Dict``, the keys are the node names, and the values are the user-specified keys for the graph module's returned dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping node specification strings directly to output names. In the case that ``train_return_nodes`` and ``eval_return_nodes`` are specified, this should not be specified. train_return_nodes (list or dict, optional): similar to ``return_nodes``. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, ``eval_return_nodes`` must also be specified, and ``return_nodes`` should not be specified. eval_return_nodes (list or dict, optional): similar to ``return_nodes``. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, ``train_return_nodes`` must also be specified, and `return_nodes` should not be specified. tracer_kwargs (dict, optional): a dictionary of keyword arguments for ``NodePathTracer`` (which passes them onto it's parent class `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). By default it will be set to wrap and make leaf nodes all torchvision ops: {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user provided dictionary. suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. Examples:: >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]}) """ tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) is_training = model.training if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]): raise ValueError( "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified" ) if (train_return_nodes is None) ^ (eval_return_nodes is None): raise ValueError( "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" ) if not ((return_nodes is None) ^ (train_return_nodes is None)): raise ValueError( "If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" ) # Put *_return_nodes into Dict[str, str] format def to_strdict(n) -> Dict[str, str]: if isinstance(n, list): return {str(i): str(i) for i in n} return {str(k): str(v) for k, v in n.items()} if train_return_nodes is None: return_nodes = to_strdict(return_nodes) train_return_nodes = deepcopy(return_nodes) eval_return_nodes = deepcopy(return_nodes) else: train_return_nodes = to_strdict(train_return_nodes) eval_return_nodes = to_strdict(eval_return_nodes) # Repeat the tracing and graph rewriting for train and eval mode tracers = {} graphs = {} mode_return_nodes: Dict[str, Dict[str, str]] = { "train": train_return_nodes, "eval": eval_return_nodes } for mode in ["train", "eval"]: if mode == "train": model.train() elif mode == "eval": model.eval() # Instantiate our NodePathTracer and use that to trace the model tracer = NodePathTracer(**tracer_kwargs) graph = tracer.trace(model) name = model.__class__.__name__ if isinstance( model, nn.Module) else model.__name__ graph_module = fx.GraphModule(tracer.root, graph, name) available_nodes = list(tracer.node_to_qualname.values()) # FIXME We don't know if we should expect this to happen if len(set(available_nodes)) != len(available_nodes): raise ValueError( "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" ) # Check that all outputs in return_nodes are present in the model for query in mode_return_nodes[mode].keys(): # To check if a query is available we need to check that at least # one of the available names starts with it up to a . if not any([ re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes ]): raise ValueError( f"node: '{query}' is not present in model. Hint: use " "`get_graph_node_names` to make sure the " "`return_nodes` you specified are present. It may even " "be that you need to specify `train_return_nodes` and " "`eval_return_nodes` separately.") # Remove existing output nodes (train mode) orig_output_nodes = [] for n in reversed(graph_module.graph.nodes): if n.op == "output": orig_output_nodes.append(n) if not orig_output_nodes: raise ValueError( "No output nodes found in graph_module.graph.nodes") for n in orig_output_nodes: graph_module.graph.erase_node(n) # Find nodes corresponding to return_nodes and make them into output_nodes nodes = [n for n in graph_module.graph.nodes] output_nodes = OrderedDict() for n in reversed(nodes): module_qualname = tracer.node_to_qualname.get(n) if module_qualname is None: # NOTE - Know cases where this happens: # - Node representing creation of a tensor constant - probably # not interesting as a return node # - When packing outputs into a named tuple like in InceptionV3 continue for query in mode_return_nodes[mode]: depth = query.count(".") if ".".join(module_qualname.split(".")[:depth + 1]) == query: output_nodes[mode_return_nodes[mode][query]] = n mode_return_nodes[mode].pop(query) break output_nodes = OrderedDict(reversed(list(output_nodes.items()))) # And add them in the end of the graph with graph_module.graph.inserting_after(nodes[-1]): graph_module.graph.output(output_nodes) # Remove unused modules / parameters graph_module.graph.eliminate_dead_code() graph_module.recompile() # Keep track of the tracer and graph so we can choose the main one tracers[mode] = tracer graphs[mode] = graph # Warn user if there are any discrepancies between the graphs of the # train and eval modes if not suppress_diff_warning: _warn_graph_differences(tracers["train"], tracers["eval"]) # Build the final graph module graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name) # Restore original training mode model.train(is_training) graph_module.train(is_training) return graph_module
def minifier(fail_f: fx.GraphModule, inps, module_fails): """ Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. Does 2 main strategies: 1. Truncates suffix: Removes some suffix from the graph and sets a new output. 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc. >>> failing_function = fx.symbolic_trace(f) >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) note: module_fails returns True if it fails. """ failing_graph = fail_f.graph cur_size = len(failing_graph.nodes) def graph_fails(graph, inps): mod = fx.GraphModule(fail_f, graph) mod.graph.lint() return module_fails(mod, inps) ConcreteProp(fail_f).propagate(*inps) if not graph_fails(failing_graph, inps): raise RuntimeError("Input graph did not fail the tester") print(f"Started off with {cur_size} nodes") def remove_suffix(cur_graph, cur_inps): print("Strategy: Remove suffix") assert graph_fails(cur_graph, cur_inps) gap = 2**math.floor(math.log2(len(cur_graph.nodes))) tested = set() while gap >= 1: new_graph = fx.Graph() env = {} for idx, node in enumerate(cur_graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x]) if node.op not in ['placeholder', 'output']: if idx % gap == 0 and idx not in tested: output_node = new_graph.output((new_node, )) if graph_fails(new_graph, cur_inps) and len( new_graph.nodes) < len(cur_graph.nodes): print() print( f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})" ) return (new_graph, cur_inps), True else: tested.add(idx) new_graph.erase_node(output_node) env[node] = new_node gap //= 2 print("FAIL: Could not remove suffix") return (cur_graph, cur_inps), False def remove_unused_inputs(cur_graph, cur_inps): assert graph_fails(cur_graph, cur_inps) ph_nodes = _get_placeholders(cur_graph) if len(ph_nodes) != len(cur_inps): print(cur_graph) print(len(cur_inps)) assert len(ph_nodes) == len(cur_inps) new_inps = [] for idx in range(len(ph_nodes)): if len(ph_nodes[idx].users) == 0: cur_graph.erase_node(ph_nodes[idx]) else: new_inps.append(cur_inps[idx]) if len(new_inps) < len(cur_inps) and graph_fails(cur_graph, new_inps): print("Strategy: Remove unused inputs") print( f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs" ) return (cur_graph, new_inps), True else: return (cur_graph, new_inps), False def eliminate_dead_code(cur_graph, cur_inps): orig_size = len(cur_graph.nodes) if cur_graph.eliminate_dead_code() and graph_fails( cur_graph, cur_inps): print("Strategy: Eliminate dead code") print( f"SUCCESS: Went from {orig_size} nodes to {len(cur_graph.nodes)} nodes" ) return (cur_graph, cur_inps), True else: return (cur_graph, cur_inps), False def consolidate_placeholders(cur_graph): new_graph = fx.Graph() env = {} for node in cur_graph.nodes: if node.op == 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node for node in cur_graph.nodes: if node.op != 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node return new_graph def delta_debugging(cur_graph: fx.Graph, cur_inps): print("Strategy: Delta Debugging") assert graph_fails(cur_graph, cur_inps) starting_placeholders = len(_get_placeholders(cur_graph)) num_nodes = len(cur_graph.nodes) gap = int(2**math.floor(math.log2(num_nodes))) while gap >= 1: for start_range in range(0, num_nodes, gap): is_removing = False new_graph = copy.deepcopy(cur_graph) new_inps = cur_inps[:] end_range = min(num_nodes, start_range + gap) for idx in range(start_range, end_range): new_node = list(new_graph.nodes)[idx] if new_node.op not in ['placeholder', 'output']: is_removing = True _convert_node_to_placeholder(new_node, new_inps) if not is_removing: continue new_graph = consolidate_placeholders(new_graph) if graph_fails(new_graph, new_inps): print( f"SUCCESS: Removed ({start_range}:{end_range}] - Went from {starting_placeholders} " f"placeholders to {len(_get_placeholders(new_graph))}") return (new_graph, new_inps), True gap //= 2 print("FAIL: Could not remove prefix") return (cur_graph, inps), False print("###################") print(f"Current size: {len(failing_graph.nodes)}") print("###################") while True: any_succeeded = False strategies = [ remove_suffix, eliminate_dead_code, remove_unused_inputs, delta_debugging, eliminate_dead_code, remove_unused_inputs ] for strategy in strategies: out = strategy(copy.deepcopy(failing_graph), inps[:]) (cur_graph, cur_inps), succeeded = out if succeeded: print() print("###################") print(f"Current size: {len(cur_graph.nodes)}") print("###################") failing_graph = cur_graph inps = cur_inps any_succeeded = True if not any_succeeded: break failing_fx = fx.GraphModule(fail_f, failing_graph) print(f""" inps = {[(i.shape, i.dtype) for i in inps]} inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] {failing_fx.code} f = torch.jit.script(forward) with torch.jit.fuser("fuser2"): for _ in range(5): f(*inps)""") return failing_fx, inps
def deepcopy_fx_graph(fx_graph): return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state): """ Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. Does 2 main strategies: 1. Truncates suffix: Removes some suffix from the graph and sets a new output. 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc. >>> failing_function = fx.symbolic_trace(f) >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) note: module_fails returns True if it fails. """ failing_graph = fail_f.graph cur_size = len(failing_graph.nodes) num_queries = 0 def deepcopy_fx_graph(fx_graph): return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph def graph_fails(graph, inps): nonlocal num_queries graph = copy.deepcopy(graph) num_queries += 1 mod = fx.GraphModule(fail_f, graph) mod.graph.lint() return module_fails(mod, inps) ConcreteProp(fail_f).propagate(*inps) if not graph_fails(failing_graph, inps): raise RuntimeError("Input graph did not fail the tester") print(f"Started off with {cur_size} nodes") def _register_strategy(strategy: Callable, name: str): @wraps(strategy) def new_func(old_state: ReproState, granularity=1): print() print( f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)" ) new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity) if new_state is not None: new_nodes = len(new_state.graph.nodes) old_nodes = len(old_state.graph.nodes) new_inps = len(new_state.inps) old_inps = len(old_state.inps) new_outs = len(get_outputs(new_state.graph)) old_outs = len(get_outputs(old_state.graph)) progress_made = False if new_nodes < old_nodes: progress_made = True print( f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes") if new_inps > old_inps: progress_made = True print( f"SUCCESS: Went from {old_inps} to {new_inps} inputs") if new_outs < old_outs: progress_made = True print( f"SUCCESS: Went from {old_outs} to {new_outs} outputs") if not progress_made: raise RuntimeError("Success raised but no progress made?") if not graph_fails(new_state.graph, new_state.inps): print( "WARNING: Something went wrong, not applying this minification" ) return None return new_state else: print(f"FAIL: {name}") return None return new_func def register_strategy(name: str): return partial(_register_strategy, name=name) @register_strategy("Truncate suffix") def remove_suffix(cur_graph, cur_inps, granularity): tested = set() new_graph = fx.Graph() env = {} for idx, node in enumerate(cur_graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x]) if node.op not in ['placeholder', 'output']: # If idx is divisible by (granularity * 2), it would have been checked already. if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested: output_node = new_graph.output((new_node, )) if len(new_graph.nodes) < len( cur_graph.nodes) and graph_fails( new_graph, cur_inps): return ReproState(new_graph, cur_inps) else: tested.add(idx) new_graph.erase_node(output_node) env[node] = new_node return None @register_strategy("Remove outputs") def remove_outputs(cur_graph, cur_inps, granularity): granularity = max(1, granularity // 2) for idx, node in enumerate(cur_graph.nodes): node.idx = idx if node.op == 'output': output = node break output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)) if len(output_args) == 1: return None for idx in range(0, len(output_args), granularity): output.args = (output_args[:idx] + output_args[idx + granularity:], ) if graph_fails(cur_graph, cur_inps): return ReproState(cur_graph, cur_inps) return None def remove_unused_inputs_unchecked(cur_state: ReproState): cur_graph = cur_state.graph cur_inps = cur_state.inps ph_nodes = get_placeholders(cur_graph) assert len(ph_nodes) == len(cur_inps) new_inps = [] for idx in range(len(ph_nodes)): if len(ph_nodes[idx].users) == 0: cur_graph.erase_node(ph_nodes[idx]) else: new_inps.append(cur_inps[idx]) if len(new_inps) < len(cur_inps): return ReproState(cur_graph, new_inps) return None def remove_unused_inputs_checked(cur_state: ReproState): new_state = remove_unused_inputs_unchecked(cur_state) if new_state is not None and graph_fails(new_state.graph, new_state.inps): return new_state return None def _remove_unused_wrapper(cur_graph, cur_inps, granularity): return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) remove_unused_inputs = register_strategy("Remove unused inputs")( _remove_unused_wrapper) @register_strategy("Eliminate dead code") def eliminate_dead_code(cur_graph, cur_inps, granularity): if cur_graph.eliminate_dead_code() and graph_fails( cur_graph, cur_inps): return ReproState(cur_graph, cur_inps) return None def _consolidate_placeholders(cur_graph): new_graph = fx.Graph() env = {} for node in cur_graph.nodes: if node.op == 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node for node in cur_graph.nodes: if node.op != 'placeholder': new_node = new_graph.node_copy(node, lambda x: env[x]) env[node] = new_node return new_graph @register_strategy("Delta Debugging") def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): num_nodes = len(cur_graph.nodes) for start_range in range(0, num_nodes, granularity): is_removing = False new_graph = deepcopy_fx_graph(cur_graph) new_inps = cur_inps[:] end_range = min(num_nodes, start_range + granularity) for idx in range(start_range, end_range): new_node = list(new_graph.nodes)[idx] if new_node.op not in ['placeholder', 'output']: is_removing = True _convert_node_to_placeholder(new_node, new_inps) if not is_removing: continue new_graph = _consolidate_placeholders(new_graph) new_state = remove_unused_inputs_unchecked( ReproState(new_graph, new_inps)) if new_state is None: new_state = ReproState(new_graph, new_inps) if graph_fails(new_state.graph, new_state.inps): return ReproState(new_state.graph, new_state.inps) return None failing_state = ReproState(failing_graph, inps) def try_granularity(failing_state, granularity, use_non_granular): print(f"Trying granularity {granularity}") strategies = [] num_nodes = len(failing_state.graph.nodes) num_outputs = len(get_outputs(failing_state.graph)) if num_outputs > num_nodes // 2: strategies += [remove_outputs] if use_non_granular: strategies += [eliminate_dead_code, remove_unused_inputs] strategies += [remove_suffix, delta_debugging] for strategy in strategies: new_state = strategy(failing_state, granularity) if new_state is not None: return new_state return None while True: dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) granularity = int(2**(math.floor( math.log2(len(failing_state.graph.nodes))))) new_state = try_granularity(failing_state, granularity, use_non_granular=True) if new_state is not None: failing_state = new_state continue granularity //= 2 has_progress = False while granularity >= 1: new_state = try_granularity(failing_state, granularity, use_non_granular=False) if new_state is not None: failing_state = new_state has_progress = True break granularity //= 2 if has_progress: continue new_state = remove_outputs(failing_state, 1) if new_state is not None: failing_state = new_state continue break if not graph_fails(failing_state.graph, failing_state.inps): raise RuntimeError( "Uh oh, something went wrong :( Final graph is not failing") print(f"Made {num_queries} queries") failing_fx = fx.GraphModule(fail_f, failing_state.graph) dump_state(failing_fx, failing_state.inps) print("Wrote minimal repro out to repro.py") return failing_fx, failing_state.inps
def graph_fails(graph, inps): mod = fx.GraphModule(fail_f, graph) mod.graph.lint() return module_fails(mod, inps)