示例#1
0
    def _execute_optimize(self, raw_graph):
        graph_searcher = GraphSearcher(raw_graph)

        # classification
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["t", "addmm"])])
        OptPass.merge_param_transpose_with_addmm(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["TupleConstruct", "TupleUnpack"])])
        OptPass.penetrate_pack_unpack(raw_graph, node_sets)

        # shufflenet
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListUnpack"]),
            PatternType(pattern=["TupleUnpack"])
        ])
        OptPass.unpack_ListUnpack_op(raw_graph, node_sets)

        # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["slice"])])

        # OptPass.slice_to_strided_slice(raw_graph, node_sets)

        # yolo_v3
        # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["select", "copy_"])])
        # OptPass.select_to_slice_inplace_copy(raw_graph, node_sets)

        # 3d pointpillar
        # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["strided_slice", "index_put_"])])
        # OptPass.stride_slice_to_index_inplace_put(raw_graph, node_sets)

        # node_sets = graph_searcher.find_nodes_from_type([PatternType(pattern=["strided_slice", "copy_"])])
        # OptPass.create_stride_slice_inplace_copy(raw_graph, node_sets)

        # nd(>2) linear (JIRA 2646)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["matmul", "add"]),
            PatternType(pattern=["matmul", "add_"])
        ])
        OptPass.merge_matmul_with_add(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListConstruct"]),
            PatternType(pattern=["TupleConstruct"])
        ])
        OptPass.pack_ListConstruct_op(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["embedding_bag"])])
        OptPass.strip_reduantant_tensors_in_embedding_bag(raw_graph, node_sets)

        return raw_graph
示例#2
0
 def FuseBnToConv(self):
     # find fusable bathnorm node
     fuse_bn_handler = ConvBnHandler()
     graph_searcher = GraphSearcher(self._graph)
     node_sets = graph_searcher.find_nodes_from_type([
         PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.BATCH_NORM],
                     action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler)
     ])
     for id, node_list in node_sets.items():
         for nodeset in node_list:
             _, bn_node = nodeset
             self._graph.remove_node(bn_node)
示例#3
0
 def collect_layer_act_pair(self):
   graph_searcher = GraphSearcher(self.graph)
   patterns = []
   tuning_ops = [NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D, 
                 NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.CONVTRANSPOSE3D]
   tail_act_ops = [NNDCT_OP.RELU, NNDCT_OP.RELU6, NNDCT_OP.HSWISH, NNDCT_OP.HSIGMOID]
   for tuning_op in tuning_ops:
     for act_op in tail_act_ops:
       patterns.append(PatternType(pattern=[tuning_op, act_op]))    
   node_sets = graph_searcher.find_nodes_from_type(patterns)
   layer_act_group = {}
   for _, node_list in node_sets.items():
     for nodeset in node_list:
       conv, act = nodeset
       layer_act_group[conv] = act
   return layer_act_group
