Exemplo n.º 1
0
def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule):
    for node in gm.graph.nodes:
        if node.target != acc_ops.sum:
            continue
        prev_node = node.kwargs["input"]
        if prev_node.target != acc_ops.cat or len(prev_node.kwargs["tensors"]) != 2:
            continue
        lhs, rhs = prev_node.kwargs["tensors"][0], prev_node.kwargs["tensors"][1]
        if lhs.target != acc_ops.unsqueeze or rhs.target != acc_ops.unsqueeze:
            continue
        lhs_input = lhs.kwargs["input"]
        rhs_input = rhs.kwargs["input"]
        # prerequisite check
        cond1 = lhs.kwargs["dim"] == 0 and rhs.kwargs["dim"] == 0
        cond2 = prev_node.kwargs["dim"] == 0
        if not cond1 or not cond2:
            continue
        with gm.graph.inserting_before(node):
            fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": lhs_input, "other": rhs_input})
        node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
    return gm
Exemplo n.º 2
0
def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule):
    for node in gm.graph.nodes:
        if node.target != acc_ops.sum:
            continue
        prev_node = node.kwargs["input"]
        if prev_node.target != acc_ops.cat or prev_node.kwargs["dim"] != 0:
            continue
        cat_inputs = prev_node.kwargs["tensors"]
        valid_pass = True
        for i in cat_inputs:
            if i.target != acc_ops.unsqueeze or i.kwargs["dim"] != 0:
                valid_pass = False
                break

        if not valid_pass:
            continue
        input_val = [i.kwargs["input"] for i in cat_inputs]

        with gm.graph.inserting_before(node):
            left = input_val[0]
            for i in range(1, len(input_val)):
                right = input_val[i]
                fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": left, "other": right})
                left = fused_node
        node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
    return gm
Exemplo n.º 3
0
def fuse_permute_matmul(gm: torch.fx.GraphModule):
    """
    Fuse pattern like permute + matmul if permute is transposing the last two dimension.
    """
    for node in gm.graph.nodes:
        if node.target == acc_ops.matmul:
            lhs, rhs = node.kwargs["input"], node.kwargs["other"]
            lhs_transposed = rhs_tranposed = False
            skip = False

            if lhs.target == acc_ops.permute and check_permute(lhs):
                lhs_transposed = True
                lhs = lhs.kwargs["input"]

            if rhs.target == acc_ops.permute and check_permute(rhs):
                rhs_tranposed = True
                rhs = rhs.kwargs["input"]

            if (not skip) and (lhs_transposed or rhs_tranposed):
                with gm.graph.inserting_before(node):
                    fused_node = gm.graph.call_function(trt_transposed_matmul, args=(lhs, rhs, lhs_transposed, rhs_tranposed))
                node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
    return gm
Exemplo n.º 4
0
def fuse_permute_matmul(gm: torch.fx.GraphModule):
    """
    Fuse pattern like permute + matmul if permute is transposing the last two dimension.
    """

    def check_permute(node: torch.fx.Node):
        ranks = len(node.meta["tensor_meta"].shape)
        permutation = list(i % ranks for i in node.kwargs["permutation"])  # type: ignore[union-attr]
        allowed_permutation = list(i for i in range(ranks))
        allowed_permutation[-1] = ranks - 2
        allowed_permutation[-2] = ranks - 1
        return len(node.users) == 1 and permutation == allowed_permutation

    for node in gm.graph.nodes:
        if node.target == acc_ops.matmul:
            lhs, rhs = node.kwargs["input"], node.kwargs["other"]
            lhs_transposed = rhs_tranposed = False

            if lhs.target == acc_ops.permute and check_permute(lhs):
                lhs_transposed = True
                lhs = lhs.kwargs["input"]

            if rhs.target == acc_ops.permute and check_permute(rhs):
                rhs_tranposed = True
                rhs = rhs.kwargs["input"]

            if lhs_transposed or rhs_tranposed:
                with gm.graph.inserting_before(node):
                    fused_node = gm.graph.call_function(trt_transposed_matmul, args=(lhs, rhs, lhs_transposed, rhs_tranposed))
                node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.recompile()
    return gm
