def test_subgraph_uniquename(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b, c, d): add_1 = a + b add_2 = add_1 + c linear_1 = self.linear(add_1) add_3 = add_2 + d add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) mm = MyModule() traced = symbolic_trace(mm) def split_cb(node: torch.fx.Node): if node.name == 'a' or node.name == 'b' or node.name == 'add': return 0 else: return 1 module_with_submodule = split_module(traced, mm, split_cb) self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
def test_subgraph_trivial_resnet(self): # Smoke test trivially splitting resnet into 1 partition works # There was an issue before causing submodule names to be aliased m = resnet18() traced = symbolic_trace(m) a = torch.rand(64, 3, 7, 7) module_with_submodules = split_module(traced, m, lambda node: 0) module_with_submodules(a)
def do_partition(self) -> GraphModule: """Return a module with submodules (partitions).""" module_with_submodules = split_module( self.graph_module, self.torch_module, lambda node: self.node_to_partition[node] ) return module_with_submodules
def do_partition(self) -> GraphModule: """Return a module with submodules (partitions).""" for node in self.graph_module.graph.nodes: if node.op == 'output': break module_with_submodules = split_module( self.graph_module, self.torch_module, lambda node: self.node_to_partitions[node][0]) return module_with_submodules
def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) # random mod partitioning partition_counter = 0 NPARTITIONS = 3 def mod_partition(node: Node): nonlocal partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition # split module in module with submodules module_with_submodules = split_module(my_module_traced, my_module, mod_partition) x = torch.rand(3, 4) y = torch.rand(3, 4) orig_out = my_module_traced(x, y) submodules_out = module_with_submodules(x, y) self.assertEqual(orig_out, submodules_out)
def split_const_subgraphs(module: torch.nn.Module, ) -> FoldedGraphModule: """ Looks through `module` for any nodes that have all constant attribute inputs and separates them out into their own constant subgraph, and returns a FoldedGraphModule which runs that constant subgraph on the first run to set attributes on the module prior to running the non-constant portion of the graph. """ mod_traced = torch.fx.symbolic_trace(module) # Build up a list of const_nodes, defined as nodes that are themselves # get_attrs, or have all get_attr or other constant node inputs. const_nodes: Set[torch.fx.Node] = set() found_const_folding = False for node in mod_traced.graph.nodes: # Skip over placeholders/outputs because they can't be const folded and # we don't want to add tags to them. if node.op in {"placeholder", "output"}: continue # If the node itself is constant, or all of its inputs are constant, # then tag it as constant. if node.op == "get_attr" or set( node.all_input_nodes).issubset(const_nodes): const_nodes.add(node) if node.op != "get_attr": found_const_folding = True # If we did not find any const folding then return early without a const fold subgraph. if not found_const_folding: return FoldedGraphModule(mod_traced, mod_traced.graph) # Partition the module into two: submod_0 for constant folding subgraph, and # submod_1 for the rest. def mod_partition(node: torch.fx.Node): return 0 if node in const_nodes else 1 split = split_module(mod_traced, module, mod_partition) # Gather all names that are output from the const folding subgraph, which we # will need to set dummy params on the module. const_output_names: List[str] = [] for node in split.submod_0.graph.nodes: if node.op == "output": # Note: we _make_tuple here because the output Node either contains # a single output Node, or Tuple[Node], so this simplifies things. const_output_names = [o.name for o in _make_tuple(node.args[0])] break # Make sure the attr name we want to use is uniquely named in the module. for i in range(len(const_output_names)): # Add a suffix to make it easier to tell these were the result of const folding. name = const_output_names[i] + "__CF" # Delete all characters that are illegal in a Python identifier. name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = f"_{name}" # Now make sure it is in fact unique to the module by incrementing suffix value. while hasattr(mod_traced, name): match = re.match(r"(.*)_(\d+)$", name) if match is None: name = name + "_1" else: base, num = match.group(1, 2) name = f"{base}_{int(num) + 1}" const_output_names[i] = name # Now track the const_output_names to what name is used in the parent graph # from the split via call_function getitem, to see what order it is passed # into the non-const subgraph submod_1. First look to the parent module # containing/calling into the const/non-const submodules to determine what # the inputs are to each. Note if submod_0 had a single output then there is # no getitem, and we can simply use the output from the call to submoid_0. call_submod_0_args, call_submod_1_args = None, None orig_ph_targets: List[str] = [] for node in split.graph.nodes: if node.op == "placeholder": orig_ph_targets.append(node.target) if node.op == "call_module": if node.target == "submod_0": call_submod_0_args = node.args continue elif node.target == "submod_1": call_submod_1_args = node.args continue assert call_submod_0_args is not None and call_submod_1_args is not None # Look through the args for the call into submod_1, and find the args that # come from submod_0. Also look for get_attrs fed directly from the parent # split into submod_1, i.e. those attrs that are not constant folded. submod_1_input_idx_to_folded_attr_name: Dict[int, str] = {} submod_1_input_idx_to_unfolded_attr_name: Dict[int, str] = {} for i, node in enumerate(call_submod_1_args): const_output_name = None # If we only had a single output from submod_0 then we simply look for # the call_module into it. if len(const_output_names) == 1: if node.op == "call_module" and node.target == "submod_0": const_output_name = const_output_names[0] # Else we had multiple outputs from submod_0, so we need to look for all # getitems from the call to it. else: if (node.op == "call_function" and node.target == operator.__getitem__ and node.args[0].target == "submod_0"): const_output_name = const_output_names[node.args[1]] # Now map from the index of the constant into calling submod_1 and map # to the constant output name, which we use for swapping in getattrs # instead of placeholders in submod_1. if const_output_name is not None: submod_1_input_idx_to_folded_attr_name[i] = const_output_name elif node.op == "get_attr": submod_1_input_idx_to_unfolded_attr_name[i] = node.target assert len(submod_1_input_idx_to_folded_attr_name) == len( const_output_names) # Now we have a mapping from const output names to the index they are passed # into submod_1, so swap in getattrs for placeholders. ph_idx = 0 for node in split.submod_1.graph.nodes: if node.op != "placeholder": continue is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys( ) is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys( ) if not is_folded_attr and not is_unfolded_attr: ph_idx += 1 continue const_output_name = (submod_1_input_idx_to_folded_attr_name[ph_idx] if is_folded_attr else submod_1_input_idx_to_unfolded_attr_name[ph_idx]) if is_folded_attr: assert not hasattr(mod_traced, const_output_name) # Use a dummy param, which will be overwritten when we run const folding. setattr( mod_traced, const_output_name, torch.nn.Parameter(torch.randn(1)), ) with split.submod_1.graph.inserting_before(node): node.replace_all_uses_with( split.submod_1.graph.get_attr(const_output_name)) split.submod_1.graph.erase_node(node) ph_idx += 1 # We may need to reorder placeholders to ensure they have the same order as # they do in the original split. ph_idx = 0 node = next(iter(split.submod_1.graph.nodes)) while node.op != "root": if node.op != "placeholder": node = node.next continue curr_orig_ph_target = orig_ph_targets[ph_idx] ph_idx += 1 # If this ph is in the correct position, nothing to do. if curr_orig_ph_target == node.target: node = node.next continue # This ph is not in the correct order, so search the rest of the graph # for the ph we expected and prepend it before the current ph. later_node = node.next while later_node.op != "root": if (later_node.op == "placeholder" and curr_orig_ph_target == later_node.target): break later_node = later_node.next assert later_node.op != "root" node.prepend(later_node) # Note we do not increment node here, as it still may be in the wrong # place (we just prepended the ph that should have come before it). # split_module currently does not use get_attrs for attrs. Instead it passes # them in as args from the parent module, which used get_attrs. Here we set # them as get_attrs inside submod_0, allowing for running folding without # somehow a priori knowing the attrs that should be passed as args. We can # unconditionally do this for all placeholders because we know all # placeholders to submod_0 must be constants accessible via get_attr. for node in split.submod_0.graph.nodes: if node.op != "placeholder": continue in_node = next(n for n in call_submod_0_args if n.name == node.target) assert in_node.op == "get_attr" with split.submod_0.graph.inserting_before(node): node.replace_all_uses_with( split.submod_0.graph.get_attr(in_node.target)) split.submod_0.graph.erase_node(node) return FoldedGraphModule(mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names)