示例#4
0
    def finetune(self, run_fn, run_args):
        if self.quantizer.quant_mode == 2:
            NndctScreenLogger().warning(
                f"Finetune function will be ignored in test mode!")
            return
        NndctScreenLogger().info(
            f"=>Finetuning module parameters for better quantization accuracy... "
        )

        # backup option value
        opt_bak_param_corr = NndctOption.nndct_param_corr.value
        set_option_value("nndct_param_corr", 0)

        # cache input and output
        #print("**** cache input and output")
        last_quant_nodes = self.collect_last_quant_nodes()
        with torch.no_grad():
            hook_mods = []
            for node in self.graph.nodes:
                if node.op.type == NNDCT_OP.INPUT or \
                node in last_quant_nodes:
                    # (self.quantizer.configer.is_node_quantizable(node, False) and
                    # len(node.op.params) > 0):
                    hook_mods.append(node.module)

            handlers = self.hook_cache_output(hook_mods)

            set_option_value("nndct_quant_off", True)
            run_fn(*run_args)
            self.clean_hooks(handlers)

            # for mod in self.quant_model.modules():
            #   if hasattr(mod, "node") and mod.node.op.type in [NNDCT_OP.DENSE, NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D]:
            #     self._float_weights[mod.node].append(mod.weight.detach().cpu())

        torch.cuda.empty_cache()

        # calibration to get a set of quantization steps
        #print("****calibration to get float model tensor values")
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)

        # evaluation to get float model tensors
        set_option_value("nndct_quant_off", False)
        with torch.no_grad():
            run_fn(*run_args)
        torch.cuda.empty_cache()

        #print("****Parameter finetuning")
        device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE)
        graph_searcher = GraphSearcher(self.graph)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU])
        ])

        layer_act_group = {}
        for _, node_list in node_sets.items():
            for nodeset in node_list:
                conv, act = nodeset
                layer_act_group[conv] = act

        # to avoid quantization steps change among parameter finetuning
        self.quantizer.quant_mode = 2

        net_inputs = []
        for node in self.input_nodes:
            cached_net_input = [
                out for out in self.cached_outputs[node.module]
            ]
            net_inputs.append(cached_net_input)

        # last_quant_nodes = self.collect_last_quant_nodes()
        last_quant_mods = [node.module for node in last_quant_nodes]

        handlers = self.hook_cache_output(last_quant_mods, hook_type="single")
        net_loss = self.eval_loss(net_inputs, last_quant_mods, device)
        self.clean_hooks(handlers)
        # model.clean_hooks()
        torch.cuda.empty_cache()

        finetune_group = {}
        # hook_mods = []
        for qmod, fmod in zip(self._quant_model.modules(),
                              self._float_model.modules()):
            if hasattr(qmod, "node"):
                if (self.quantizer.configer.is_node_quantizable(
                        qmod.node, False) and len(qmod.node.op.params) > 0):
                    finetune_group[qmod.node] = [qmod, fmod]

                    # hook_mods.append(fmod)
        # self.hook_cache_output(hook_mods, hook_type="single")

        for node, module_pair in finetune_group.items():
            # if self.quantizer.configer.is_node_quantizable(node, False) and \
            #   len(node.op.params) > 0:
            quant_layer, float_layer = module_pair
            pn_node = self.graph.parents(node)[0]
            handlers = self.hook_cache_output([pn_node.module],
                                              hook_type="single")
            layer_inputs = []
            with torch.no_grad():
                for input_args in zip(*net_inputs):
                    new_input_args = []
                    for ip in input_args:
                        if isinstance(ip, torch.Tensor):
                            new_input_args.append(ip.to(device))
                    _ = self.quant_model(*new_input_args)

                    layer_inputs.append(
                        self.cached_output[pn_node.module].detach().cpu())
            self.clean_hooks(handlers)
            del self.cached_output[pn_node.module]
            #print(f"Tuning {node.name}")
            net_loss = self.optimize_layer(node, float_layer, layer_inputs,
                                           layer_act_group, net_inputs,
                                           net_loss, last_quant_mods, device)
            del layer_inputs
            torch.cuda.empty_cache()

        # recover quantizer status
        for node in self.graph.nodes:
            for _, fp_history in self.quantizer.fp_history.items():
                if node.name in fp_history:
                    fp_history[node.name].clear()
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_saved"):
                setattr(mod, "param_saved", False)
        self.quantizer.quant_mode = 1
        set_option_value("nndct_param_corr", opt_bak_param_corr)

        # export finetuned parameters
        self.quantizer.export_param()
示例#5
0
 def FuseBnToConv(self):
     # find fusable bathnorm node
     fuse_bn_handler = ConvBnHandler()
     graph_searcher = GraphSearcher(self._graph)
     node_sets = graph_searcher.find_nodes_from_type([
         PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.BATCH_NORM],
                     action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.BATCH_NORM],
                     action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler),
         PatternType(
             pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.BATCH_NORM],
             action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.CONV2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONCAT,
             NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.CONV3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.CONCAT, NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
         PatternType(pattern=[
             NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONCAT,
             NNDCT_OP.BATCH_NORM
         ],
                     action=fuse_bn_handler),
     ])
     removed_bn = set()
     for id, node_list in node_sets.items():
         for nodeset in node_list:
             bn_node = nodeset[-1]
             if bn_node.merged and bn_node not in removed_bn:
                 self._graph.remove_node(bn_node)
                 removed_bn.add(bn_node)