Exemplo n.º 5
0
    def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node):
        if node.op != "call_module":
            return node.target

        # Find the module that node.target points to
        m = dict(mod.named_modules())[node.target]
        return getattr(m, "_base_class_origin", type(m))
Exemplo n.º 6
0
    def __init__(self,
                 module: torch.fx.GraphModule,
                 input_shapes: List[InputTensorSpec],
                 logger_level=trt.Logger.WARNING):
        # Preprocess the model
        module = copy.deepcopy(module)
        module = module.cpu().float()
        module = NormalizeArgs(module).transform()
        super().__init__(module)

        self.logger = trt.Logger(logger_level)
        self.builder = trt.Builder(self.logger)

        # TODO: explicit batching
        # EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        # self.network = self.builder.create_network(EXPLICIT_BATCH)

        self.network = self.builder.create_network()

        self.input_shape_itr = iter(input_shapes)

        self._cur_node_name: Optional[str] = None

        self._input_names: List[str] = []
        self._output_names: List[str] = []
Exemplo n.º 7
0
def legalize_graph(gm: torch.fx.GraphModule):
    """
    Replace the graph of the given GraphModule with one that contains the same nodes as the
    original, but in topologically sorted order.

    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
    order of its input GraphModule, so that this order is restored before further transformation.

    Arguments:
        gm: The graph module to topologically sort. It is modified in-place.

    """
    # Build an adjacency list representation of node dependencies in the graph. This also
    # serves as a list of nodes that still need to be inserted into the new, topologically
    # sorted graph.
    dependencies = {
        node: node.all_input_nodes.copy()
        for node in gm.graph.nodes
    }

    # Construct a new graph that will contain all nodes in topologically sorted order.
    new_graph = torch.fx.Graph()
    value_remap: Dict[torch.fx.Node, torch.fx.Node] = {}

    # Copy over all nodes with no dependencies.
    for node, deps in dependencies.items():
        if not deps:
            value_remap[node] = new_graph.node_copy(node,
                                                    lambda n: value_remap[n])

    # Remove the copied over nodes from the adjacency list.
    for copied_node in value_remap.keys():
        del dependencies[copied_node]

    # While there are still nodes to insert into the new graph:
    while dependencies:
        copied_this_round = []

        # Copy over all nodes whose dependencies already exist in the new graph.
        for node, deps in dependencies.items():
            all_deps_copied = True
            for dep in deps:
                if dep not in value_remap:
                    all_deps_copied = False

            if all_deps_copied:
                value_remap[node] = new_graph.node_copy(
                    node, lambda n: value_remap[n])
                copied_this_round.append(node)

        # Delete all nodes copied over in this iteration from dependencies.
        for copied_node in copied_this_round:
            del dependencies[copied_node]

    # Replace the old graph with the new, topologically sorted one.
    gm.graph = new_graph
Exemplo n.º 8
0
def fuse_permute_linear(gm: torch.fx.GraphModule):
    """
    Fuse pattern like permute + linear if permute is transposing the last two dimension.
    """
    for node in gm.graph.nodes:
        if node.target == acc_ops.linear:
            inp = node.kwargs["input"]
            if inp.target == acc_ops.permute and check_permute(inp):
                inp = inp.kwargs["input"]
                weight = node.kwargs["weight"]
                bias = node.kwargs["bias"]
                with gm.graph.inserting_before(node):
                    fused_node = gm.graph.call_function(trt_transposed_linear, args=(inp, weight, bias))
                    node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
    return gm
Exemplo n.º 9
0
def _remove_exceptions(gm: torch.fx.GraphModule) -> bool:
    """
    Unconditionally removes all call_modules to ConditionalExceptionWrappers
    found in GraphModule gm. Returns whether the graph is modified.
    """
    changed = False
    for node in gm.graph.nodes:
        if node.op == "call_module" and isinstance(
                gm.get_submodule(node.target), ConditionalExceptionWrapper):
            gm.graph.erase_node(node)
            changed = True
    return changed
