Exemple #1
0
    def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node):
        if node.op != "call_module":
            return node.target

        # Find the module that node.target points to
        m = dict(mod.named_modules())[node.target]
        return getattr(m, "_base_class_origin", type(m))
Exemple #2
0
def _split_nodes(traced_graph_module: torch.fx.GraphModule,
                 shard_count: int = 3) -> Dict:
    """Utility used to trace a graph and identify shard cutpoints."""

    node_name_to_shard_id: Dict[str, int] = {}
    shard_id = 0
    nodes_so_far = []
    param_count: Dict[str, int] = {}
    shard_to_param_count = {}

    # Find the total number of params in the model and
    # the number of params per shard we are aiming for.
    for name, module in traced_graph_module.named_modules():
        name = name.replace(".", "_")
        param_count[name] = sum([x.numel() for x in module.parameters()])
    logging.info(f"Total number of params are {param_count['']}")
    per_shard_param = param_count[""] // shard_count
    logging.info(f"Per shard param count {per_shard_param}")

    for node in traced_graph_module.graph.nodes:
        if node.op == "placeholder":
            node_name_to_shard_id[node.name] = shard_id
            nodes_so_far.append(node.name)
        elif node.op in [
                "get_attr", "call_function", "call_method", "call_module"
        ]:

            min_shard_id = shard_id
            min_node_name = ""
            # For each of the args of a given node, find the arg that is not the
            # last node we traversed. This is to help us find skip connections
            # across shards.
            for arg in node.args:
                # If the node has args that are inputs to the forward function, they
                # may not have explicit names.
                if not hasattr(arg, "name"):
                    continue

                if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[
                        -1]:
                    if node_name_to_shard_id[arg.name] < min_shard_id:
                        min_shard_id = node_name_to_shard_id[arg.name]
                        min_node_name = arg.name

            # If there is an input that is not from the previous shard,
            # we collapse all the shards in between to be part of 1 shard.
            # and update the param count per shard accordingly.
            if min_shard_id < shard_id:
                for node_name in reversed(nodes_so_far):
                    node_name_to_shard_id[node_name] = min_shard_id
                    if node_name == min_node_name:
                        break
                shard_id = min_shard_id
                # TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
                shard_to_param_count = _create_shard_to_param_count(
                    param_count, node_name_to_shard_id)

            # Update state that is tracking node -> shard id and shard id -> param count.
            node_name_to_shard_id[node.name] = shard_id
            nodes_so_far.append(node.name)
            # TODO(anj): This could just be an update, we don't need to recreate the map.
            shard_to_param_count = _create_shard_to_param_count(
                param_count, node_name_to_shard_id)
            # If we have gone over the number of params per shard count that we want to
            # achieve, we should add a new shard.
            # The shard_id may not have been updated in the map if we are at a node that does not
            # have params.
            if shard_id in shard_to_param_count and shard_to_param_count[
                    shard_id] > per_shard_param:
                shard_id += 1
        elif node.op == "output":
            break
    return node_name_to_shard_id