示例#6
0
    def finetune(self, run_fn, run_args):
        if self.quantizer.quant_mode == 2:
            NndctScreenLogger().warning(
                f"Finetune function will be ignored in test mode!")
            return
        NndctScreenLogger().info(
            f"=>Preparing data for fast finetuning module parameters ...")

        # memory status
        total_m, *_, available_m = list(
            map(lambda x: x / 1024,
                map(int,
                    os.popen('free -t -m').readlines()[1].split()[1:])))
        NndctScreenLogger().info(
            f"Mem status(total mem: {total_m:.2f}G, available mem: {available_m:.2f}G)."
        )

        NndctScreenLogger().info(
            f"=>Preparing data for fast finetuning module parameters ...")
        # backup option value
        opt_bak_param_corr = NndctOption.nndct_param_corr.value
        set_option_value("nndct_param_corr", 0)

        # cache input and output
        #print("**** cache input and output")
        last_quant_nodes = self.collect_last_quant_nodes()
        with torch.no_grad():
            cache_layers = []
            monitor_layers = []
            for node in self.graph.nodes:
                if node.op.type == NNDCT_OP.INPUT or node in last_quant_nodes:
                    cache_layers.append(node.module)
                elif self.quantizer.configer.is_conv_like(node):
                    monitor_layers.append(node.module)

            monitor_handlers = self.hook_memory_monitor(monitor_layers)
            cache_handlers = self.hook_cache_output(cache_layers,
                                                    monitor_mem=True)
            set_option_value("nndct_quant_off", True)
            run_fn(*run_args)
            # memory statistics
            total_memory_cost = 0.0
            for layer in cache_layers:
                total_memory_cost += self._mem_count[layer]
                del self._mem_count[layer]

            total_memory_cost += 2 * max(self._mem_count.values())

            self.clean_hooks(monitor_handlers + cache_handlers)

        torch.cuda.empty_cache()
        NndctScreenLogger().info(
            f"Mem cost by fast finetuning: {total_memory_cost:.2f}G.")
        if total_memory_cost > 0.8 * available_m:
            NndctScreenLogger().warning(
                f"There is not enought memory for fast finetuning and this process will be ignored!.Try to use a smaller calibration dataset."
            )
            return
        # calibration to get a set of quantization steps
        #print("****calibration to get float model tensor values")
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)

        # evaluation to get float model tensors
        set_option_value("nndct_quant_off", False)
        with torch.no_grad():
            run_fn(*run_args)
        torch.cuda.empty_cache()

        #print("****Parameter finetuning")
        NndctScreenLogger().info(
            f"=>Fast finetuning module parameters for better quantization accuracy..."
        )
        device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE)
        graph_searcher = GraphSearcher(self.graph)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSWISH]),
            PatternType(
                pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSWISH]),
            PatternType(
                pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU6]),
        ])

        layer_act_group = {}
        for _, node_list in node_sets.items():
            for nodeset in node_list:
                conv, act = nodeset
                layer_act_group[conv] = act

        # to avoid quantization steps change among parameter finetuning
        self.quantizer.quant_mode = 2

        net_inputs = []
        for node in self.input_nodes:
            cached_net_input = [
                out for out in self.cached_outputs[node.module]
            ]
            net_inputs.append(cached_net_input)

        # last_quant_nodes = self.collect_last_quant_nodes()
        last_quant_mods = [node.module for node in last_quant_nodes]

        handlers = self.hook_cache_output(last_quant_mods, hook_type="single")
        net_loss = self.eval_loss(net_inputs, last_quant_mods, device)
        self.clean_hooks(handlers)
        # model.clean_hooks()
        torch.cuda.empty_cache()

        finetune_group = {}
        # hook_mods = []
        for qmod, fmod in zip(self._quant_model.modules(),
                              self._float_model.modules()):
            if hasattr(qmod, "node"):
                if (self.quantizer.configer.is_node_quantizable(
                        qmod.node, False) and len(qmod.node.op.params) > 0):
                    finetune_group[qmod.node] = [qmod, fmod]

                    # hook_mods.append(fmod)
        # self.hook_cache_output(hook_mods, hook_type="single")

        #for node, module_pair in finetune_group.items():
        for idx, (node,
                  module_pair) in tqdm(enumerate(finetune_group.items()),
                                       total=len(finetune_group.items())):
            # if self.quantizer.configer.is_node_quantizable(node, False) and \
            #   len(node.op.params) > 0:
            quant_layer, float_layer = module_pair
            pn_node = self.graph.parents(node)[0]
            handlers = self.hook_cache_output([pn_node.module],
                                              hook_type="single")
            layer_inputs = []
            with torch.no_grad():
                for input_args in zip(*net_inputs):
                    new_input_args = []
                    for ip in input_args:
                        if isinstance(ip, torch.Tensor):
                            new_input_args.append(ip.to(device))
                    _ = self.quant_model(*new_input_args)

                    layer_inputs.append(
                        self.cached_output[pn_node.module].detach().cpu())
            self.clean_hooks(handlers)
            del self.cached_output[pn_node.module]
            #print(f"Tuning {node.name}")
            net_loss = self.optimize_layer(node, float_layer, layer_inputs,
                                           layer_act_group, net_inputs,
                                           net_loss, last_quant_mods, device)
            # print(f"{node.name}:{net_loss}")
            del layer_inputs
            torch.cuda.empty_cache()

        # recover quantizer status
        for node in self.graph.nodes:
            for _, config_history in self.quantizer.config_history.items():
                if node.name in config_history:
                    config_history[node.name].clear()
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_saved"):
                setattr(mod, "param_saved", False)
        self.quantizer.quant_mode = 1
        set_option_value("nndct_param_corr", opt_bak_param_corr)

        NndctScreenLogger().info(f"=>Export fast finetuned parameters ...")
        # export finetuned parameters
        self.quantizer.export_param()