Exemplo n.º 10
0
def verify_split_model(
    mod: torch.fx.GraphModule, acc_submodule_keyword: str = ACC_SUBMODULE_PREFIX, expected_number: int = 1,
) -> None:
    acc_submodule_num = 0
    for name, _ in mod.named_children():
        if name.startswith(acc_submodule_keyword):
            acc_submodule_num = acc_submodule_num + 1

    if acc_submodule_num < expected_number:
        raise RuntimeError(ERROR_MSG_NO_ACC_MODULE)
    elif acc_submodule_num > expected_number:
        raise RuntimeError(ERROR_MSG_MULTI_ACC_MODULES)
Exemplo n.º 11
0
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    """
    Replace the graph of the given GraphModule with one that contains the same nodes as the
    original, but in topologically sorted order.

    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
    order of its input GraphModule, so that this order is restored before further transformation.

    Arguments:
        gm: The graph module to topologically sort. It is modified in-place.

    Returns:
        The graph module in-place sorted
    """
    indeg = {node: 0 for node in gm.graph.nodes}
    new_graph = torch.fx.Graph()
    # Track how many unfulfilled dependencies each node has
    for node in gm.graph.nodes:
        for user in node.users:
            indeg[user] += 1
    queue: collections.deque = collections.deque()
    # Add all nodes with no dependencies to the queue
    for node in gm.graph.nodes:
        if indeg[node] == 0:
            queue.append(node)
    env: Dict[torch.fx.Node, torch.fx.Node] = {}
    # Pop nodes from the queue, and add nodes that have had all their
    # dependencies fulfilled
    while len(queue) > 0:
        cur = queue.popleft()
        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
        for user in cur.users:
            indeg[user] -= 1
            if indeg[user] == 0:
                queue.append(user)
    # If the new graph's size is not as large as the old one, then there must be
    # a cycle (i.e. some node's dependencies were not satisfied.)
    if len(new_graph.nodes) < len(gm.graph.nodes):
        raise RuntimeError(
            f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
        )
    gm.graph = new_graph
    return gm
Exemplo n.º 12
0
    def __init__(self,
                 module: torch.fx.GraphModule,
                 input_shapes: List[InputTensorSpec],
                 logger_level=trt.Logger.WARNING):
        # Preprocess the model
        module = copy.copy(module)
        module = module.cpu()
        module = NormalizeArgs(module).transform()
        super().__init__(module)

        self.logger = trt.Logger(logger_level)
        self.builder = trt.Builder(self.logger)
        self.network = self.builder.create_network()

        self.input_shape_itr = iter(input_shapes)

        self._cur_node_name: Optional[str] = None

        self._input_names: List[str] = []
        self._output_names: List[str] = []
Exemplo n.º 13
0
def fuse_sparse_matmul_add(gm: torch.fx.GraphModule):
    """
    Replace acc_ops.matmul + acc_ops.add with acc_ops.linear
    TRT8.2 can take advantage of structured sparsity (2:4), but the graph needs contain a single FC layer.
    Later versions of TRT should work with matmul.

    Example before:
    def forward(self, x):
        a = self.a
        b = self.b
        addmm_mm = torch.fx.experimental.fx_acc.acc_ops.matmul(input = a, other = b);  a = b = None
        addmm_add = torch.fx.experimental.fx_acc.acc_ops.add(input = addmm_mm, other = x);  addmm_mm = x = None
        return addmm_add

    After:
    def forward(self, x):
        a = self.a
        b = self.b
        linear_1 = torch.fx.experimental.fx_acc.acc_ops.linear(input = a, weight = b, bias = x);  a = b = x = None
        return linear_1
    """
    counter = 0
    for node in gm.graph.nodes:
        if node.target != acc_ops.add:
            continue
        add_node = node
        bias = add_node.kwargs["other"]

        if bias.op != "get_attr":
            continue
        # test that bias tensor is one-dimensional, should correspond to shape (out_features)
        if get_attr(bias).dim() > 1:
            continue

        node = add_node.kwargs["input"]
        if node.target != acc_ops.matmul:
            continue
        matmul_node = node
        a = matmul_node.kwargs["input"]

        node = matmul_node.kwargs["other"]
        if node.op != "get_attr":
            continue

        get_attr_node = node
        weight = get_attr(get_attr_node)
        # TODO: verify that weight comply with TRT structured sparsity requirements:
        # For each output channel and for each spatial pixel in the kernel weights,
        # every 4 input channels must have at least 2 zeros.

        # test that weight tensor is two-dimensional, should correspond to shape (out_features, in_features)
        if weight.dim() != 2:
            continue

        weight_t = weight.transpose(0, 1)
        weight_t_name = "weight_t_tensor_" + str(counter)
        gm.register_buffer(weight_t_name, weight_t)
        counter += 1

        with gm.graph.inserting_before(add_node):
            weight_t_attr = gm.graph.get_attr(weight_t_name)
            fused_node = gm.graph.call_function(acc_ops.linear, kwargs={"input": a, "weight": weight_t_attr, "bias": bias})
        add_node.replace_all_uses_with(fused_node)

    gm.graph.eliminate_dead_code()
    gm.graph.lint()
    gm.recompile()
    return gm
