def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule): for node in gm.graph.nodes: if node.target != acc_ops.sum: continue prev_node = node.kwargs["input"] if prev_node.target != acc_ops.cat or len(prev_node.kwargs["tensors"]) != 2: continue lhs, rhs = prev_node.kwargs["tensors"][0], prev_node.kwargs["tensors"][1] if lhs.target != acc_ops.unsqueeze or rhs.target != acc_ops.unsqueeze: continue lhs_input = lhs.kwargs["input"] rhs_input = rhs.kwargs["input"] # prerequisite check cond1 = lhs.kwargs["dim"] == 0 and rhs.kwargs["dim"] == 0 cond2 = prev_node.kwargs["dim"] == 0 if not cond1 or not cond2: continue with gm.graph.inserting_before(node): fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": lhs_input, "other": rhs_input}) node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() return gm
def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule): for node in gm.graph.nodes: if node.target != acc_ops.sum: continue prev_node = node.kwargs["input"] if prev_node.target != acc_ops.cat or prev_node.kwargs["dim"] != 0: continue cat_inputs = prev_node.kwargs["tensors"] valid_pass = True for i in cat_inputs: if i.target != acc_ops.unsqueeze or i.kwargs["dim"] != 0: valid_pass = False break if not valid_pass: continue input_val = [i.kwargs["input"] for i in cat_inputs] with gm.graph.inserting_before(node): left = input_val[0] for i in range(1, len(input_val)): right = input_val[i] fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": left, "other": right}) left = fused_node node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() return gm
def fuse_permute_matmul(gm: torch.fx.GraphModule): """ Fuse pattern like permute + matmul if permute is transposing the last two dimension. """ for node in gm.graph.nodes: if node.target == acc_ops.matmul: lhs, rhs = node.kwargs["input"], node.kwargs["other"] lhs_transposed = rhs_tranposed = False skip = False if lhs.target == acc_ops.permute and check_permute(lhs): lhs_transposed = True lhs = lhs.kwargs["input"] if rhs.target == acc_ops.permute and check_permute(rhs): rhs_tranposed = True rhs = rhs.kwargs["input"] if (not skip) and (lhs_transposed or rhs_tranposed): with gm.graph.inserting_before(node): fused_node = gm.graph.call_function(trt_transposed_matmul, args=(lhs, rhs, lhs_transposed, rhs_tranposed)) node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() return gm
def fuse_permute_matmul(gm: torch.fx.GraphModule): """ Fuse pattern like permute + matmul if permute is transposing the last two dimension. """ def check_permute(node: torch.fx.Node): ranks = len(node.meta["tensor_meta"].shape) permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] allowed_permutation = list(i for i in range(ranks)) allowed_permutation[-1] = ranks - 2 allowed_permutation[-2] = ranks - 1 return len(node.users) == 1 and permutation == allowed_permutation for node in gm.graph.nodes: if node.target == acc_ops.matmul: lhs, rhs = node.kwargs["input"], node.kwargs["other"] lhs_transposed = rhs_tranposed = False if lhs.target == acc_ops.permute and check_permute(lhs): lhs_transposed = True lhs = lhs.kwargs["input"] if rhs.target == acc_ops.permute and check_permute(rhs): rhs_tranposed = True rhs = rhs.kwargs["input"] if lhs_transposed or rhs_tranposed: with gm.graph.inserting_before(node): fused_node = gm.graph.call_function(trt_transposed_matmul, args=(lhs, rhs, lhs_transposed, rhs_tranposed)) node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.recompile() return gm
def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node): if node.op != "call_module": return node.target # Find the module that node.target points to m = dict(mod.named_modules())[node.target] return getattr(m, "_base_class_origin", type(m))
def __init__(self, module: torch.fx.GraphModule, input_shapes: List[InputTensorSpec], logger_level=trt.Logger.WARNING): # Preprocess the model module = copy.deepcopy(module) module = module.cpu().float() module = NormalizeArgs(module).transform() super().__init__(module) self.logger = trt.Logger(logger_level) self.builder = trt.Builder(self.logger) # TODO: explicit batching # EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) # self.network = self.builder.create_network(EXPLICIT_BATCH) self.network = self.builder.create_network() self.input_shape_itr = iter(input_shapes) self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []
def legalize_graph(gm: torch.fx.GraphModule): """ Replace the graph of the given GraphModule with one that contains the same nodes as the original, but in topologically sorted order. This is used by the merge_matmul transformation below, which disturbs the topologically sorted order of its input GraphModule, so that this order is restored before further transformation. Arguments: gm: The graph module to topologically sort. It is modified in-place. """ # Build an adjacency list representation of node dependencies in the graph. This also # serves as a list of nodes that still need to be inserted into the new, topologically # sorted graph. dependencies = { node: node.all_input_nodes.copy() for node in gm.graph.nodes } # Construct a new graph that will contain all nodes in topologically sorted order. new_graph = torch.fx.Graph() value_remap: Dict[torch.fx.Node, torch.fx.Node] = {} # Copy over all nodes with no dependencies. for node, deps in dependencies.items(): if not deps: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) # Remove the copied over nodes from the adjacency list. for copied_node in value_remap.keys(): del dependencies[copied_node] # While there are still nodes to insert into the new graph: while dependencies: copied_this_round = [] # Copy over all nodes whose dependencies already exist in the new graph. for node, deps in dependencies.items(): all_deps_copied = True for dep in deps: if dep not in value_remap: all_deps_copied = False if all_deps_copied: value_remap[node] = new_graph.node_copy( node, lambda n: value_remap[n]) copied_this_round.append(node) # Delete all nodes copied over in this iteration from dependencies. for copied_node in copied_this_round: del dependencies[copied_node] # Replace the old graph with the new, topologically sorted one. gm.graph = new_graph
def fuse_permute_linear(gm: torch.fx.GraphModule): """ Fuse pattern like permute + linear if permute is transposing the last two dimension. """ for node in gm.graph.nodes: if node.target == acc_ops.linear: inp = node.kwargs["input"] if inp.target == acc_ops.permute and check_permute(inp): inp = inp.kwargs["input"] weight = node.kwargs["weight"] bias = node.kwargs["bias"] with gm.graph.inserting_before(node): fused_node = gm.graph.call_function(trt_transposed_linear, args=(inp, weight, bias)) node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() return gm
def _remove_exceptions(gm: torch.fx.GraphModule) -> bool: """ Unconditionally removes all call_modules to ConditionalExceptionWrappers found in GraphModule gm. Returns whether the graph is modified. """ changed = False for node in gm.graph.nodes: if node.op == "call_module" and isinstance( gm.get_submodule(node.target), ConditionalExceptionWrapper): gm.graph.erase_node(node) changed = True return changed
def verify_split_model( mod: torch.fx.GraphModule, acc_submodule_keyword: str = ACC_SUBMODULE_PREFIX, expected_number: int = 1, ) -> None: acc_submodule_num = 0 for name, _ in mod.named_children(): if name.startswith(acc_submodule_keyword): acc_submodule_num = acc_submodule_num + 1 if acc_submodule_num < expected_number: raise RuntimeError(ERROR_MSG_NO_ACC_MODULE) elif acc_submodule_num > expected_number: raise RuntimeError(ERROR_MSG_MULTI_ACC_MODULES)
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Replace the graph of the given GraphModule with one that contains the same nodes as the original, but in topologically sorted order. This is used by the merge_matmul transformation below, which disturbs the topologically sorted order of its input GraphModule, so that this order is restored before further transformation. Arguments: gm: The graph module to topologically sort. It is modified in-place. Returns: The graph module in-place sorted """ indeg = {node: 0 for node in gm.graph.nodes} new_graph = torch.fx.Graph() # Track how many unfulfilled dependencies each node has for node in gm.graph.nodes: for user in node.users: indeg[user] += 1 queue: collections.deque = collections.deque() # Add all nodes with no dependencies to the queue for node in gm.graph.nodes: if indeg[node] == 0: queue.append(node) env: Dict[torch.fx.Node, torch.fx.Node] = {} # Pop nodes from the queue, and add nodes that have had all their # dependencies fulfilled while len(queue) > 0: cur = queue.popleft() env[cur] = new_graph.node_copy(cur, lambda x: env[x]) for user in cur.users: indeg[user] -= 1 if indeg[user] == 0: queue.append(user) # If the new graph's size is not as large as the old one, then there must be # a cycle (i.e. some node's dependencies were not satisfied.) if len(new_graph.nodes) < len(gm.graph.nodes): raise RuntimeError( f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" ) gm.graph = new_graph return gm
def __init__(self, module: torch.fx.GraphModule, input_shapes: List[InputTensorSpec], logger_level=trt.Logger.WARNING): # Preprocess the model module = copy.copy(module) module = module.cpu() module = NormalizeArgs(module).transform() super().__init__(module) self.logger = trt.Logger(logger_level) self.builder = trt.Builder(self.logger) self.network = self.builder.create_network() self.input_shape_itr = iter(input_shapes) self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []
def fuse_sparse_matmul_add(gm: torch.fx.GraphModule): """ Replace acc_ops.matmul + acc_ops.add with acc_ops.linear TRT8.2 can take advantage of structured sparsity (2:4), but the graph needs contain a single FC layer. Later versions of TRT should work with matmul. Example before: def forward(self, x): a = self.a b = self.b addmm_mm = torch.fx.experimental.fx_acc.acc_ops.matmul(input = a, other = b); a = b = None addmm_add = torch.fx.experimental.fx_acc.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None return addmm_add After: def forward(self, x): a = self.a b = self.b linear_1 = torch.fx.experimental.fx_acc.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None return linear_1 """ counter = 0 for node in gm.graph.nodes: if node.target != acc_ops.add: continue add_node = node bias = add_node.kwargs["other"] if bias.op != "get_attr": continue # test that bias tensor is one-dimensional, should correspond to shape (out_features) if get_attr(bias).dim() > 1: continue node = add_node.kwargs["input"] if node.target != acc_ops.matmul: continue matmul_node = node a = matmul_node.kwargs["input"] node = matmul_node.kwargs["other"] if node.op != "get_attr": continue get_attr_node = node weight = get_attr(get_attr_node) # TODO: verify that weight comply with TRT structured sparsity requirements: # For each output channel and for each spatial pixel in the kernel weights, # every 4 input channels must have at least 2 zeros. # test that weight tensor is two-dimensional, should correspond to shape (out_features, in_features) if weight.dim() != 2: continue weight_t = weight.transpose(0, 1) weight_t_name = "weight_t_tensor_" + str(counter) gm.register_buffer(weight_t_name, weight_t) counter += 1 with gm.graph.inserting_before(add_node): weight_t_attr = gm.graph.get_attr(weight_t_name) fused_node = gm.graph.call_function(acc_ops.linear, kwargs={"input": a, "weight": weight_t_attr, "bias": bias}) add_node.replace_all_uses_with(fused_node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() return gm
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: """Utility used to trace a graph and identify shard cutpoints.""" node_name_to_shard_id: Dict[str, int] = {} shard_id = 0 nodes_so_far = [] param_count: Dict[str, int] = {} shard_to_param_count = {} # Find the total number of params in the model and # the number of params per shard we are aiming for. for name, module in traced_graph_module.named_modules(): name = name.replace(".", "_") param_count[name] = sum([x.numel() for x in module.parameters()]) logging.info(f"Total number of params are {param_count['']}") per_shard_param = param_count[""] // shard_count logging.info(f"Per shard param count {per_shard_param}") for node in traced_graph_module.graph.nodes: if node.op == "placeholder": node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) elif node.op in [ "get_attr", "call_function", "call_method", "call_module" ]: min_shard_id = shard_id min_node_name = "" # For each of the args of a given node, find the arg that is not the # last node we traversed. This is to help us find skip connections # across shards. for arg in node.args: # If the node has args that are inputs to the forward function, they # may not have explicit names. if not hasattr(arg, "name"): continue if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[ -1]: if node_name_to_shard_id[arg.name] < min_shard_id: min_shard_id = node_name_to_shard_id[arg.name] min_node_name = arg.name # If there is an input that is not from the previous shard, # we collapse all the shards in between to be part of 1 shard. # and update the param count per shard accordingly. if min_shard_id < shard_id: for node_name in reversed(nodes_so_far): node_name_to_shard_id[node_name] = min_shard_id if node_name == min_node_name: break shard_id = min_shard_id # TODO(anj-s): Find a way to raise an error early if this can cause OOM errors. shard_to_param_count = _create_shard_to_param_count( param_count, node_name_to_shard_id) # Update state that is tracking node -> shard id and shard id -> param count. node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) # TODO(anj): This could just be an update, we don't need to recreate the map. shard_to_param_count = _create_shard_to_param_count( param_count, node_name_to_shard_id) # If we have gone over the number of params per shard count that we want to # achieve, we should add a new shard. # The shard_id may not have been updated in the map if we are at a node that does not # have params. if shard_id in shard_to_param_count and shard_to_param_count[ shard_id] > per_shard_param: shard_id += 1 elif node.op == "output": break return node_name_to_shard_id