def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node): if node.op != "call_module": return node.target # Find the module that node.target points to m = dict(mod.named_modules())[node.target] return getattr(m, "_base_class_origin", type(m))
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: """Utility used to trace a graph and identify shard cutpoints.""" node_name_to_shard_id: Dict[str, int] = {} shard_id = 0 nodes_so_far = [] param_count: Dict[str, int] = {} shard_to_param_count = {} # Find the total number of params in the model and # the number of params per shard we are aiming for. for name, module in traced_graph_module.named_modules(): name = name.replace(".", "_") param_count[name] = sum([x.numel() for x in module.parameters()]) logging.info(f"Total number of params are {param_count['']}") per_shard_param = param_count[""] // shard_count logging.info(f"Per shard param count {per_shard_param}") for node in traced_graph_module.graph.nodes: if node.op == "placeholder": node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) elif node.op in [ "get_attr", "call_function", "call_method", "call_module" ]: min_shard_id = shard_id min_node_name = "" # For each of the args of a given node, find the arg that is not the # last node we traversed. This is to help us find skip connections # across shards. for arg in node.args: # If the node has args that are inputs to the forward function, they # may not have explicit names. if not hasattr(arg, "name"): continue if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[ -1]: if node_name_to_shard_id[arg.name] < min_shard_id: min_shard_id = node_name_to_shard_id[arg.name] min_node_name = arg.name # If there is an input that is not from the previous shard, # we collapse all the shards in between to be part of 1 shard. # and update the param count per shard accordingly. if min_shard_id < shard_id: for node_name in reversed(nodes_so_far): node_name_to_shard_id[node_name] = min_shard_id if node_name == min_node_name: break shard_id = min_shard_id # TODO(anj-s): Find a way to raise an error early if this can cause OOM errors. shard_to_param_count = _create_shard_to_param_count( param_count, node_name_to_shard_id) # Update state that is tracking node -> shard id and shard id -> param count. node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) # TODO(anj): This could just be an update, we don't need to recreate the map. shard_to_param_count = _create_shard_to_param_count( param_count, node_name_to_shard_id) # If we have gone over the number of params per shard count that we want to # achieve, we should add a new shard. # The shard_id may not have been updated in the map if we are at a node that does not # have params. if shard_id in shard_to_param_count and shard_to_param_count[ shard_id] > per_shard_param: shard_id += 1 elif node.op == "output": break return node_name_to_shard_id