Exemplo n.º 14
0
def _split_nodes(traced_graph_module: torch.fx.GraphModule,
                 shard_count: int = 3) -> Dict:
    """Utility used to trace a graph and identify shard cutpoints."""

    node_name_to_shard_id: Dict[str, int] = {}
    shard_id = 0
    nodes_so_far = []
    param_count: Dict[str, int] = {}
    shard_to_param_count = {}

    # Find the total number of params in the model and
    # the number of params per shard we are aiming for.
    for name, module in traced_graph_module.named_modules():
        name = name.replace(".", "_")
        param_count[name] = sum([x.numel() for x in module.parameters()])
    logging.info(f"Total number of params are {param_count['']}")
    per_shard_param = param_count[""] // shard_count
    logging.info(f"Per shard param count {per_shard_param}")

    for node in traced_graph_module.graph.nodes:
        if node.op == "placeholder":
            node_name_to_shard_id[node.name] = shard_id
            nodes_so_far.append(node.name)
        elif node.op in [
                "get_attr", "call_function", "call_method", "call_module"
        ]:

            min_shard_id = shard_id
            min_node_name = ""
            # For each of the args of a given node, find the arg that is not the
            # last node we traversed. This is to help us find skip connections
            # across shards.
            for arg in node.args:
                # If the node has args that are inputs to the forward function, they
                # may not have explicit names.
                if not hasattr(arg, "name"):
                    continue

                if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[
                        -1]:
                    if node_name_to_shard_id[arg.name] < min_shard_id:
                        min_shard_id = node_name_to_shard_id[arg.name]
                        min_node_name = arg.name

            # If there is an input that is not from the previous shard,
            # we collapse all the shards in between to be part of 1 shard.
            # and update the param count per shard accordingly.
            if min_shard_id < shard_id:
                for node_name in reversed(nodes_so_far):
                    node_name_to_shard_id[node_name] = min_shard_id
                    if node_name == min_node_name:
                        break
                shard_id = min_shard_id
                # TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
                shard_to_param_count = _create_shard_to_param_count(
                    param_count, node_name_to_shard_id)

            # Update state that is tracking node -> shard id and shard id -> param count.
            node_name_to_shard_id[node.name] = shard_id
            nodes_so_far.append(node.name)
            # TODO(anj): This could just be an update, we don't need to recreate the map.
            shard_to_param_count = _create_shard_to_param_count(
                param_count, node_name_to_shard_id)
            # If we have gone over the number of params per shard count that we want to
            # achieve, we should add a new shard.
            # The shard_id may not have been updated in the map if we are at a node that does not
            # have params.
            if shard_id in shard_to_param_count and shard_to_param_count[
                    shard_id] > per_shard_param:
                shard_id += 1
        elif node.op == "output":
            break
    return node_name_to_shard_id