def _execute_optimize(self, raw_graph): graph_searcher = GraphSearcher(raw_graph) # classification node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["t", "addmm"])]) OptPass.merge_param_transpose_with_addmm(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["TupleConstruct", "TupleUnpack"])]) OptPass.penetrate_pack_unpack(raw_graph, node_sets) # shufflenet node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListUnpack"]), PatternType(pattern=["TupleUnpack"]) ]) OptPass.unpack_ListUnpack_op(raw_graph, node_sets) # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["slice"])]) # OptPass.slice_to_strided_slice(raw_graph, node_sets) # yolo_v3 # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["select", "copy_"])]) # OptPass.select_to_slice_inplace_copy(raw_graph, node_sets) # 3d pointpillar # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["strided_slice", "index_put_"])]) # OptPass.stride_slice_to_index_inplace_put(raw_graph, node_sets) # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["strided_slice", "copy_"])]) # OptPass.create_stride_slice_inplace_copy(raw_graph, node_sets) # nd(>2) linear (JIRA 2646) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["matmul", "add"]), PatternType(pattern=["matmul", "add_"]) ]) OptPass.merge_matmul_with_add(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListConstruct"]), PatternType(pattern=["TupleConstruct"]) ]) OptPass.pack_ListConstruct_op(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["embedding_bag"])]) OptPass.strip_reduantant_tensors_in_embedding_bag(raw_graph, node_sets) return raw_graph
def FuseBnToConv(self): # find fusable bathnorm node fuse_bn_handler = ConvBnHandler() graph_searcher = GraphSearcher(self._graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler) ]) for id, node_list in node_sets.items(): for nodeset in node_list: _, bn_node = nodeset self._graph.remove_node(bn_node)
def collect_layer_act_pair(self): graph_searcher = GraphSearcher(self.graph) patterns = [] tuning_ops = [NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.CONVTRANSPOSE3D] tail_act_ops = [NNDCT_OP.RELU, NNDCT_OP.RELU6, NNDCT_OP.HSWISH, NNDCT_OP.HSIGMOID] for tuning_op in tuning_ops: for act_op in tail_act_ops: patterns.append(PatternType(pattern=[tuning_op, act_op])) node_sets = graph_searcher.find_nodes_from_type(patterns) layer_act_group = {} for _, node_list in node_sets.items(): for nodeset in node_list: conv, act = nodeset layer_act_group[conv] = act return layer_act_group
def finetune(self, run_fn, run_args): if self.quantizer.quant_mode == 2: NndctScreenLogger().warning( f"Finetune function will be ignored in test mode!") return NndctScreenLogger().info( f"=>Finetuning module parameters for better quantization accuracy... " ) # backup option value opt_bak_param_corr = NndctOption.nndct_param_corr.value set_option_value("nndct_param_corr", 0) # cache input and output #print("**** cache input and output") last_quant_nodes = self.collect_last_quant_nodes() with torch.no_grad(): hook_mods = [] for node in self.graph.nodes: if node.op.type == NNDCT_OP.INPUT or \ node in last_quant_nodes: # (self.quantizer.configer.is_node_quantizable(node, False) and # len(node.op.params) > 0): hook_mods.append(node.module) handlers = self.hook_cache_output(hook_mods) set_option_value("nndct_quant_off", True) run_fn(*run_args) self.clean_hooks(handlers) # for mod in self.quant_model.modules(): # if hasattr(mod, "node") and mod.node.op.type in [NNDCT_OP.DENSE, NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D]: # self._float_weights[mod.node].append(mod.weight.detach().cpu()) torch.cuda.empty_cache() # calibration to get a set of quantization steps #print("****calibration to get float model tensor values") for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) # evaluation to get float model tensors set_option_value("nndct_quant_off", False) with torch.no_grad(): run_fn(*run_args) torch.cuda.empty_cache() #print("****Parameter finetuning") device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE) graph_searcher = GraphSearcher(self.graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]) ]) layer_act_group = {} for _, node_list in node_sets.items(): for nodeset in node_list: conv, act = nodeset layer_act_group[conv] = act # to avoid quantization steps change among parameter finetuning self.quantizer.quant_mode = 2 net_inputs = [] for node in self.input_nodes: cached_net_input = [ out for out in self.cached_outputs[node.module] ] net_inputs.append(cached_net_input) # last_quant_nodes = self.collect_last_quant_nodes() last_quant_mods = [node.module for node in last_quant_nodes] handlers = self.hook_cache_output(last_quant_mods, hook_type="single") net_loss = self.eval_loss(net_inputs, last_quant_mods, device) self.clean_hooks(handlers) # model.clean_hooks() torch.cuda.empty_cache() finetune_group = {} # hook_mods = [] for qmod, fmod in zip(self._quant_model.modules(), self._float_model.modules()): if hasattr(qmod, "node"): if (self.quantizer.configer.is_node_quantizable( qmod.node, False) and len(qmod.node.op.params) > 0): finetune_group[qmod.node] = [qmod, fmod] # hook_mods.append(fmod) # self.hook_cache_output(hook_mods, hook_type="single") for node, module_pair in finetune_group.items(): # if self.quantizer.configer.is_node_quantizable(node, False) and \ # len(node.op.params) > 0: quant_layer, float_layer = module_pair pn_node = self.graph.parents(node)[0] handlers = self.hook_cache_output([pn_node.module], hook_type="single") layer_inputs = [] with torch.no_grad(): for input_args in zip(*net_inputs): new_input_args = [] for ip in input_args: if isinstance(ip, torch.Tensor): new_input_args.append(ip.to(device)) _ = self.quant_model(*new_input_args) layer_inputs.append( self.cached_output[pn_node.module].detach().cpu()) self.clean_hooks(handlers) del self.cached_output[pn_node.module] #print(f"Tuning {node.name}") net_loss = self.optimize_layer(node, float_layer, layer_inputs, layer_act_group, net_inputs, net_loss, last_quant_mods, device) del layer_inputs torch.cuda.empty_cache() # recover quantizer status for node in self.graph.nodes: for _, fp_history in self.quantizer.fp_history.items(): if node.name in fp_history: fp_history[node.name].clear() for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) for mod in self.quant_model.modules(): if hasattr(mod, "param_saved"): setattr(mod, "param_saved", False) self.quantizer.quant_mode = 1 set_option_value("nndct_param_corr", opt_bak_param_corr) # export finetuned parameters self.quantizer.export_param()
def FuseBnToConv(self): # find fusable bathnorm node fuse_bn_handler = ConvBnHandler() graph_searcher = GraphSearcher(self._graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType( pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.BATCH_NORM], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.CONV2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.CONV3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), PatternType(pattern=[ NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM ], action=fuse_bn_handler), ]) removed_bn = set() for id, node_list in node_sets.items(): for nodeset in node_list: bn_node = nodeset[-1] if bn_node.merged and bn_node not in removed_bn: self._graph.remove_node(bn_node) removed_bn.add(bn_node)
def finetune(self, run_fn, run_args): if self.quantizer.quant_mode == 2: NndctScreenLogger().warning( f"Finetune function will be ignored in test mode!") return NndctScreenLogger().info( f"=>Preparing data for fast finetuning module parameters ...") # memory status total_m, *_, available_m = list( map(lambda x: x / 1024, map(int, os.popen('free -t -m').readlines()[1].split()[1:]))) NndctScreenLogger().info( f"Mem status(total mem: {total_m:.2f}G, available mem: {available_m:.2f}G)." ) NndctScreenLogger().info( f"=>Preparing data for fast finetuning module parameters ...") # backup option value opt_bak_param_corr = NndctOption.nndct_param_corr.value set_option_value("nndct_param_corr", 0) # cache input and output #print("**** cache input and output") last_quant_nodes = self.collect_last_quant_nodes() with torch.no_grad(): cache_layers = [] monitor_layers = [] for node in self.graph.nodes: if node.op.type == NNDCT_OP.INPUT or node in last_quant_nodes: cache_layers.append(node.module) elif self.quantizer.configer.is_conv_like(node): monitor_layers.append(node.module) monitor_handlers = self.hook_memory_monitor(monitor_layers) cache_handlers = self.hook_cache_output(cache_layers, monitor_mem=True) set_option_value("nndct_quant_off", True) run_fn(*run_args) # memory statistics total_memory_cost = 0.0 for layer in cache_layers: total_memory_cost += self._mem_count[layer] del self._mem_count[layer] total_memory_cost += 2 * max(self._mem_count.values()) self.clean_hooks(monitor_handlers + cache_handlers) torch.cuda.empty_cache() NndctScreenLogger().info( f"Mem cost by fast finetuning: {total_memory_cost:.2f}G.") if total_memory_cost > 0.8 * available_m: NndctScreenLogger().warning( f"There is not enought memory for fast finetuning and this process will be ignored!.Try to use a smaller calibration dataset." ) return # calibration to get a set of quantization steps #print("****calibration to get float model tensor values") for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) # evaluation to get float model tensors set_option_value("nndct_quant_off", False) with torch.no_grad(): run_fn(*run_args) torch.cuda.empty_cache() #print("****Parameter finetuning") NndctScreenLogger().info( f"=>Fast finetuning module parameters for better quantization accuracy..." ) device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE) graph_searcher = GraphSearcher(self.graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSWISH]), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSWISH]), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU6]), ]) layer_act_group = {} for _, node_list in node_sets.items(): for nodeset in node_list: conv, act = nodeset layer_act_group[conv] = act # to avoid quantization steps change among parameter finetuning self.quantizer.quant_mode = 2 net_inputs = [] for node in self.input_nodes: cached_net_input = [ out for out in self.cached_outputs[node.module] ] net_inputs.append(cached_net_input) # last_quant_nodes = self.collect_last_quant_nodes() last_quant_mods = [node.module for node in last_quant_nodes] handlers = self.hook_cache_output(last_quant_mods, hook_type="single") net_loss = self.eval_loss(net_inputs, last_quant_mods, device) self.clean_hooks(handlers) # model.clean_hooks() torch.cuda.empty_cache() finetune_group = {} # hook_mods = [] for qmod, fmod in zip(self._quant_model.modules(), self._float_model.modules()): if hasattr(qmod, "node"): if (self.quantizer.configer.is_node_quantizable( qmod.node, False) and len(qmod.node.op.params) > 0): finetune_group[qmod.node] = [qmod, fmod] # hook_mods.append(fmod) # self.hook_cache_output(hook_mods, hook_type="single") #for node, module_pair in finetune_group.items(): for idx, (node, module_pair) in tqdm(enumerate(finetune_group.items()), total=len(finetune_group.items())): # if self.quantizer.configer.is_node_quantizable(node, False) and \ # len(node.op.params) > 0: quant_layer, float_layer = module_pair pn_node = self.graph.parents(node)[0] handlers = self.hook_cache_output([pn_node.module], hook_type="single") layer_inputs = [] with torch.no_grad(): for input_args in zip(*net_inputs): new_input_args = [] for ip in input_args: if isinstance(ip, torch.Tensor): new_input_args.append(ip.to(device)) _ = self.quant_model(*new_input_args) layer_inputs.append( self.cached_output[pn_node.module].detach().cpu()) self.clean_hooks(handlers) del self.cached_output[pn_node.module] #print(f"Tuning {node.name}") net_loss = self.optimize_layer(node, float_layer, layer_inputs, layer_act_group, net_inputs, net_loss, last_quant_mods, device) # print(f"{node.name}:{net_loss}") del layer_inputs torch.cuda.empty_cache() # recover quantizer status for node in self.graph.nodes: for _, config_history in self.quantizer.config_history.items(): if node.name in config_history: config_history[node.name].clear() for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) for mod in self.quant_model.modules(): if hasattr(mod, "param_saved"): setattr(mod, "param_saved", False) self.quantizer.quant_mode = 1 set_option_value("nndct_param_corr", opt_bak_param_corr) NndctScreenLogger().info(f"=>Export fast finetuned parameters ...") # export finetuned parameters self.quantizer.export_param()
def _opt_raw_graph(self, raw_graph): graph_searcher = GraphSearcher(raw_graph) # torch 1.8.x will generate redundant type_as op before element-wise add. if get_torch_version() >= 180 and get_torch_version() < 190: node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["type_as", "add"])]) OptPass.merge_internal_type_as(raw_graph, node_sets) # classification node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["t", "addmm"]), PatternType(pattern=["t", "matmul"]) ]) OptPass.merge_param_transpose_with_addmm(raw_graph, node_sets) # shufflenet node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListUnpack"]), PatternType(pattern=["TupleUnpack"]) ]) OptPass.unpack_ListUnpack_op(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["slice"])]) OptPass.slice_to_strided_slice(raw_graph, node_sets) # yolo_v3 node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["select", "copy_"])]) OptPass.select_to_slice_inplace_copy(raw_graph, node_sets) # FADnet while True: node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["select", "strided_slice"])]) if not OptPass.merge_select_to_strided_slice(raw_graph, node_sets): break # FADnet node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["strided_slice", "strided_slice"])]) OptPass.merge_consecutive_strided_slice(raw_graph, node_sets) # 3d pointpillar node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["strided_slice", "index_put_"])]) OptPass.stride_slice_to_index_inplace_put(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["strided_slice", "copy_"])]) OptPass.create_stride_slice_inplace_copy(raw_graph, node_sets) # nd(>2) linear (JIRA 2646) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["matmul", "add"]), PatternType(pattern=["matmul", "add_"]) ]) OptPass.merge_matmul_with_add(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListConstruct"]), PatternType(pattern=["TupleConstruct"]) ]) OptPass.pack_ListConstruct_op(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["embedding_bag"])]) OptPass.strip_reduantant_tensors_in_embedding_bag(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["empty", "zero_"])]) OptPass.merge_empty_with_zero(raw_graph, node_sets) # delete node should be done after merge stride_slice # delete reduantant view FADnet. node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["view"])]) OptPass.remove_reduantant_view(raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["add"]), PatternType(pattern=["sub"]), PatternType(pattern=["mul"]), PatternType(pattern=["div"]) ]) OptPass.transform_const_scalar_to_const_tensor(raw_graph, node_sets) return raw_graph
def _opt_raw_graph(self, raw_graph): graph_searcher = GraphSearcher(raw_graph) # shufflenet node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListUnpack"]), PatternType(pattern=["TupleUnpack"]) ]) OptPass.unpack_ListUnpack_op(self, raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["slice"])]) OptPass.slice_to_strided_slice(self, raw_graph, node_sets) # yolo_v3 node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["select", "copy_"])]) OptPass.select_to_slice_inplace_copy(self, raw_graph, node_sets) # 3d pointpillar node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["strided_slice", "index_put_"])]) OptPass.stride_slice_to_index_inplace_put(self, raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["strided_slice", "copy_"])]) OptPass.create_stride_slice_inplace_copy(self, raw_graph, node_sets) # nd(>2) linear (JIRA 2646) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["matmul", "add"]), PatternType(pattern=["matmul", "add_"]) ]) OptPass.merge_matmul_with_add(self, raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=["ListConstruct"]), PatternType(pattern=["TupleConstruct"]) ]) OptPass.pack_ListConstruct_op(self, raw_graph, node_sets) node_sets = graph_searcher.find_nodes_from_type( [PatternType(pattern=["embedding_bag"])]) OptPass.strip_reduantant_tensors_in_embedding_bag( self, raw_graph, node_sets) return raw_graph