示例#7
0
    def _opt_raw_graph(self, raw_graph):
        graph_searcher = GraphSearcher(raw_graph)

        # torch 1.8.x will generate redundant type_as op before element-wise add.
        if get_torch_version() >= 180 and get_torch_version() < 190:
            node_sets = graph_searcher.find_nodes_from_type(
                [PatternType(pattern=["type_as", "add"])])
            OptPass.merge_internal_type_as(raw_graph, node_sets)

        # classification
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["t", "addmm"]),
            PatternType(pattern=["t", "matmul"])
        ])
        OptPass.merge_param_transpose_with_addmm(raw_graph, node_sets)

        # shufflenet
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListUnpack"]),
            PatternType(pattern=["TupleUnpack"])
        ])
        OptPass.unpack_ListUnpack_op(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["slice"])])

        OptPass.slice_to_strided_slice(raw_graph, node_sets)

        # yolo_v3
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["select", "copy_"])])
        OptPass.select_to_slice_inplace_copy(raw_graph, node_sets)

        # FADnet
        while True:
            node_sets = graph_searcher.find_nodes_from_type(
                [PatternType(pattern=["select", "strided_slice"])])
            if not OptPass.merge_select_to_strided_slice(raw_graph, node_sets):
                break
        # FADnet
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["strided_slice", "strided_slice"])])
        OptPass.merge_consecutive_strided_slice(raw_graph, node_sets)

        # 3d pointpillar
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["strided_slice", "index_put_"])])
        OptPass.stride_slice_to_index_inplace_put(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["strided_slice", "copy_"])])
        OptPass.create_stride_slice_inplace_copy(raw_graph, node_sets)

        # nd(>2) linear (JIRA 2646)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["matmul", "add"]),
            PatternType(pattern=["matmul", "add_"])
        ])
        OptPass.merge_matmul_with_add(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListConstruct"]),
            PatternType(pattern=["TupleConstruct"])
        ])
        OptPass.pack_ListConstruct_op(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["embedding_bag"])])
        OptPass.strip_reduantant_tensors_in_embedding_bag(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["empty", "zero_"])])
        OptPass.merge_empty_with_zero(raw_graph, node_sets)

        # delete node should be done after merge stride_slice
        # delete reduantant view FADnet.
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["view"])])
        OptPass.remove_reduantant_view(raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["add"]),
            PatternType(pattern=["sub"]),
            PatternType(pattern=["mul"]),
            PatternType(pattern=["div"])
        ])
        OptPass.transform_const_scalar_to_const_tensor(raw_graph, node_sets)

        return raw_graph
示例#8
0
    def _opt_raw_graph(self, raw_graph):
        graph_searcher = GraphSearcher(raw_graph)

        # shufflenet
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListUnpack"]),
            PatternType(pattern=["TupleUnpack"])
        ])
        OptPass.unpack_ListUnpack_op(self, raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["slice"])])

        OptPass.slice_to_strided_slice(self, raw_graph, node_sets)

        # yolo_v3
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["select", "copy_"])])
        OptPass.select_to_slice_inplace_copy(self, raw_graph, node_sets)

        # 3d pointpillar
        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["strided_slice", "index_put_"])])
        OptPass.stride_slice_to_index_inplace_put(self, raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["strided_slice", "copy_"])])
        OptPass.create_stride_slice_inplace_copy(self, raw_graph, node_sets)

        # nd(>2) linear (JIRA 2646)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["matmul", "add"]),
            PatternType(pattern=["matmul", "add_"])
        ])
        OptPass.merge_matmul_with_add(self, raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=["ListConstruct"]),
            PatternType(pattern=["TupleConstruct"])
        ])
        OptPass.pack_ListConstruct_op(self, raw_graph, node_sets)

        node_sets = graph_searcher.find_nodes_from_type(
            [PatternType(pattern=["embedding_bag"])])
        OptPass.strip_reduantant_tensors_in_embedding_bag(
            self, raw_graph, node_sets)

        return raw_graph